Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,38 @@ private static unsafe void TransparentAwaitValueTask(ValueTask valueTask)

[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
private static unsafe void TransparentAwaitValueTaskOfT<T>(ValueTask<T?> valueTask)
private static unsafe void AwaitValueTaskSource(object source, short token)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();

ValueTaskContinuation? vtsCont = state.CachedValueTaskContinuation;
if (vtsCont != null)
{
state.CachedValueTaskContinuation = null;
}
else
{
vtsCont = new ValueTaskContinuation();
}

Debug.Assert(source != null);
vtsCont.Initialize(source, token);

// We only need to capture flags.
// If needed, VTS will use the scheduling context captured in the "state".
CaptureContinuationContextFlags(ref vtsCont.Flags, state.CurrentThread!);

Comment thread
VSadov marked this conversation as resolved.
sentinelContinuation.Next = vtsCont;
state.StackState->ValueTaskContinuation = vtsCont;

state.CaptureContexts();
AsyncSuspend(vtsCont);
}

[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
private static unsafe void TransparentAwaitValueTaskOfT<T>(ValueTask<T> valueTask)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();
Expand All @@ -405,6 +436,37 @@ private static unsafe void TransparentAwaitValueTaskOfT<T>(ValueTask<T?> valueTa
AsyncSuspend(vtsCont);
}

[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
private static unsafe void AwaitValueTaskSourceOfT<T>(object source, short token)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();

ValueTaskContinuation? vtsCont = state.CachedValueTaskContinuation;
if (vtsCont != null)
{
state.CachedValueTaskContinuation = null;
}
else
{
vtsCont = new ValueTaskContinuation();
}

Debug.Assert(source != null);
vtsCont.Initialize<T>(source, token);

// We only need to capture flags.
// If needed, VTS will use the scheduling context captured in the "state".
CaptureContinuationContextFlags(ref vtsCont.Flags, state.CurrentThread!);

Comment thread
VSadov marked this conversation as resolved.
sentinelContinuation.Next = vtsCont;
state.StackState->ValueTaskContinuation = vtsCont;

state.CaptureContexts();
AsyncSuspend(vtsCont);
}

/// <summary>
/// Used by internal thunks that implement awaiting on Task.
/// </summary>
Expand Down Expand Up @@ -493,23 +555,25 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)

// Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
// These never have special continuation context handling.
// Except for the scenario with ValueTaskContinuation that wraps ValueTaskSource
// which can capture continuation context flags.
const ContinuationFlags continueFlags =
ContinuationFlags.ContinueOnCapturedSynchronizationContext |
ContinuationFlags.ContinueOnThreadPool |
ContinuationFlags.ContinueOnCapturedTaskScheduler;

Debug.Assert((headContinuation.Flags & continueFlags) == 0);

SetContinuationState(headContinuation);

try
{
if (stackState->CriticalNotifier is { } critNotifier)
{
Debug.Assert((headContinuation.Flags & continueFlags) == 0);
critNotifier.UnsafeOnCompleted(GetContinuationAction());
}
else if (stackState->TaskNotifier is { } taskNotifier)
{
Debug.Assert((headContinuation.Flags & continueFlags) == 0);
// Runtime async callable wrapper for task returning
// method. This implements the context transparent
// forwarding and makes these wrappers minimal cost.
Expand All @@ -525,6 +589,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
Debug.Assert(source != null);
if (source is Task t)
{
Debug.Assert((headContinuation.Flags & continueFlags) == 0);
if (!t.TryAddCompletionAction(this))
{
ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true);
Expand All @@ -541,17 +606,18 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
// the continuation chain builds from the innermost frame out and at the time when the
// notifier is created we do not know yet if the caller wants to continue on a context.

// Skip to a nontransparent/user continuation. Such continuaton must exist.
// Skip to a nontransparent/user continuation. Such continuation must exist.
// Since we see a VTS notifier, something was directly or indirectly
// awaiting an async thunk for a ValueTask-returning method.
// That can only happen in nontransparent/user code.
Continuation nextUserContinuation = valueTaskSourceCont.Next!;
while ((nextUserContinuation.Flags & continueFlags) == 0 && nextUserContinuation.Next != null)
// awaiting either an async thunk for a ValueTask-returning method or
// the direct AsyncHelpers.Await(ValueTask/ValueTask<T>) path.
// In either case, that can only happen in nontransparent/user code.
Continuation contWithContinueFlags = valueTaskSourceCont;
while ((contWithContinueFlags.Flags & continueFlags) == 0 && contWithContinueFlags.Next != null)
{
nextUserContinuation = nextUserContinuation.Next;
contWithContinueFlags = contWithContinueFlags.Next;
}

ContinuationFlags continuationFlags = nextUserContinuation.Flags;
ContinuationFlags continuationFlags = contWithContinueFlags.Flags;
const ContinuationFlags continueOnContextFlags =
ContinuationFlags.ContinueOnCapturedSynchronizationContext |
ContinuationFlags.ContinueOnCapturedTaskScheduler;
Expand All @@ -564,7 +630,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
}

// Clear continuation flags, so that continuation runs transparently
nextUserContinuation.Flags &= ~continueFlags;
contWithContinueFlags.Flags &= ~continueFlags;

valueTaskSourceCont.OnCompletedValueTaskSource(
source,
Expand All @@ -576,6 +642,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
}
else
{
Debug.Assert((headContinuation.Flags & continueFlags) == 0);
Debug.Assert(stackState->Notifier != null);
stackState->Notifier!.OnCompleted(GetContinuationAction());
}
Expand Down Expand Up @@ -1117,7 +1184,7 @@ private static void RestoreContextsOnSuspension(bool resumed, ExecutionContext?
}
}

private static void CaptureContinuationContext(ref object continuationContext, ref ContinuationFlags flags)
private static void CaptureContinuationContext(ref object? continuationContext, ref ContinuationFlags flags)
{
SynchronizationContext? syncCtx = Thread.CurrentThreadAssumedInitialized._synchronizationContext;
if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext))
Expand All @@ -1138,6 +1205,26 @@ private static void CaptureContinuationContext(ref object continuationContext, r
flags |= ContinuationFlags.ContinueOnThreadPool;
}

// Same as above, but only captures flags
private static void CaptureContinuationContextFlags(ref ContinuationFlags flags, Thread currentThread)
{
SynchronizationContext? syncCtx = currentThread._synchronizationContext;
if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext))
{
flags |= ContinuationFlags.ContinueOnCapturedSynchronizationContext;
return;
}

TaskScheduler? sched = TaskScheduler.InternalCurrent;
if (sched != null && sched != TaskScheduler.Default)
{
flags |= ContinuationFlags.ContinueOnCapturedTaskScheduler;
return;
}

flags |= ContinuationFlags.ContinueOnThreadPool;
}

// Finish suspension in the common case of a custom await or for a ConfigureAwait(false) task await:
// - Capture current ExecutionContext into the continuation
// - Restore ExecutionContext and SynchronizationContext to the current Thread object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ private static class ValueTaskContinuationResume
{
var vtsCont = (ValueTaskContinuation)cont;
vtsCont.Next = null;

const ContinuationFlags continueFlags =
ContinuationFlags.ContinueOnCapturedSynchronizationContext |
ContinuationFlags.ContinueOnThreadPool |
ContinuationFlags.ContinueOnCapturedTaskScheduler;

Debug.Assert((vtsCont.Flags & continueFlags) == 0);

t_runtimeAsyncAwaitState.CachedValueTaskContinuation = vtsCont;

vtsCont.GetResult(ref result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,21 @@ public static void Await(Task task)
[StackTraceHidden]
public static T Await<T>(ValueTask<T> task)
{
ValueTaskAwaiter<T> awaiter = task.GetAwaiter();
if (!awaiter.IsCompleted)
if (!task.IsCompleted)
{
UnsafeAwaitAwaiter(awaiter);
if (task._obj is Task<T> t)
{
TailAwait();
Comment thread
VSadov marked this conversation as resolved.
Await(t);
}
else
{
TailAwait();
AwaitValueTaskSourceOfT<T>(task._obj!, task._token);
Comment thread
VSadov marked this conversation as resolved.
}
}

return awaiter.GetResult();
return task.Result;
}
Comment thread
VSadov marked this conversation as resolved.

/// <summary>
Expand All @@ -123,13 +131,21 @@ public static T Await<T>(ValueTask<T> task)
[StackTraceHidden]
public static void Await(ValueTask task)
{
ValueTaskAwaiter awaiter = task.GetAwaiter();
if (!awaiter.IsCompleted)
if (!task.IsCompleted)
{
UnsafeAwaitAwaiter(awaiter);
if (task._obj is Task t)
{
TailAwait();
Await(t);
}
else
{
TailAwait();
AwaitValueTaskSource(task._obj!, task._token);
Comment thread
VSadov marked this conversation as resolved.
}
}

awaiter.GetResult();
task.ThrowIfCompletedUnsuccessfully();
}

/// <summary>
Expand Down
Loading