Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs related to channel dispose while there are active calls #2120

Merged
merged 8 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Grpc.Net.Client/Configuration/HedgingPolicy.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -37,7 +37,7 @@ public sealed class HedgingPolicy : ConfigObject
internal const string HedgingDelayPropertyName = "hedgingDelay";
internal const string NonFatalStatusCodesPropertyName = "nonFatalStatusCodes";

private ConfigProperty<Values<StatusCode, object>, IList<object>> _nonFatalStatusCodes =
private readonly ConfigProperty<Values<StatusCode, object>, IList<object>> _nonFatalStatusCodes =
new(i => new Values<StatusCode, object>(i ?? new List<object>(), s => ConvertHelpers.ConvertStatusCode(s), s => ConvertHelpers.ConvertStatusCode(s.ToString()!)), NonFatalStatusCodesPropertyName);

/// <summary>
Expand Down
38 changes: 25 additions & 13 deletions src/Grpc.Net.Client/GrpcChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,13 @@ internal void RegisterActiveCall(IDisposable grpcCall)
{
lock (_lock)
{
// Test the disposed flag inside the lock to ensure there is no chance of a race and adding a call after dispose.
// Note that a GrpcCall has been created but hasn't been started. The error will prevent it from starting.
if (Disposed)
{
throw new ObjectDisposedException(nameof(GrpcChannel));
}

ActiveCalls.Add(grpcCall);
}
}
Expand Down Expand Up @@ -733,23 +740,29 @@ public Task WaitForStateChangedAsync(ConnectivityState lastObservedState, Cancel
/// </summary>
public void Dispose()
{
if (Disposed)
{
return;
}

IDisposable[]? activeCallsCopy = null;
lock (_lock)
{
// Check and set disposed flag inside lock.
if (Disposed)
{
return;
}

if (ActiveCalls.Count > 0)
{
// Disposing a call will remove it from ActiveCalls. Need to take a copy
// to avoid enumeration from being modified
var activeCallsCopy = ActiveCalls.ToArray();
activeCallsCopy = ActiveCalls.ToArray();
}

foreach (var activeCall in activeCallsCopy)
{
activeCall.Dispose();
}
Disposed = true;
}

// Dispose calls outside of lock to avoid chance of deadlock.
if (activeCallsCopy is not null)
{
foreach (var activeCall in activeCallsCopy)
{
activeCall.Dispose();
}
}

Expand All @@ -760,7 +773,6 @@ public void Dispose()
#if SUPPORT_LOAD_BALANCING
ConnectionManager.Dispose();
#endif
Disposed = true;
}

internal bool TryAddToRetryBuffer(long messageSize)
Expand Down
7 changes: 1 addition & 6 deletions src/Grpc.Net.Client/Internal/HttpClientCallInvoker.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -150,11 +150,6 @@ public HttpClientCallInvoker(GrpcChannel channel)
where TRequest : class
where TResponse : class
{
if (channel.Disposed)
{
throw new ObjectDisposedException(nameof(GrpcChannel));
}

var methodInfo = channel.GetCachedGrpcMethodInfo(method);
var call = new GrpcCall<TRequest, TResponse>(method, methodInfo, options, channel, attempt);

Expand Down
4 changes: 3 additions & 1 deletion src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -54,6 +54,8 @@ public HedgingCall(HedgingPolicyInfo hedgingPolicy, GrpcChannel channel, Method<
_delayInterruptTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
_hedgingDelayCts = new CancellationTokenSource();
}

Channel.RegisterActiveCall(this);
}

private async Task StartCall(Action<GrpcCall<TRequest, TResponse>> startCallFunc)
Expand Down
4 changes: 3 additions & 1 deletion src/Grpc.Net.Client/Internal/Retry/RetryCall.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -42,6 +42,8 @@ public RetryCall(RetryPolicyInfo retryPolicy, GrpcChannel channel, Method<TReque
_retryPolicy = retryPolicy;

_nextRetryDelayMilliseconds = Convert.ToInt32(retryPolicy.InitialBackoff.TotalMilliseconds);

Channel.RegisterActiveCall(this);
}

private int CalculateNextRetryDelay()
Expand Down
4 changes: 3 additions & 1 deletion src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -440,6 +440,8 @@ protected virtual void Dispose(bool disposing)

protected void Cleanup()
{
Channel.FinishActiveCall(this);

_ctsRegistration?.Dispose();
_ctsRegistration = null;
CancellationTokenSource.Cancel();
Expand Down
6 changes: 5 additions & 1 deletion test/FunctionalTests/Client/RetryTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -108,6 +108,8 @@ await foreach (var message in requestStream.ReadAllAsync())

// Assert
Assert.IsTrue(result.Data.Span.SequenceEqual(sentData.ToArray()));

Assert.AreEqual(0, channel.ActiveCalls.Count);
}

[Test]
Expand Down Expand Up @@ -390,6 +392,8 @@ Task FakeServerStreamCall(DataMessage request, IServerStreamWriter<DataMessage>
await MakeCallsAsync(channel, method, references, cts.Token).DefaultTimeout();

// Assert
Assert.AreEqual(0, channel.ActiveCalls.Count);

// There is a race when cleaning up cancellation token registry.
// Retry a few times to ensure GC is run after unregister.
await TestHelpers.AssertIsTrueRetryAsync(() =>
Expand Down
82 changes: 81 additions & 1 deletion test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -17,13 +17,16 @@
#endregion

using System.Net;
using System.Threading.Tasks;
using Greet;
using Grpc.Core;
using Grpc.Net.Client.Internal;
using Grpc.Net.Client.Internal.Http;
using Grpc.Net.Client.Tests.Infrastructure;
using Grpc.Shared;
using Grpc.Tests.Shared;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using NUnit.Framework;

namespace Grpc.Net.Client.Tests;
Expand Down Expand Up @@ -177,4 +180,81 @@ public async Task AsyncDuplexStreamingCall_MessagesStreamed_MessagesReceived()
Assert.IsTrue(moveNextTask4.IsCompleted);
Assert.IsFalse(await moveNextTask3.DefaultTimeout());
}

[Test]
public async Task AsyncDuplexStreamingCall_CancellationDisposeRace_Success()
{
// Arrange
var services = new ServiceCollection();
services.AddNUnitLogger();
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger(GetType());

for (var i = 0; i < 20; i++)
{
// Let's mimic a real call first to get GrpcCall.RunCall where we need to for reproducing the deadlock.
var streamContent = new SyncPointMemoryStream();
var requestContentTcs = new TaskCompletionSource<Task<Stream>>(TaskCreationOptions.RunContinuationsAsynchronously);

PushStreamContent<HelloRequest, HelloReply>? content = null;

var handler = TestHttpMessageHandler.Create(async request =>
{
content = (PushStreamContent<HelloRequest, HelloReply>)request.Content!;
var streamTask = content.ReadAsStreamAsync();
requestContentTcs.SetResult(streamTask);
// Wait for RequestStream.CompleteAsync()
await streamTask;
return ResponseUtils.CreateResponse(HttpStatusCode.OK, new StreamContent(streamContent));
});
var channel = GrpcChannel.ForAddress("http://localhost", new GrpcChannelOptions
{
HttpHandler = handler,
LoggerFactory = loggerFactory
});
var invoker = channel.CreateCallInvoker();

var cts = new CancellationTokenSource();

var call = invoker.AsyncDuplexStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token));
await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout();
await call.RequestStream.CompleteAsync().DefaultTimeout();

// Let's read a response
var deserializationContext = new DefaultDeserializationContext();
var requestContent = await await requestContentTcs.Task.DefaultTimeout();
var requestMessage = await StreamSerializationHelper.ReadMessageAsync(
requestContent,
ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer,
GrpcProtocolConstants.IdentityGrpcEncoding,
maximumMessageSize: null,
GrpcProtocolConstants.DefaultCompressionProviders,
singleMessage: false,
CancellationToken.None).DefaultTimeout();
Assert.AreEqual("1", requestMessage!.Name);

var actTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);

var cancellationTask = Task.Run(async () =>
{
await actTcs.Task;
cts.Cancel();
});
var disposingTask = Task.Run(async () =>
{
await actTcs.Task;
channel.Dispose();
});

// Small pause to make sure we're waiting at the TCS everywhere.
await Task.Delay(50);

// Act
actTcs.SetResult(null);

// Assert
// Cancellation and disposing should both complete quickly. If there is a deadlock then the await will timeout.
await Task.WhenAll(cancellationTask, disposingTask).DefaultTimeout();
}
}
}
53 changes: 51 additions & 2 deletions test/Grpc.Net.Client.Tests/Retry/HedgingTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -72,6 +72,8 @@ public async Task AsyncUnaryCall_OneAttempt_Success(int maxAttempts)
var rs = await call.ResponseAsync.DefaultTimeout();
Assert.AreEqual("Hello world", rs.Message);
Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode);

Assert.AreEqual(0, invoker.Channel.ActiveCalls.Count);
}

[Test]
Expand Down Expand Up @@ -591,7 +593,6 @@ public async Task AsyncClientStreamingCall_SuccessAfterRetry_RequestContentSent(
var responseTask = call.ResponseAsync;
Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete.");


await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout();
await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout();

Expand Down Expand Up @@ -687,6 +688,54 @@ public async Task AsyncClientStreamingCall_WriteAfterResult_Error()
Assert.AreEqual(StatusCode.OK, ex.StatusCode);
}

[Test]
public void AsyncUnaryCall_DisposedChannel_Error()
{
// Arrange
var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK));
});
var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig();
var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig);

// Act & Assert
invoker.Channel.Dispose();
Assert.Throws<ObjectDisposedException>(() => invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "World" }));
}

[Test]
public async Task AsyncUnaryCall_ChannelDisposeDuringBackoff_CanceledStatus()
{
// Arrange
var callCount = 0;
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
{
callCount++;

await request.Content!.CopyToAsync(new MemoryStream());
return ResponseUtils.CreateHeadersOnlyResponse(HttpStatusCode.OK, StatusCode.Unavailable, retryPushbackHeader: TimeSpan.FromSeconds(10).TotalMilliseconds.ToString(CultureInfo.InvariantCulture));
});
var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10));
var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig);
var cts = new CancellationTokenSource();

// Act
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(cancellationToken: cts.Token), new HelloRequest { Name = "World" });

var delayTask = Task.Delay(100);
var completedTask = await Task.WhenAny(call.ResponseAsync, delayTask);

// Assert
Assert.AreEqual(delayTask, completedTask); // Ensure that we're waiting for retry

invoker.Channel.Dispose();

var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);
Assert.AreEqual("gRPC call disposed.", ex.Status.Detail);
}

private static Task<HelloRequest?> ReadRequestMessage(Stream requestContent)
{
return StreamSerializationHelper.ReadMessageAsync(
Expand Down