Skip to content

Commit

Permalink
Support IAsyncEnumerable<T> and ChannelReader<T> with ValueTypes in S…
Browse files Browse the repository at this point in the history
…ignalR native AOT

Support streaming ValueTypes from a SignalR Hub method in both the client and the server in native AOT. In order to make this work, we need to use pure reflection to read from the streaming object.

Support passing in an IAsyncEnumerable/ChannelReader of ValueType to a parameter in SignalR.Client. This works because the user code creates the concrete object, and the SignalR.Client library just needs to read from it using reflection.

The only scenario that can't be supported is on the SignalR server we can't support receiving an IAsyncEnumerable/ChannelReader of ValueType. This is because there is no way for the SignalR library code to construct a concrete instance to pass into the user-defined method on native AOT.

Fix dotnet#56179
  • Loading branch information
eerhardt committed Jul 2, 2024
1 parent 102f4bd commit c3f7649
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 48 deletions.
63 changes: 53 additions & 10 deletions src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -857,26 +857,69 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
[UnconditionalSuppressMessage("Trimming", "IL2060:MakeGenericMethod",
Justification = "The methods passed into here (SendStreamItems and SendIAsyncEnumerableStreamItems) don't have trimming annotations.")]
[UnconditionalSuppressMessage("AOT", "IL3050:RequiresDynamicCode",
Justification = "There is a runtime check for ValueType streaming item type when PublishAot=true. Developers will get an exception in this situation before publishing.")]
Justification = "ValueTypes are handled without using MakeGenericMethod.")]
private void InvokeStreamMethod(MethodInfo methodInfo, Type[] genericTypes, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource)
{
#if NET
Debug.Assert(genericTypes.Length == 1);

#if NET6_0_OR_GREATER
if (!RuntimeFeature.IsDynamicCodeSupported && genericTypes[0].IsValueType)
{
// NativeAOT apps are not able to stream IAsyncEnumerable and ChannelReader of ValueTypes
// since we cannot create SendStreamItems and SendIAsyncEnumerableStreamItems methods with a generic ValueType.
throw new InvalidOperationException($"Unable to stream an item with type '{genericTypes[0]}' because it is a ValueType. Native code to support streaming this ValueType will not be available with native AOT.");
_ = ReflectionSendStreamItems(methodInfo, connectionState, streamId, reader, tokenSource);
}
else
#endif
{
_ = methodInfo
.MakeGenericMethod(genericTypes)
.Invoke(this, [connectionState, streamId, reader, tokenSource]);
}
}

#if NET6_0_OR_GREATER

/// <summary>
/// Uses reflection to read items from an IAsyncEnumerable{T} or ChannelReader{T} and send them to the server.
///
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
/// we cannot use MakeGenericMethod to call the appropriate SendStreamItems method because the generic type is a value type.
/// </summary>
private Task ReflectionSendStreamItems(MethodInfo methodInfo, ConnectionState connectionState, string streamId, object reader, CancellationTokenSource tokenSource)
{
async Task ReadAsyncEnumeratorStream(IAsyncEnumerator<object?> enumerator)
{
try
{
while (await enumerator.MoveNextAsync().ConfigureAwait(false))
{
await SendWithLock(connectionState, new StreamItemMessage(streamId, enumerator.Current), tokenSource.Token).ConfigureAwait(false);
Log.SendingStreamItem(_logger, streamId);
}
}
finally
{
await enumerator.DisposeAsync().ConfigureAwait(false);
}
}

_ = methodInfo
.MakeGenericMethod(genericTypes)
.Invoke(this, [connectionState, streamId, reader, tokenSource]);
Func<Task> createAndConsumeStream;
if (methodInfo == _sendStreamItemsMethod)
{
// reader is a ChannelReader<T>
createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumeratorFromChannel(reader, tokenSource.Token));
}
else
{
// reader is an IAsyncEnumerable<T>
Debug.Assert(methodInfo == _sendIAsyncStreamItemsMethod);

createAndConsumeStream = () => ReadAsyncEnumeratorStream(AsyncEnumerableAdapters.MakeReflectionAsyncEnumerator(reader, tokenSource.Token));
}

return CommonStreaming(connectionState, streamId, createAndConsumeStream, tokenSource);
}
#endif

// this is called via reflection using the `_sendStreamItems` field
// this is called via reflection using the `_sendStreamItemsMethod` field
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationTokenSource tokenSource)
{
async Task ReadChannelStream()
Expand Down
118 changes: 114 additions & 4 deletions src/SignalR/common/Shared/AsyncEnumerableAdapters.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Reflection;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
Expand All @@ -11,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal;
// True-internal because this is a weird and tricky class to use :)
internal static class AsyncEnumerableAdapters
{
public static IAsyncEnumerator<object?> MakeCancelableAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
public static IAsyncEnumerator<object?> MakeAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
{
var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken);
return enumerator as IAsyncEnumerator<object?> ?? new BoxedAsyncEnumerator<T>(enumerator);
Expand Down Expand Up @@ -52,10 +54,13 @@ public ValueTask<bool> MoveNextAsync()

private async Task<bool> MoveNextAsyncAwaited()
{
if (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false) && _channel.TryRead(out var item))
while (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false))
{
Current = item;
return true;
if (_channel.TryRead(out var item))
{
Current = item;
return true;
}
}
return false;
}
Expand Down Expand Up @@ -137,4 +142,109 @@ public ValueTask DisposeAsync()
return _asyncEnumerator.DisposeAsync();
}
}

#if NET6_0_OR_GREATER

private static readonly MethodInfo _asyncEnumerableGetAsyncEnumeratorMethodInfo = typeof(IAsyncEnumerable<>).GetMethod("GetAsyncEnumerator")!;

/// <summary>
/// Creates an IAsyncEnumerator{object} from an IAsyncEnumerable{T} using reflection.
///
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
/// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type.
/// </summary>
public static IAsyncEnumerator<object?> MakeReflectionAsyncEnumerator(object asyncEnumerable, CancellationToken cancellationToken)
{
var constructedIAsyncEnumerableInterface = ReflectionHelper.GetIAsyncEnumerableInterface(asyncEnumerable.GetType())!;
var enumerator = ((MethodInfo)constructedIAsyncEnumerableInterface.GetMemberWithSameMetadataDefinitionAs(_asyncEnumerableGetAsyncEnumeratorMethodInfo)).Invoke(asyncEnumerable, [cancellationToken])!;
return new ReflectionAsyncEnumerator(enumerator);
}

/// <summary>
/// Creates an IAsyncEnumerator{object} from a ChannelReader{T} using reflection.
///
/// Used when the runtime does not support dynamic code generation (ex. native AOT) and the generic type is a value type. In this scenario,
/// we cannot use MakeGenericMethod to call a generic method because the generic type is a value type.
/// </summary>
public static IAsyncEnumerator<object?> MakeReflectionAsyncEnumeratorFromChannel(object channelReader, CancellationToken cancellationToken)
{
return new ReflectionChannelAsyncEnumerator(channelReader, cancellationToken);
}

private sealed class ReflectionAsyncEnumerator : IAsyncEnumerator<object?>
{
private static readonly MethodInfo _asyncEnumeratorMoveNextAsyncMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("MoveNextAsync")!;
private static readonly MethodInfo _asyncEnumeratorGetCurrentMethodInfo = typeof(IAsyncEnumerator<>).GetMethod("get_Current")!;

private readonly object _enumerator;
private readonly MethodInfo _moveNextAsyncMethodInfo;
private readonly MethodInfo _getCurrentMethodInfo;

public ReflectionAsyncEnumerator(object enumerator)
{
_enumerator = enumerator;

var type = ReflectionHelper.GetIAsyncEnumeratorInterface(enumerator.GetType())!;
_moveNextAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorMoveNextAsyncMethodInfo)!;
_getCurrentMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_asyncEnumeratorGetCurrentMethodInfo)!;
}

public object? Current => _getCurrentMethodInfo.Invoke(_enumerator, []);

public ValueTask<bool> MoveNextAsync() => (ValueTask<bool>)_moveNextAsyncMethodInfo.Invoke(_enumerator, [])!;

public ValueTask DisposeAsync() => ((IAsyncDisposable)_enumerator).DisposeAsync();
}

private sealed class ReflectionChannelAsyncEnumerator : IAsyncEnumerator<object?>
{
private static readonly MethodInfo _channelReaderTryReadMethodInfo = typeof(ChannelReader<>).GetMethod("TryRead")!;
private static readonly MethodInfo _channelReaderWaitToReadAsyncMethodInfo = typeof(ChannelReader<>).GetMethod("WaitToReadAsync")!;

private readonly object _channelReader;
private readonly object?[] _tryReadResult = [null];
private readonly object[] _waitToReadArgs;
private readonly MethodInfo _tryReadMethodInfo;
private readonly MethodInfo _waitToReadAsyncMethodInfo;

public ReflectionChannelAsyncEnumerator(object channelReader, CancellationToken cancellationToken)
{
_channelReader = channelReader;
_waitToReadArgs = [cancellationToken];

var type = channelReader.GetType();
_tryReadMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderTryReadMethodInfo)!;
_waitToReadAsyncMethodInfo = (MethodInfo)type.GetMemberWithSameMetadataDefinitionAs(_channelReaderWaitToReadAsyncMethodInfo)!;
}

public object? Current { get; private set; }

public ValueTask<bool> MoveNextAsync()
{
if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!)
{
Current = _tryReadResult[0];
return new ValueTask<bool>(true);
}

return new ValueTask<bool>(MoveNextAsyncAwaited());
}

private async Task<bool> MoveNextAsyncAwaited()
{
while (await ((ValueTask<bool>)_waitToReadAsyncMethodInfo.Invoke(_channelReader, _waitToReadArgs)!).ConfigureAwait(false))
{
if ((bool)_tryReadMethodInfo.Invoke(_channelReader, _tryReadResult)!)
{
Current = _tryReadResult[0];
return true;
}
}
return false;
}

public ValueTask DisposeAsync() => default;
}

#endif
}
23 changes: 23 additions & 0 deletions src/SignalR/common/Shared/ReflectionHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,27 @@ public static bool TryGetStreamType(Type streamType, [NotNullWhen(true)] out Typ

return null;
}

[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070:UnrecognizedReflectionPattern",
Justification = "The 'IAsyncEnumerator<>' Type must exist and so trimmer kept it. In which case " +
"It also kept it on any type which implements it. The below call to GetInterfaces " +
"may return fewer results when trimmed but it will return 'IAsyncEnumerator<>' " +
"if the type implemented it, even after trimming.")]
public static Type? GetIAsyncEnumeratorInterface(Type type)
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>))
{
return type;
}

foreach (Type typeToCheck in type.GetInterfaces())
{
if (typeToCheck.IsGenericType && typeToCheck.GetGenericTypeDefinition() == typeof(IAsyncEnumerator<>))
{
return typeToCheck;
}
}

return null;
}
}
Loading

0 comments on commit c3f7649

Please sign in to comment.