Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
<Compile Include="$(BclSourcesRoot)\System\Runtime\CompilerServices\GenericsHelpers.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\CompilerServices\InitHelpers.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\CompilerServices\AsyncHelpers.CoreCLR.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\CompilerServices\AsyncHelpers.TaskContinuation.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\CompilerServices\AsyncHelpers.ValueTaskContinuation.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\CompilerServices\AsyncProfiler.CoreCLR.cs" />
<Compile Include="$(BclSourcesRoot)\System\Runtime\CompilerServices\RuntimeHelpers.CoreCLR.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ internal enum ContinuationFlags
ContinueOnCapturedSynchronizationContext = 1 << 1,
ContinueOnCapturedTaskScheduler = 1 << 2,

AllContinuationFlags = ContinueOnThreadPool | ContinueOnCapturedSynchronizationContext | ContinueOnCapturedTaskScheduler,

// The flags encode where in the continuation various members are stored.
// If the encoded index is 0, it means no such member is present.
// Otherwise the exact offset of the member is computed as
Expand Down Expand Up @@ -213,7 +215,7 @@ private ref struct RuntimeAsyncStackState
public ICriticalNotifyCompletion? CriticalNotifier;
public INotifyCompletion? Notifier;
public ValueTaskContinuation? ValueTaskContinuation;
public Task? TaskNotifier;
public TaskContinuation? TaskContinuation;

// When we suspend in the leaf, the contexts are captured into these fields.
public ExecutionContext? LeafExecutionContext;
Expand Down Expand Up @@ -256,6 +258,7 @@ private unsafe struct RuntimeAsyncAwaitState
{
public Continuation? SentinelContinuation;
public ValueTaskContinuation? CachedValueTaskContinuation;
public TaskContinuation? CachedTaskContinuation;

// We cache the thread here to avoid unnecessary repeated TLS lookups.
public Thread? CurrentThread;
Expand Down Expand Up @@ -470,18 +473,61 @@ private static unsafe void AwaitValueTaskSourceOfT<T>(object source, short token
/// <summary>
/// Used by internal thunks that implement awaiting on Task.
/// </summary>
/// <param name="t">Task whose completion we are awaiting.</param>
/// <param name="task">Task whose completion we are awaiting.</param>
[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
private static unsafe void TransparentAwait(Task task)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();

TaskContinuation? taskCont = state.CachedTaskContinuation;
if (taskCont != null)
{
state.CachedTaskContinuation = null;
}
else
{
taskCont = new TaskContinuation();
}

taskCont.Initialize(task);

sentinelContinuation.Next = taskCont;
state.StackState->TaskContinuation = taskCont;

state.CaptureContexts();
AsyncSuspend(taskCont);
}

/// <summary>
/// Used by internal thunks that implement awaiting on Task.
/// </summary>
/// <param name="task">Task whose completion we are awaiting.</param>
[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
private static unsafe void TransparentAwait(Task t)
private static unsafe void TransparentAwaitOfT<T>(Task<T> task)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();

state.StackState->TaskNotifier = t;
TaskContinuation? taskCont = state.CachedTaskContinuation;
if (taskCont != null)
{
state.CachedTaskContinuation = null;
}
else
{
taskCont = new TaskContinuation();
}

taskCont.Initialize<T>(task);

sentinelContinuation.Next = taskCont;
state.StackState->TaskContinuation = taskCont;

state.CaptureContexts();
AsyncSuspend(sentinelContinuation);
AsyncSuspend(taskCont);
}

// Represents execution of a chain of suspended and resuming runtime
Expand Down Expand Up @@ -553,31 +599,29 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
Continuation headContinuation = sentinelContinuation.Next!;
sentinelContinuation.Next = null;

// Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
// These never have special continuation context handling.
// Except for the scenario with ValueTaskContinuation that wraps ValueTaskSource
// which can capture continuation context flags.
const ContinuationFlags continueFlags =
ContinuationFlags.ContinueOnCapturedSynchronizationContext |
ContinuationFlags.ContinueOnThreadPool |
ContinuationFlags.ContinueOnCapturedTaskScheduler;

SetContinuationState(headContinuation);

try
{
if (stackState->CriticalNotifier is { } critNotifier)
{
Debug.Assert((headContinuation.Flags & continueFlags) == 0);
// Result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
// These never have special continuation context handling.
Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0);
critNotifier.UnsafeOnCompleted(GetContinuationAction());
}
else if (stackState->TaskNotifier is { } taskNotifier)
else if (stackState->TaskContinuation is { } taskCont)
{
Debug.Assert((headContinuation.Flags & continueFlags) == 0);
Debug.Assert(headContinuation == taskCont);
// Similarly for transparent awwaits we do not expect
// any continuation flags.
Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0);
// Runtime async callable wrapper for task returning
// method. This implements the context transparent
// forwarding and makes these wrappers minimal cost.
if (!taskNotifier.TryAddCompletionAction(this))
Debug.Assert(taskCont.Task != null);
if (!taskCont.Task.TryAddCompletionAction(this))
{
ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true);
}
Expand All @@ -589,7 +633,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
Debug.Assert(source != null);
if (source is Task t)
{
Debug.Assert((headContinuation.Flags & continueFlags) == 0);
Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0);
if (!t.TryAddCompletionAction(this))
{
ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true);
Expand All @@ -612,7 +656,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
// the direct AsyncHelpers.Await(ValueTask/ValueTask<T>) path.
// In either case, that can only happen in nontransparent/user code.
Continuation contWithContinueFlags = valueTaskSourceCont;
while ((contWithContinueFlags.Flags & continueFlags) == 0 && contWithContinueFlags.Next != null)
while ((contWithContinueFlags.Flags & ContinuationFlags.AllContinuationFlags) == 0 && contWithContinueFlags.Next != null)
{
contWithContinueFlags = contWithContinueFlags.Next;
}
Expand All @@ -630,7 +674,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
}

// Clear continuation flags, so that continuation runs transparently
contWithContinueFlags.Flags &= ~continueFlags;
contWithContinueFlags.Flags &= ~ContinuationFlags.AllContinuationFlags;

valueTaskSourceCont.OnCompletedValueTaskSource(
source,
Expand All @@ -642,7 +686,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
}
else
{
Debug.Assert((headContinuation.Flags & continueFlags) == 0);
Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0);
Debug.Assert(stackState->Notifier != null);
stackState->Notifier!.OnCompleted(GetContinuationAction());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Threading.Tasks;

namespace System.Runtime.CompilerServices
{
public static partial class AsyncHelpers
{
private sealed unsafe class TaskContinuation : Continuation
{
internal Task? Task;
private delegate*<Task, ref byte, void> _getResult;

public TaskContinuation()
{
ResumeInfo = (ResumeInfo*)Unsafe.AsPointer(in TaskContinuationResume.ResumeInfo);
}

public void GetResult(ref byte returnValue)
{
Debug.Assert(Task != null);

// Avoid retaining the task. The call below may throw.
Task task = Task;
Task = null;

_getResult(task, ref returnValue);
}
Comment thread
jakobbotsch marked this conversation as resolved.

public void Initialize(Task task)
{
Task = task;
_getResult = &GetResult;
}

public void Initialize<T>(Task<T> task)
{
Task = task;
_getResult = &GetResult<T>;
}

private static void GetResult(Task task, ref byte result)
{
TaskAwaiter.ValidateEnd(task);
}

private static void GetResult<T>(Task task, ref byte result)
{
Debug.Assert(task is Task<T>);

Task<T> taskOfT = Unsafe.As<Task, Task<T>>(ref task);
TaskAwaiter.ValidateEnd(taskOfT);
Unsafe.As<byte, T>(ref result) = taskOfT.ResultOnSuccess;
}

private static class TaskContinuationResume
{
[FixedAddressValueType]
public static readonly ResumeInfo ResumeInfo = new ResumeInfo
{
DiagnosticIP = null,
Resume = &ResumeTaskContinuation,
};

private static Continuation? ResumeTaskContinuation(Continuation cont, ref byte result)
{
var taskCont = (TaskContinuation)cont;
Comment thread
jakobbotsch marked this conversation as resolved.
taskCont.Next = null;

Debug.Assert((taskCont.Flags & ContinuationFlags.AllContinuationFlags) == 0);

t_runtimeAsyncAwaitState.CachedTaskContinuation = taskCont;

taskCont.GetResult(ref result);
return null;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,7 @@ private static class ValueTaskContinuationResume
var vtsCont = (ValueTaskContinuation)cont;
vtsCont.Next = null;

const ContinuationFlags continueFlags =
ContinuationFlags.ContinueOnCapturedSynchronizationContext |
ContinuationFlags.ContinueOnThreadPool |
ContinuationFlags.ContinueOnCapturedTaskScheduler;

Debug.Assert((vtsCont.Flags & continueFlags) == 0);
Debug.Assert((vtsCont.Flags & ContinuationFlags.AllContinuationFlags) == 0);

t_runtimeAsyncAwaitState.CachedValueTaskContinuation = vtsCont;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
<ItemGroup>
<!-- TODO: (async) once we know which helpers can actually be shared, move those to libraries partition -->
<Compile Include="$(CoreClrProjectRoot)System.Private.CoreLib\src\System\Runtime\CompilerServices\AsyncHelpers.CoreCLR.cs" />
<Compile Include="$(CoreClrProjectRoot)System.Private.CoreLib\src\System\Runtime\CompilerServices\AsyncHelpers.TaskContinuation.cs" />
<Compile Include="$(CoreClrProjectRoot)System.Private.CoreLib\src\System\Runtime\CompilerServices\AsyncHelpers.ValueTaskContinuation.cs" />
<Compile Include="$(CoreClrProjectRoot)System.Private.CoreLib\src\System\Runtime\CompilerServices\AsyncProfiler.CoreCLR.cs" />
</ItemGroup>
Expand Down
14 changes: 11 additions & 3 deletions src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,17 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
// Task path
TypeDesc taskType = taskReturningMethodReturnType;
MethodDesc completedTaskResultMethod;
MethodDesc transparentAwaitMethod;

if (!taskReturningMethodReturnType.HasInstantiation)
{
// Task (non-generic)
completedTaskResultMethod = context.SystemModule
.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("CompletedTask"u8, null);
transparentAwaitMethod = context.SystemModule
.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("TransparentAwait"u8, null);
}
else
{
Expand All @@ -357,7 +361,12 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
MethodDesc completedTaskResultMethodOpen = context.SystemModule
.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("CompletedTaskResult"u8, null);
MethodDesc transparentAwaitMethodOpen = context.SystemModule
.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("TransparentAwaitOfT"u8, null);

completedTaskResultMethod = completedTaskResultMethodOpen.MakeInstantiatedMethod(new Instantiation(logicalReturnType));
transparentAwaitMethod = transparentAwaitMethodOpen.MakeInstantiatedMethod(new Instantiation(logicalReturnType));
}

ILLocalVariable taskLocal = emitter.NewLocal(taskType);
Expand All @@ -373,9 +382,8 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
codestream.Emit(ILOpcode.brtrue, getResultLabel);

codestream.EmitLdLoc(taskLocal);
codestream.Emit(ILOpcode.call, emitter.NewToken(
context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("TransparentAwait"u8, null)));
codestream.Emit(ILOpcode.call, emitter.NewToken(context.GetCoreLibEntryPoint("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8, "TailAwait"u8, null)));
codestream.Emit(ILOpcode.call, emitter.NewToken(transparentAwaitMethod));

codestream.EmitLabel(getResultLabel);
codestream.EmitLdLoc(taskLocal);
Expand Down
20 changes: 17 additions & 3 deletions src/coreclr/vm/asyncthunks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig
// Task task = other(arg);
// if (!task.IsCompleted)
// {
// // Magic function which will suspend the current run of async methods
// TailAwait();
// AsyncHelpers.TransparentAwait(task);
// }
// return AsyncHelpers.CompletedTaskResult(task);
Expand Down Expand Up @@ -595,37 +595,51 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig
MethodTable* pMTTask;

int completedTaskResultToken;
int transparentAwaitToken;

if (msig.IsReturnTypeVoid())
{
pMTTask = CoreLibBinder::GetClass(CLASS__TASK);

MethodDesc* pMDCompletedTask = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__COMPLETED_TASK);
MethodDesc* pMDTransparentAwait = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT);

completedTaskResultToken = pCode->GetToken(pMDCompletedTask);
transparentAwaitToken = pCode->GetToken(pMDTransparentAwait);
}
else
{
MethodTable* pMTTaskOpen = CoreLibBinder::GetClass(CLASS__TASK_1);
pMTTask = ClassLoader::LoadGenericInstantiationThrowing(pMTTaskOpen->GetModule(), pMTTaskOpen->GetCl(), Instantiation(&thLogicalRetType, 1)).GetMethodTable();

MethodDesc* pMDCompletedTaskResult = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__COMPLETED_TASK_RESULT);
MethodDesc* pMDTransparentAwait = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT_OF_T);

pMDCompletedTaskResult = FindOrCreateAssociatedMethodDesc(pMDCompletedTaskResult, pMDCompletedTaskResult->GetMethodTable(), FALSE, Instantiation(&thLogicalRetType, 1), FALSE);
pMDTransparentAwait = FindOrCreateAssociatedMethodDesc(pMDTransparentAwait, pMDTransparentAwait->GetMethodTable(), FALSE, Instantiation(&thLogicalRetType, 1), FALSE);

completedTaskResultToken = GetTokenForGenericMethodCallWithAsyncReturnType(pCode, pMDCompletedTaskResult);
transparentAwaitToken = GetTokenForGenericMethodCallWithAsyncReturnType(pCode, pMDTransparentAwait);
}

LocalDesc taskLocalDesc(pMTTask);
DWORD taskLocal = pCode->NewLocal(taskLocalDesc);
ILCodeLabel* pGetResultLabel = pCode->NewCodeLabel();

// Store task returned by actual user func or by ValueTask.AsTask
// Store task returned by actual user func
pCode->EmitSTLOC(taskLocal);

// Did it already complete?
pCode->EmitLDLOC(taskLocal);
pCode->EmitCALL(METHOD__TASK__GET_ISCOMPLETED, 1, 1);
pCode->EmitBRTRUE(pGetResultLabel);

// No, so tail await to TransparentAwait
pCode->EmitLDLOC(taskLocal);
pCode->EmitCALL(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT, 1, 0);
pCode->EmitCALL(METHOD__ASYNC_HELPERS__TAIL_AWAIT, 0, 0);
pCode->EmitCALL(transparentAwaitToken, 1, 0);

// Yes, so just get the result
pCode->EmitLabel(pGetResultLabel);
pCode->EmitLDLOC(taskLocal);
pCode->EmitCALL(completedTaskResultToken, 1, msig.IsReturnTypeVoid() ? 0 : 1);
Expand Down
Loading
Loading