diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index d91895128a85ec..bf742b7a505f66 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -380,7 +380,38 @@ private static unsafe void TransparentAwaitValueTask(ValueTask valueTask) [BypassReadyToRun] [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] - private static unsafe void TransparentAwaitValueTaskOfT(ValueTask valueTask) + private static unsafe void AwaitValueTaskSource(object source, short token) + { + ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; + Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation(); + + ValueTaskContinuation? vtsCont = state.CachedValueTaskContinuation; + if (vtsCont != null) + { + state.CachedValueTaskContinuation = null; + } + else + { + vtsCont = new ValueTaskContinuation(); + } + + Debug.Assert(source != null); + vtsCont.Initialize(source, token); + + // We only need to capture flags. + // If needed, VTS will use the scheduling context captured in the "state". + CaptureContinuationContextFlags(ref vtsCont.Flags, state.CurrentThread!); + + sentinelContinuation.Next = vtsCont; + state.StackState->ValueTaskContinuation = vtsCont; + + state.CaptureContexts(); + AsyncSuspend(vtsCont); + } + + [BypassReadyToRun] + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] + private static unsafe void TransparentAwaitValueTaskOfT(ValueTask valueTask) { ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation(); @@ -405,6 +436,37 @@ private static unsafe void TransparentAwaitValueTaskOfT(ValueTask valueTa AsyncSuspend(vtsCont); } + [BypassReadyToRun] + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] + private static unsafe void AwaitValueTaskSourceOfT(object source, short token) + { + ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; + Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation(); + + ValueTaskContinuation? vtsCont = state.CachedValueTaskContinuation; + if (vtsCont != null) + { + state.CachedValueTaskContinuation = null; + } + else + { + vtsCont = new ValueTaskContinuation(); + } + + Debug.Assert(source != null); + vtsCont.Initialize(source, token); + + // We only need to capture flags. + // If needed, VTS will use the scheduling context captured in the "state". + CaptureContinuationContextFlags(ref vtsCont.Flags, state.CurrentThread!); + + sentinelContinuation.Next = vtsCont; + state.StackState->ValueTaskContinuation = vtsCont; + + state.CaptureContexts(); + AsyncSuspend(vtsCont); + } + /// /// Used by internal thunks that implement awaiting on Task. /// @@ -493,23 +555,25 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. // These never have special continuation context handling. + // Except for the scenario with ValueTaskContinuation that wraps ValueTaskSource + // which can capture continuation context flags. const ContinuationFlags continueFlags = ContinuationFlags.ContinueOnCapturedSynchronizationContext | ContinuationFlags.ContinueOnThreadPool | ContinuationFlags.ContinueOnCapturedTaskScheduler; - Debug.Assert((headContinuation.Flags & continueFlags) == 0); - SetContinuationState(headContinuation); try { if (stackState->CriticalNotifier is { } critNotifier) { + Debug.Assert((headContinuation.Flags & continueFlags) == 0); critNotifier.UnsafeOnCompleted(GetContinuationAction()); } else if (stackState->TaskNotifier is { } taskNotifier) { + Debug.Assert((headContinuation.Flags & continueFlags) == 0); // Runtime async callable wrapper for task returning // method. This implements the context transparent // forwarding and makes these wrappers minimal cost. @@ -525,6 +589,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) Debug.Assert(source != null); if (source is Task t) { + Debug.Assert((headContinuation.Flags & continueFlags) == 0); if (!t.TryAddCompletionAction(this)) { ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true); @@ -541,17 +606,18 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) // the continuation chain builds from the innermost frame out and at the time when the // notifier is created we do not know yet if the caller wants to continue on a context. - // Skip to a nontransparent/user continuation. Such continuaton must exist. + // Skip to a nontransparent/user continuation. Such continuation must exist. // Since we see a VTS notifier, something was directly or indirectly - // awaiting an async thunk for a ValueTask-returning method. - // That can only happen in nontransparent/user code. - Continuation nextUserContinuation = valueTaskSourceCont.Next!; - while ((nextUserContinuation.Flags & continueFlags) == 0 && nextUserContinuation.Next != null) + // awaiting either an async thunk for a ValueTask-returning method or + // the direct AsyncHelpers.Await(ValueTask/ValueTask) path. + // In either case, that can only happen in nontransparent/user code. + Continuation contWithContinueFlags = valueTaskSourceCont; + while ((contWithContinueFlags.Flags & continueFlags) == 0 && contWithContinueFlags.Next != null) { - nextUserContinuation = nextUserContinuation.Next; + contWithContinueFlags = contWithContinueFlags.Next; } - ContinuationFlags continuationFlags = nextUserContinuation.Flags; + ContinuationFlags continuationFlags = contWithContinueFlags.Flags; const ContinuationFlags continueOnContextFlags = ContinuationFlags.ContinueOnCapturedSynchronizationContext | ContinuationFlags.ContinueOnCapturedTaskScheduler; @@ -564,7 +630,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) } // Clear continuation flags, so that continuation runs transparently - nextUserContinuation.Flags &= ~continueFlags; + contWithContinueFlags.Flags &= ~continueFlags; valueTaskSourceCont.OnCompletedValueTaskSource( source, @@ -576,6 +642,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) } else { + Debug.Assert((headContinuation.Flags & continueFlags) == 0); Debug.Assert(stackState->Notifier != null); stackState->Notifier!.OnCompleted(GetContinuationAction()); } @@ -1117,7 +1184,7 @@ private static void RestoreContextsOnSuspension(bool resumed, ExecutionContext? } } - private static void CaptureContinuationContext(ref object continuationContext, ref ContinuationFlags flags) + private static void CaptureContinuationContext(ref object? continuationContext, ref ContinuationFlags flags) { SynchronizationContext? syncCtx = Thread.CurrentThreadAssumedInitialized._synchronizationContext; if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext)) @@ -1138,6 +1205,26 @@ private static void CaptureContinuationContext(ref object continuationContext, r flags |= ContinuationFlags.ContinueOnThreadPool; } + // Same as above, but only captures flags + private static void CaptureContinuationContextFlags(ref ContinuationFlags flags, Thread currentThread) + { + SynchronizationContext? syncCtx = currentThread._synchronizationContext; + if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext)) + { + flags |= ContinuationFlags.ContinueOnCapturedSynchronizationContext; + return; + } + + TaskScheduler? sched = TaskScheduler.InternalCurrent; + if (sched != null && sched != TaskScheduler.Default) + { + flags |= ContinuationFlags.ContinueOnCapturedTaskScheduler; + return; + } + + flags |= ContinuationFlags.ContinueOnThreadPool; + } + // Finish suspension in the common case of a custom await or for a ConfigureAwait(false) task await: // - Capture current ExecutionContext into the continuation // - Restore ExecutionContext and SynchronizationContext to the current Thread object diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs index d7b205787da59e..7d5af875d08a50 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs @@ -105,6 +105,14 @@ private static class ValueTaskContinuationResume { var vtsCont = (ValueTaskContinuation)cont; vtsCont.Next = null; + + const ContinuationFlags continueFlags = + ContinuationFlags.ContinueOnCapturedSynchronizationContext | + ContinuationFlags.ContinueOnThreadPool | + ContinuationFlags.ContinueOnCapturedTaskScheduler; + + Debug.Assert((vtsCont.Flags & continueFlags) == 0); + t_runtimeAsyncAwaitState.CachedValueTaskContinuation = vtsCont; vtsCont.GetResult(ref result); diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs index 510b863d4dda8e..f54f58afab7e30 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs @@ -104,13 +104,21 @@ public static void Await(Task task) [StackTraceHidden] public static T Await(ValueTask task) { - ValueTaskAwaiter awaiter = task.GetAwaiter(); - if (!awaiter.IsCompleted) + if (!task.IsCompleted) { - UnsafeAwaitAwaiter(awaiter); + if (task._obj is Task t) + { + TailAwait(); + Await(t); + } + else + { + TailAwait(); + AwaitValueTaskSourceOfT(task._obj!, task._token); + } } - return awaiter.GetResult(); + return task.Result; } /// @@ -123,13 +131,21 @@ public static T Await(ValueTask task) [StackTraceHidden] public static void Await(ValueTask task) { - ValueTaskAwaiter awaiter = task.GetAwaiter(); - if (!awaiter.IsCompleted) + if (!task.IsCompleted) { - UnsafeAwaitAwaiter(awaiter); + if (task._obj is Task t) + { + TailAwait(); + Await(t); + } + else + { + TailAwait(); + AwaitValueTaskSource(task._obj!, task._token); + } } - awaiter.GetResult(); + task.ThrowIfCompletedUnsuccessfully(); } ///