Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Allow multiple calls to GetAsyncEnumerator (#31105)
  • Loading branch information
jcouv committed Nov 18, 2018
1 parent 578ce2e commit 71e1473
Show file tree
Hide file tree
Showing 8 changed files with 483 additions and 176 deletions.
42 changes: 36 additions & 6 deletions docs/features/async-streams.md
Expand Up @@ -96,7 +96,8 @@ The state machine for an async-iterator method primarily implements `IAsyncEnume
It is similar to a state machine produced for an async method. It contains builder and awaiter fields, used to run the state machine in the background (when an `await` is reached in the async-iterator). It also captures parameter values (if any) or `this` (if needed).
But it contains additional state:
- a promise of a value-or-end,
- a current yielded value of type `T`.
- a current yielded value of type `T`,
- an `int` capturing the id of the thread that created it.

The central method of the state machine is `MoveNext()`. It gets run by `MoveNextAsync()`, or as a background continuation initiated from these from an `await` in the method.

Expand All @@ -120,17 +121,46 @@ This is reflected in the implementation, which extends the lowering machinery fo
```C#
ValueTask<bool> MoveNextAsync()
{
if (State == StateMachineStates.FinishedStateMachine)
if (state == StateMachineStates.FinishedStateMachine)
{
return default(ValueTask<bool>);
}
_valueOrEndPromise.Reset();
valueOrEndPromise.Reset();
var inst = this;
_builder.Start(ref inst);
return new ValueTask<bool>(this, _valueOrEndPromise.Version);
builder.Start(ref inst);
return new ValueTask<bool>(this, valueOrEndPromise.Version);
}
```

```C#
T Current => _current;
T Current => current;
```

The kick-off method and the initialization of the state machine for an async-iterator method follows those for regular iterator methods.
In particular, the synthesized `GetAsyncEnumerator()` method is like `GetEnuemrator()` except that it sets the initial state to to StateMachineStates.NotStartedStateMachine (-1):
```C#
IAsyncEnumerator<T> GetAsyncEnumerator()
{
{StateMachineType} result;
if (initialThreadId == /*managedThreadId*/ && state == StateMachineStates.FinishedStateMachine)
{
state = StateMachineStates.NotStartedStateMachine;
result = this;
}
else
{
result = new {StateMachineType}(StateMachineStates.NotStartedStateMachine);
}
/* copy all of the parameter proxies */
}
```
For a discussion of the threadID check, see https://github.com/dotnet/corefx/issues/3481

Similarly, the kick-off method is much like those of regular iterator methods:
```C#
{
{StateMachineType} result = new {StateMachineType}(StateMachineStates.FinishedStateMachine); // -2
/* save parameters into parameter proxies */
return result;
}
```
Expand Up @@ -59,7 +59,7 @@ protected override void VerifyPresenceOfRequiredAPIs(DiagnosticBag bag)

protected override void GenerateMethodImplementations()
{
// IAsyncStateMachine and constructor
// IAsyncStateMachine methods and constructor
base.GenerateMethodImplementations();

// IAsyncEnumerable
Expand Down Expand Up @@ -100,52 +100,71 @@ protected override void GenerateControlFields()
_currentField = F.StateMachineField(elementType, GeneratedNames.MakeIteratorCurrentFieldName());
}

/// <summary>
/// Generates the body of the replacement method, which initializes the state machine. Unlike regular async methods, we won't start it.
/// </summary>
protected override BoundStatement GenerateStateMachineCreation(LocalSymbol stateMachineVariable, NamedTypeSymbol frameType)
protected override void GenerateConstructor()
{
// If the async method's result type is a type parameter of the method, then the AsyncTaskMethodBuilder<T>
// needs to use the method's type parameters inside the rewritten method body. All other methods generated
// during async rewriting are members of the synthesized state machine struct, and use the type parameters
// from the struct.
AsyncMethodBuilderMemberCollection methodScopeAsyncMethodBuilderMemberCollection;
if (!AsyncMethodBuilderMemberCollection.TryCreate(F, method, null, out methodScopeAsyncMethodBuilderMemberCollection))
// Produces:
// .ctor(int state)
// {
// this.state = state;
// this.initialThreadId = {managedThreadId};
// this.builder = System.Runtime.CompilerServices.AsyncVoidMethodBuilder.Create();
// this.valueOrEndPromise = new ManualResetValueTaskSourceLogic<bool>(this);
// }
Debug.Assert(stateMachineType.Constructor is IteratorConstructor);

F.CurrentFunction = stateMachineType.Constructor;
var bodyBuilder = ArrayBuilder<BoundStatement>.GetInstance();
bodyBuilder.Add(F.BaseInitialization());
bodyBuilder.Add(F.Assignment(F.Field(F.This(), stateField), F.Parameter(F.CurrentFunction.Parameters[0]))); // this.state = state;

var managedThreadId = MakeCurrentThreadId();
if (managedThreadId != null && (object)initialThreadIdField != null)
{
return new BoundBadStatement(F.Syntax, ImmutableArray<BoundNode>.Empty, hasErrors: true);
// this.initialThreadId = {managedThreadId};
bodyBuilder.Add(F.Assignment(F.Field(F.This(), initialThreadIdField), managedThreadId));
}

var bodyBuilder = ArrayBuilder<BoundStatement>.GetInstance();
// this.builder = System.Runtime.CompilerServices.AsyncVoidMethodBuilder.Create();
AsyncMethodBuilderMemberCollection methodScopeAsyncMethodBuilderMemberCollection;
bool found = AsyncMethodBuilderMemberCollection.TryCreate(F, method, typeMap: null, out methodScopeAsyncMethodBuilderMemberCollection);
Debug.Assert(found);

// local.$builder = System.Runtime.CompilerServices.AsyncTaskMethodBuilder<typeArgs>.Create();
bodyBuilder.Add(
F.Assignment(
F.Field(F.Local(stateMachineVariable), _builderField.AsMember(frameType)),
F.Field(F.This(), _builderField),
F.StaticCall(
null,
methodScopeAsyncMethodBuilderMemberCollection.CreateBuilder)));

// local.$stateField = NotStartedStateMachine;
bodyBuilder.Add(
F.Assignment(
F.Field(F.Local(stateMachineVariable), stateField.AsMember(frameType)),
F.Literal(StateMachineStates.NotStartedStateMachine)));

// local._valueOrEndPromise = new ManualResetValueTaskSourceLogic<bool>(stateMachine);
// this._valueOrEndPromise = new ManualResetValueTaskSourceLogic<bool>(this);
MethodSymbol mrvtslCtor =
F.WellKnownMethod(WellKnownMember.System_Threading_Tasks_ManualResetValueTaskSourceLogic_T__ctor)
.AsMember((NamedTypeSymbol)_promiseOfValueOrEndField.Type.TypeSymbol);

bodyBuilder.Add(
F.Assignment(
F.Field(F.Local(stateMachineVariable), _promiseOfValueOrEndField.AsMember(frameType)),
F.New(mrvtslCtor, F.Local(stateMachineVariable))));
F.Field(F.This(), _promiseOfValueOrEndField),
F.New(mrvtslCtor, F.This())));

// return local;
bodyBuilder.Add(F.Return(F.Local(stateMachineVariable)));
bodyBuilder.Add(F.Return());
F.CloseMethod(F.Block(bodyBuilder.ToImmutableAndFree()));
bodyBuilder = null;
}

return F.Block(
bodyBuilder.ToImmutableAndFree());
protected override void InitializeStateMachine(ArrayBuilder<BoundStatement> bodyBuilder, NamedTypeSymbol frameType, LocalSymbol stateMachineLocal)
{
// var stateMachineLocal = new {StateMachineType}(FinishedStateMachine)
int initialState = StateMachineStates.FinishedStateMachine;
bodyBuilder.Add(
F.Assignment(
F.Local(stateMachineLocal),
F.New(stateMachineType.Constructor.AsMember(frameType), F.Literal(initialState))));
}

protected override BoundStatement GenerateStateMachineCreation(LocalSymbol stateMachineVariable, NamedTypeSymbol frameType)
{
// return local;
return F.Block(F.Return(F.Local(stateMachineVariable)));
}

/// <summary>
Expand Down Expand Up @@ -396,8 +415,6 @@ private void GenerateIAsyncDisposable_DisposeAsync()
/// </summary>
private void GenerateIAsyncEnumerableImplementation_GetAsyncEnumerator()
{
// https://github.com/dotnet/roslyn/issues/30275 do the threadID dance to decide if we can return this or should instantiate.

NamedTypeSymbol IAsyncEnumerableOfElementType =
F.WellKnownType(WellKnownType.System_Collections_Generic_IAsyncEnumerable_T)
.Construct(_currentField.Type.TypeSymbol);
Expand All @@ -406,14 +423,8 @@ private void GenerateIAsyncEnumerableImplementation_GetAsyncEnumerator()
F.WellKnownMethod(WellKnownMember.System_Collections_Generic_IAsyncEnumerable_T__GetAsyncEnumerator)
.AsMember(IAsyncEnumerableOfElementType);

// The implementation doesn't depend on the method body of the iterator method.
// Generates IAsyncEnumerator<elementType> IAsyncEnumerable<elementType>.GetEnumerator()
OpenMethodImplementation(IAsyncEnumerableOfElementType_GetEnumerator, hasMethodBodyDependency: false);

// https://github.com/dotnet/roslyn/issues/30275 0 may not be the proper state to start with
F.CloseMethod(F.Block(
//F.Assignment(F.Field(F.This(), stateField), F.Literal(StateMachineStates.FirstUnusedState)), // this.state = 0;
F.Return(F.This()))); // return this;
BoundExpression managedThreadId = null;
GenerateIteratorGetEnumerator(IAsyncEnumerableOfElementType_GetEnumerator, ref managedThreadId, StateMachineStates.NotStartedStateMachine);
}

protected override void GenerateMoveNext(SynthesizedImplementationMethod moveNextMethod)
Expand Down
Expand Up @@ -106,10 +106,9 @@ private Symbol EnsureWellKnownMember(WellKnownMember member, DiagnosticBag bag)
return Binder.GetWellKnownTypeMember(F.Compilation, member, bag, body.Syntax.Location);
}

protected override bool PreserveInitialParameterValues
{
get { return false; }
}
// Should only be true for async-enumerables, not async-enumerators. Tracked by https://github.com/dotnet/roslyn/issues/31057
protected override bool PreserveInitialParameterValuesAndThreadId
=> method.IsIterator;

protected override void GenerateControlFields()
{
Expand Down Expand Up @@ -158,6 +157,11 @@ protected override void GenerateMethodImplementations()
}

// Constructor
GenerateConstructor();
}

protected virtual void GenerateConstructor()
{
if (stateMachineType.TypeKind == TypeKind.Class)
{
F.CurrentFunction = stateMachineType.Constructor;
Expand Down
Expand Up @@ -25,7 +25,8 @@ public AsyncStateMachine(VariableSlotAllocator variableAllocatorOpt, TypeCompila
CSharpCompilation compilation = asyncMethod.DeclaringCompilation;
var interfaces = ArrayBuilder<NamedTypeSymbol>.GetInstance();

if (asyncMethod.IsIterator)
bool isIterator = asyncMethod.IsIterator;
if (isIterator)
{
var elementType = TypeMap.SubstituteType(asyncMethod.IteratorElementType).TypeSymbol;
this.IteratorElementType = elementType;
Expand All @@ -51,7 +52,7 @@ public AsyncStateMachine(VariableSlotAllocator variableAllocatorOpt, TypeCompila
interfaces.Add(compilation.GetWellKnownType(WellKnownType.System_Runtime_CompilerServices_IAsyncStateMachine));
_interfaces = interfaces.ToImmutableAndFree();

_constructor = new AsyncConstructor(this);
_constructor = isIterator ? (MethodSymbol)new IteratorConstructor(this) : new AsyncConstructor(this);
}

public override TypeKind TypeKind
Expand Down
Expand Up @@ -15,7 +15,7 @@ internal sealed class IteratorConstructor : SynthesizedInstanceConstructor, ISyn
{
private readonly ImmutableArray<ParameterSymbol> _parameters;

internal IteratorConstructor(IteratorStateMachine container)
internal IteratorConstructor(StateMachineTypeSymbol container)
: base(container)
{
var intType = container.DeclaringCompilation.GetSpecialType(SpecialType.System_Int32);
Expand Down

0 comments on commit 71e1473

Please sign in to comment.