diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 6a09abb8d026..9d5b908af61d 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -565,20 +565,13 @@ public IAsyncEnumerable StreamAsyncCore(string methodName, obj private async IAsyncEnumerable CastIAsyncEnumerable(string methodName, object[] args, CancellationTokenSource cts) { var reader = await StreamAsChannelCoreAsync(methodName, typeof(T), args, cts.Token); - try + while (await reader.WaitToReadAsync(cts.Token)) { - while (await reader.WaitToReadAsync(cts.Token)) + while (reader.TryRead(out var item)) { - while (reader.TryRead(out var item)) - { - yield return (T)item; - } + yield return (T)item; } } - finally - { - cts.Dispose(); - } } private async Task> StreamAsChannelCoreAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs index 618d6f14a92b..c1090a0df66b 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs @@ -254,6 +254,36 @@ public async Task StreamAsyncCanceledWhenPassedCanceledToken() } } + [Fact] + public async Task CanCancelTokenAfterStreamIsCompleted() + { + using (StartVerifiableLog()) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + + await hubConnection.StartAsync().OrTimeout(); + + var asyncEnumerable = hubConnection.StreamAsync("Stream", 1); + using var cts = new CancellationTokenSource(); + await using var e = asyncEnumerable.GetAsyncEnumerator(cts.Token); + var task = e.MoveNextAsync(); + + var item = await connection.ReadSentJsonAsync().OrTimeout(); + await connection.ReceiveJsonMessage( + new { type = HubProtocolConstants.CompletionMessageType, invocationId = item["invocationId"] } + ).OrTimeout(); + + await task.OrTimeout(); + + while (await e.MoveNextAsync().OrTimeout()) + { + } + // Cancel after stream is completed but before the AsyncEnumerator is disposed + cts.Cancel(); + } + } + [Fact] public async Task ConnectionTerminatedIfServerTimeoutIntervalElapsesWithNoMessages() {