diff --git a/src/fsharp/FSharp.Core/control.fs b/src/fsharp/FSharp.Core/control.fs index 6be535a316e..23b27a716bc 100644 --- a/src/fsharp/FSharp.Core/control.fs +++ b/src/fsharp/FSharp.Core/control.fs @@ -494,6 +494,25 @@ namespace Microsoft.FSharp.Control // delayPrim = "bindA (return ()) f" let delayA f = callA f () + // protect an exception in an cancellation workflow and adds it to the given OperationCanceledException + // ie creates a new instance, where the new information is added, + // if cexn is a TaskCanceledException the type is preserved as well. + let augmentOperationCancelledException (cexn:OperationCanceledException) (edi:ExceptionDispatchInfo) = + // It's probably ok to not care about the stack of the cexn object, because + // 1. we often don't even collect the stack in the ccont route + // 2. there are no suited APIs to handle this (at best we could add the original instance to the new instance...) + // If we ever need this we probably need to provide our own sub-types of OperationCanceledException + // and TaskCanceledException + let exn = edi.GetAssociatedSourceException() + let collected = + match cexn.InnerException with + | null -> [|exn|] + | :? AggregateException as a -> Array.append (a.Flatten().InnerExceptions |> Seq.toArray) [|exn|] + | inner -> [|inner; exn|] + let aggr = (new AggregateException(collected)).Flatten() + match cexn with + | :? TaskCanceledException -> new TaskCanceledException(cexn.Message, aggr) :> OperationCanceledException + | _ -> new OperationCanceledException(cexn.Message, aggr) // Call p but augment the normal, exception and cancel continuations with a call to finallyFunction. // If the finallyFunction raises an exception then call the original exception continuation // with the new exception. If exception is raised after a cancellation, exception is ignored @@ -511,8 +530,8 @@ namespace Microsoft.FSharp.Control // If an exception is thrown we continue with the previous exception continuation. let econt exn = protect trampolineHolder args.aux.econt finallyFunction () (fun () -> args.aux.econt exn) // The cancellation continuation runs the finallyFunction and then runs the previous cancellation continuation. - // If an exception is thrown we continue with the previous cancellation continuation (the exception is lost) - let ccont cexn = protect trampolineHolder (fun _ -> args.aux.ccont cexn) finallyFunction () (fun () -> args.aux.ccont cexn) + // If an exception is thrown we collect/protect it in the OperationCancelledException + let ccont cexn = protect trampolineHolder (augmentOperationCancelledException cexn >> args.aux.ccont) finallyFunction () (fun () -> args.aux.ccont cexn) invokeA p { args with cont = cont; aux = { args.aux with econt = econt; ccont = ccont } }) // Re-route the exception continuation to call to catchFunction. If catchFunction or the new process fail @@ -531,7 +550,7 @@ namespace Microsoft.FSharp.Control /// Call the finallyFunction if the computation results in a cancellation let whenCancelledA (finallyFunction : OperationCanceledException -> unit) p = unprotectedPrimitive (fun ({ aux = aux } as args)-> - let ccont exn = protect aux.trampolineHolder (fun _ -> aux.ccont exn) finallyFunction exn (fun _ -> aux.ccont exn) + let ccont exn = protect aux.trampolineHolder (augmentOperationCancelledException exn >> aux.ccont) finallyFunction exn (fun _ -> aux.ccont exn) invokeA p { args with aux = { aux with ccont = ccont } }) let getCancellationToken() = @@ -866,7 +885,11 @@ namespace Microsoft.FSharp.Control queueAsync token (fun res -> resultCell.RegisterResult(Ok(res),reuseThread=true)) - (fun edi -> resultCell.RegisterResult(Error(edi),reuseThread=true)) + (fun edi -> + let result = + if token.IsCancellationRequested then Canceled(augmentOperationCancelledException (new OperationCanceledException()) edi) + else Error(edi) + resultCell.RegisterResult(result,reuseThread=true)) (fun exn -> resultCell.RegisterResult(Canceled(exn),reuseThread=true)) computation |> unfake @@ -897,7 +920,11 @@ namespace Microsoft.FSharp.Control token trampolineHolder (fun res -> resultCell.RegisterResult(Ok(res),reuseThread=true)) - (fun edi -> resultCell.RegisterResult(Error(edi),reuseThread=true)) + (fun edi -> + let result = + if token.IsCancellationRequested then Canceled(augmentOperationCancelledException (new OperationCanceledException()) edi) + else Error(edi) + resultCell.RegisterResult(result,reuseThread=true)) (fun exn -> resultCell.RegisterResult(Canceled(exn),reuseThread=true)) computation) |> unfake @@ -939,9 +966,15 @@ namespace Microsoft.FSharp.Control member __.Proceed = not isStopped member __.Stop() = isStopped <- true - let StartAsTask (token:CancellationToken, computation : Async<_>,taskCreationOptions) : Task<_> = + let StartAsTask (token:CancellationToken, computation : Async<'a>,taskCreationOptions) : Task<'a> = +#if FX_NO_ASYNCTASKMETHODBUILDER let taskCreationOptions = defaultArg taskCreationOptions TaskCreationOptions.None let tcs = new TaskCompletionSource<_>(taskCreationOptions) +#else + // AsyncTaskMethodBuilder allows us to better control the cancellation process by setting custom exception objects. + let _ = defaultArg taskCreationOptions TaskCreationOptions.None + let tcs = System.Runtime.CompilerServices.AsyncTaskMethodBuilder<'a>() +#endif // The contract: // a) cancellation signal should always propagate to the computation @@ -950,8 +983,15 @@ namespace Microsoft.FSharp.Control queueAsync token (fun r -> tcs.SetResult r |> fake) - (fun edi -> tcs.SetException edi.SourceException |> fake) - (fun _ -> tcs.SetCanceled() |> fake) + (fun edi -> + let wrapper = + if token.IsCancellationRequested then + augmentOperationCancelledException (new TaskCanceledException()) edi :> exn + else + edi.SourceException + tcs.SetException wrapper |> fake) + // We wrap in a TaskCanceledException to maintain backwards compat. + (fun exn -> tcs.SetException (new TaskCanceledException(exn.Message, exn.InnerException)) |> fake) computation |> unfake task @@ -1148,14 +1188,30 @@ namespace Microsoft.FSharp.Control // Contains helpers that will attach continuation to the given task. // Should be invoked as a part of protectedPrimitive(withResync) call module TaskHelpers = - let continueWith (task : Task<'T>, args, useCcontForTaskCancellation) = + // This uses a trick to get the underlying OperationCanceledException + let inline getCancelledException (completedTask:Task) (waitWithAwaiter) = + let fallback() = new TaskCanceledException(completedTask) :> OperationCanceledException + // sadly there is no other public api to retrieve it, but to call .GetAwaiter().GetResult(). + try waitWithAwaiter() + // should not happen, but just in case... + fallback() + with + | :? OperationCanceledException as o -> o + | other -> + // shouldn't happen, but just in case... + new TaskCanceledException(fallback().Message, other) :> OperationCanceledException + + let continueWith (task : Task<'T>, args, useSimpleCcontForTaskCancellationAndLooseException) = let continuation (completedTask : Task<_>) : unit = args.aux.trampolineHolder.Protect((fun () -> if completedTask.IsCanceled then - if useCcontForTaskCancellation + if useSimpleCcontForTaskCancellationAndLooseException then args.aux.ccont (new OperationCanceledException(args.aux.token)) - else args.aux.econt (ExceptionDispatchInfo.Capture(new TaskCanceledException(completedTask))) + else + let cancelledException = + getCancelledException completedTask (fun () -> completedTask.GetAwaiter().GetResult() |> ignore) + args.aux.econt (ExceptionDispatchInfo.Capture(cancelledException)) elif completedTask.IsFaulted then args.aux.econt (MayLoseStackTrace(completedTask.Exception)) else @@ -1163,14 +1219,18 @@ namespace Microsoft.FSharp.Control task.ContinueWith(Action>(continuation)) |> ignore |> fake - let continueWithUnit (task : Task, args, useCcontForTaskCancellation) = + let continueWithUnit (task : Task, args, useSimpleCcontForTaskCancellationAndLooseException) = + let continuation (completedTask : Task) : unit = args.aux.trampolineHolder.Protect((fun () -> if completedTask.IsCanceled then - if useCcontForTaskCancellation + if useSimpleCcontForTaskCancellationAndLooseException then args.aux.ccont (new OperationCanceledException(args.aux.token)) - else args.aux.econt (ExceptionDispatchInfo.Capture(new TaskCanceledException(completedTask))) + else + let cancelledException = + getCancelledException completedTask (fun () -> completedTask.GetAwaiter().GetResult()) + args.aux.econt (ExceptionDispatchInfo.Capture(cancelledException)) elif completedTask.IsFaulted then args.aux.econt (MayLoseStackTrace(completedTask.Exception)) else diff --git a/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/AsyncType.fs b/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/AsyncType.fs index d2fe26e0c00..1e508ede5ac 100644 --- a/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/AsyncType.fs +++ b/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/AsyncType.fs @@ -126,7 +126,12 @@ type AsyncType() = ) member private this.WaitASec (t:Task) = - let result = t.Wait(TimeSpan(hours=0,minutes=0,seconds=1)) + let result = + try t.Wait(TimeSpan(hours=0,minutes=0,seconds=1)) + with :? AggregateException -> + // This throws the "original" exception + t.GetAwaiter().GetResult() + false Assert.IsTrue(result, "Task did not finish after waiting for a second.") @@ -162,17 +167,118 @@ type AsyncType() = try let result = t.Wait(300) Assert.IsFalse (result) - with :? AggregateException -> Assert.Fail "Task should not finish, jet" + with :? AggregateException -> Assert.Fail "Task should not finish, yet" tcs.SetCanceled() try this.WaitASec t - with :? AggregateException as a -> - match a.InnerException with - | :? TaskCanceledException as t -> () + with :? TaskCanceledException -> () + Assert.IsTrue (t.IsCompleted, "Task is not completed") + Assert.IsTrue (t.IsCanceled, "Task is not cancelled") + + [] + member this.StartAsTaskCancellationViaException () = + let cts = new CancellationTokenSource() + let tcs = TaskCompletionSource() + let a = async { + cts.CancelAfter (100) + do! tcs.Task |> Async.AwaitTask } +#if FSCORE_PORTABLE_NEW || coreclr + let t : Task = +#else + use t : Task = +#endif + Async.StartAsTask(a, cancellationToken = cts.Token) + + // Should not finish + try + let result = t.Wait(300) + Assert.IsFalse (result) + with :? AggregateException -> Assert.Fail "Task should not finish, yet" + + let msg = "Custom non-conforming 3rd-Party-Api throws" + tcs.SetException(Exception msg) + + try + this.WaitASec t + with :? TaskCanceledException as t -> + match t.InnerException with + | :? AggregateException as a -> + Assert.AreEqual(1, a.InnerExceptions.Count) + Assert.AreEqual(msg, a.InnerException.Message) + | _ -> reraise() + Assert.IsTrue (t.IsCompleted, "Task is not completed") + Assert.IsTrue (t.IsCanceled, "Task is not cancelled") + + [] + member this.RunSynchronouslyCancellation () = + let cts = new CancellationTokenSource() + let tcs = TaskCompletionSource() + let a = async { + cts.CancelAfter (100) + do! tcs.Task |> Async.AwaitTask } +#if FSCORE_PORTABLE_NEW || coreclr + let t : Task = +#else + use t : Task = +#endif + Task.Run(new Func(fun () -> Async.RunSynchronously(a, cancellationToken = cts.Token))) + + // Should not finish + try + let result = t.Wait(300) + Assert.IsFalse (result) + with :? AggregateException -> Assert.Fail "Task should not finish, yet" + + tcs.SetCanceled() + + try + this.WaitASec t + with :? OperationCanceledException -> () + + Assert.IsTrue (t.IsCompleted, "Task is not completed") + // We used Task.Run for convenience, it will not notice the cancellation + // -> Cancellation is noticed by RunSynchronously throwing 'OperationCanceledException' + // which is tested above + //Assert.IsTrue (t.IsCanceled, "Task is not cancelled") + + [] + member this.RunSynchronouslyCancellationViaException () = + let cts = new CancellationTokenSource() + let tcs = TaskCompletionSource() + let a = async { + cts.CancelAfter (100) + do! tcs.Task |> Async.AwaitTask } +#if FSCORE_PORTABLE_NEW || coreclr + let t : Task = +#else + use t : Task = +#endif + Task.Run(new Func(fun () -> Async.RunSynchronously(a, cancellationToken = cts.Token))) + + // Should not finish + try + let result = t.Wait(300) + Assert.IsFalse (result) + with :? AggregateException -> Assert.Fail "Task should not finish, yet" + + let msg = "Custom non-conforming 3rd-Party-Api throws" + tcs.SetException(Exception msg) + + try + this.WaitASec t + with :? OperationCanceledException as t -> + match t.InnerException with + | :? AggregateException as a -> + Assert.AreEqual(1, a.InnerExceptions.Count) + Assert.AreEqual(msg, a.InnerException.Message) | _ -> reraise() Assert.IsTrue (t.IsCompleted, "Task is not completed") + // We used Task.Run for convenience, it will not notice the cancellation + // -> Cancellation is noticed by RunSynchronously throwing 'OperationCanceledException' + // which is tested above + //Assert.IsTrue (t.IsCanceled, "Task is not cancelled") [] member this.StartTask () = diff --git a/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/Cancellation.fs b/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/Cancellation.fs index 47f241f2462..ab441a1d13d 100644 --- a/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/Cancellation.fs +++ b/tests/FSharp.Core.UnitTests/FSharp.Core/Microsoft.FSharp.Control/Cancellation.fs @@ -307,4 +307,119 @@ type CancellationType() = Assert.IsTrue((r1a <> r1b)) Assert.IsTrue((r1a <> r2)) + [] + member this.TestCancellationKeepsExceptionInfo() = + let cts = new CancellationTokenSource() + let ewh = new ManualResetEvent(false) + let msg = "Cleanup failure" + let a = async { + try ewh.Set() |> ignore + do! Async.Sleep 10000 + finally raise <| Exception msg} + async { + ewh.WaitOne() |> ignore + cts.Cancel() + } |> Async.Start + + try + Async.RunSynchronously(a, cancellationToken = cts.Token) + with :? OperationCanceledException as o -> + match o.InnerException with + | :? AggregateException as a -> + Assert.AreEqual (1, a.InnerExceptions.Count) + match a.InnerException with + | e when not (isNull e) -> + Assert.AreEqual(msg, e.Message) + | _ -> reraise() + | _ -> reraise() + + [] + member this.TestCancellationKeepsExceptionInfoAsTask() = + let cts = new CancellationTokenSource() + let ewh = new ManualResetEvent(false) + let msg = "Cleanup failure" + let a = async { + try ewh.Set() |> ignore + do! Async.Sleep 10000 + finally raise <| Exception msg} + async { + ewh.WaitOne() |> ignore + cts.Cancel() + } |> Async.Start + + let t = Async.StartAsTask(a, cancellationToken = cts.Token) + try + t.GetAwaiter().GetResult() + with :? OperationCanceledException as o -> + match o.InnerException with + | :? AggregateException as a -> + Assert.AreEqual (1, a.InnerExceptions.Count) + match a.InnerException with + | e when not (isNull e) -> + Assert.AreEqual(msg, e.Message) + | _ -> reraise() + | _ -> reraise() + + Assert.IsTrue(t.IsCompleted, "Task should be marked as completed") + Assert.IsTrue(t.IsCanceled, "Task should be marked as cancelled") + + [] + member this.TestCancellationKeepsExceptionInfoWithTryWith() = + let cts = new CancellationTokenSource() + let ewh = new ManualResetEvent(false) + let msg = "Cleanup failure" + let a = async { + try + try ewh.Set() |> ignore + do! Async.Sleep 10000 + finally raise <| Exception msg + with :? InvalidOperationException -> () } + async { + ewh.WaitOne() |> ignore + cts.Cancel() + } |> Async.Start + + try + Async.RunSynchronously(a, cancellationToken = cts.Token) + with :? OperationCanceledException as o -> + match o.InnerException with + | :? AggregateException as a -> + Assert.AreEqual (1, a.InnerExceptions.Count) + match a.InnerException with + | e when not (isNull e) -> + Assert.AreEqual(msg, e.Message) + | _ -> reraise() + | _ -> reraise() + + [] + member this.TestCancellationKeepsExceptionInfoWithTryWithAsTask() = + let cts = new CancellationTokenSource() + let ewh = new ManualResetEvent(false) + let msg = "Cleanup failure" + let a = async { + try + try ewh.Set() |> ignore + do! Async.Sleep 10000 + finally raise <| Exception msg + with :? InvalidOperationException -> () } + async { + ewh.WaitOne() |> ignore + cts.Cancel() + } |> Async.Start + + let t = Async.StartAsTask(a, cancellationToken = cts.Token) + try + t.GetAwaiter().GetResult() + with :? OperationCanceledException as o -> + match o.InnerException with + | :? AggregateException as a -> + Assert.AreEqual (1, a.InnerExceptions.Count) + match a.InnerException with + | e when not (isNull e) -> + Assert.AreEqual(msg, e.Message) + | _ -> reraise() + | _ -> reraise() + + Assert.IsTrue(t.IsCompleted, "Task should be marked as completed") + Assert.IsTrue(t.IsCanceled, "Task should be marked as cancelled")