Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CancellationToken available in call credentials interceptor #2107

Merged
merged 5 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,7 @@ protected override string PeerCore
get
{
// Follows the standard at https://github.com/grpc/grpc/blob/master/doc/naming.md
if (_peer == null)
{
_peer = BuildPeer();
}

return _peer;
return _peer ??= BuildPeer();
}
}

Expand Down Expand Up @@ -291,10 +286,7 @@ private void EndCallCore()

private void LogCallEnd()
{
if (_activity != null)
{
_activity.AddTag(GrpcServerConstants.ActivityStatusCodeTag, _status.StatusCode.ToTrailerString());
}
_activity?.AddTag(GrpcServerConstants.ActivityStatusCodeTag, _status.StatusCode.ToTrailerString());
if (_status.StatusCode != StatusCode.OK)
{
if (GrpcEventSource.Log.IsEnabled())
Expand Down Expand Up @@ -387,10 +379,7 @@ protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
public void Initialize(ISystemClock? clock = null)
{
_activity = GetHostActivity();
if (_activity != null)
{
_activity.AddTag(GrpcServerConstants.ActivityMethodTag, MethodCore);
}
_activity?.AddTag(GrpcServerConstants.ActivityMethodTag, MethodCore);

if (GrpcEventSource.Log.IsEnabled())
{
Expand Down
24 changes: 21 additions & 3 deletions src/Grpc.Core.Api/AsyncAuthInterceptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#endregion

using System.Threading;
using System.Threading.Tasks;
using Grpc.Core.Utils;

Expand All @@ -34,16 +35,25 @@ namespace Grpc.Core;
/// </summary>
public class AuthInterceptorContext
{
readonly string serviceUrl;
readonly string methodName;
private readonly string serviceUrl;
private readonly string methodName;
private readonly CancellationToken cancellationToken;

/// <summary>
/// Initializes a new instance of <c>AuthInterceptorContext</c>.
/// </summary>
public AuthInterceptorContext(string serviceUrl, string methodName)
public AuthInterceptorContext(string serviceUrl, string methodName) : this(serviceUrl, methodName, CancellationToken.None)
{
}

/// <summary>
/// Initializes a new instance of <c>AuthInterceptorContext</c>.
/// </summary>
public AuthInterceptorContext(string serviceUrl, string methodName, CancellationToken cancellationToken)
{
this.serviceUrl = GrpcPreconditions.CheckNotNull(serviceUrl, nameof(serviceUrl));
this.methodName = GrpcPreconditions.CheckNotNull(methodName, nameof(methodName));
this.cancellationToken = cancellationToken;
}

/// <summary>
Expand All @@ -61,4 +71,12 @@ public string MethodName
{
get { return methodName; }
}

/// <summary>
/// The cancellation token of the RPC being called.
/// </summary>
public CancellationToken CancellationToken
{
get { return cancellationToken; }
}
}
16 changes: 10 additions & 6 deletions src/Grpc.Net.Client/Internal/DefaultCallCredentialsConfigurator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -20,15 +20,19 @@

namespace Grpc.Net.Client.Internal;

internal class DefaultCallCredentialsConfigurator : CallCredentialsConfiguratorBase
internal sealed class DefaultCallCredentialsConfigurator : CallCredentialsConfiguratorBase
{
public AsyncAuthInterceptor? Interceptor { get; private set; }
public IReadOnlyList<CallCredentials>? Credentials { get; private set; }
public IReadOnlyList<CallCredentials>? CompositeCredentials { get; private set; }

public void Reset()
// A place to cache the context to avoid creating a new instance for each auth interceptor call.
// It's ok not to reset this state because the context is only used for the lifetime of the call.
public AuthInterceptorContext? CachedContext { get; set; }

public void ResetPerCallCredentialState()
{
Interceptor = null;
Credentials = null;
CompositeCredentials = null;
}

public override void SetAsyncAuthInterceptorCredentials(object? state, AsyncAuthInterceptor interceptor)
Expand All @@ -38,6 +42,6 @@ public override void SetAsyncAuthInterceptorCredentials(object? state, AsyncAuth

public override void SetCompositeCredentials(object? state, IReadOnlyList<CallCredentials> credentials)
{
Credentials = credentials;
CompositeCredentials = credentials;
}
}
4 changes: 2 additions & 2 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -955,13 +955,13 @@ private async Task ReadCredentials(HttpRequestMessage request)

if (Options.Credentials != null)
{
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, Options.Credentials).ConfigureAwait(false);
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, Options.Credentials, _callCts.Token).ConfigureAwait(false);
}
if (Channel.CallCredentials?.Count > 0)
{
foreach (var credentials in Channel.CallCredentials)
{
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, credentials).ConfigureAwait(false);
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, credentials, _callCts.Token).ConfigureAwait(false);
}
}
}
Expand Down
63 changes: 45 additions & 18 deletions src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -121,13 +121,34 @@ internal static bool ShouldSkipHeader(string name)
/* round an integer up to the next value with three significant figures */
private static long TimeoutRoundUpToThreeSignificantFigures(long x)
{
if (x < 1000) return x;
if (x < 10000) return RoundUp(x, 10);
if (x < 100000) return RoundUp(x, 100);
if (x < 1000000) return RoundUp(x, 1000);
if (x < 10000000) return RoundUp(x, 10000);
if (x < 100000000) return RoundUp(x, 100000);
if (x < 1000000000) return RoundUp(x, 1000000);
if (x < 1000)
{
return x;
}
if (x < 10000)
{
return RoundUp(x, 10);
}
if (x < 100000)
{
return RoundUp(x, 100);
}
if (x < 1000000)
{
return RoundUp(x, 1000);
}
if (x < 10000000)
{
return RoundUp(x, 10000);
}
if (x < 100000000)
{
return RoundUp(x, 100000);
}
if (x < 1000000000)
{
return RoundUp(x, 1000000);
}
return RoundUp(x, 10000000);

static long RoundUp(long x, long divisor)
Expand Down Expand Up @@ -235,7 +256,7 @@ internal static bool CanWriteCompressed(WriteOptions? writeOptions)
return canCompress;
}

internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddress, IMethod method)
internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddress, IMethod method, CancellationToken cancellationToken)
{
var authority = baseAddress.Authority;
if (baseAddress.Scheme == Uri.UriSchemeHttps && authority.EndsWith(":443", StringComparison.Ordinal))
Expand All @@ -252,38 +273,44 @@ internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddr
serviceUrl += "/";
}
serviceUrl += method.ServiceName;
return new AuthInterceptorContext(serviceUrl, method.Name);
return new AuthInterceptorContext(serviceUrl, method.Name, cancellationToken);
}

internal static async Task ReadCredentialMetadata(
DefaultCallCredentialsConfigurator configurator,
GrpcChannel channel,
HttpRequestMessage message,
IMethod method,
CallCredentials credentials)
CallCredentials credentials,
CancellationToken cancellationToken)
{
credentials.InternalPopulateConfiguration(configurator, null);

if (configurator.Interceptor != null)
{
var authInterceptorContext = GrpcProtocolHelpers.CreateAuthInterceptorContext(channel.Address, method);
// Multiple auth interceptors can be called for a gRPC call.
// These all have the same data: address, method and cancellation token.
// Lazily allocate the context if it is needed.
// Stored on the configurator instead of a ref parameter because ref parameters are not supported in async methods.
configurator.CachedContext ??= CreateAuthInterceptorContext(channel.Address, method, cancellationToken);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looked scary at first, but the call sites don't reuse the DefaultCallCredentialsConfigurator for different channel + method + cancellation combos.

Can we make that more obvious somehow so this code doesn't look scary?

e.g.

ReadCredentialMetadata(...)
{
    ReadCredentialMetadataCore(...);
    configurator.CachedContext = null;
}

// Rename current method
ReadCredentialMetadataInner(...)
{
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated comments and renamed some fields to make it clearer what is going on.

var metadata = new Metadata();
await configurator.Interceptor(authInterceptorContext, metadata).ConfigureAwait(false);
await configurator.Interceptor(configurator.CachedContext, metadata).ConfigureAwait(false);

foreach (var entry in metadata)
{
AddHeader(message.Headers, entry);
}
}

if (configurator.Credentials != null)
if (configurator.CompositeCredentials != null)
{
// Copy credentials locally. ReadCredentialMetadata will update it.
var callCredentials = configurator.Credentials;
foreach (var c in callCredentials)
var compositeCredentials = configurator.CompositeCredentials;
foreach (var callCredentials in compositeCredentials)
{
configurator.Reset();
await ReadCredentialMetadata(configurator, channel, message, method, c).ConfigureAwait(false);
configurator.ResetPerCallCredentialState();
await ReadCredentialMetadata(configurator, channel, message, method, callCredentials, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
61 changes: 57 additions & 4 deletions test/Grpc.Net.Client.Tests/CallCredentialTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -18,6 +18,7 @@

using System.Net;
using System.Net.Http.Headers;
using System.Threading;
using Greet;
using Grpc.Core;
using Grpc.Net.Client.Tests.Infrastructure;
Expand Down Expand Up @@ -79,19 +80,71 @@ public async Task CallCredentialsWithHttps_MetadataOnRequest()
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var syncPoint = new SyncPoint(runContinuationsAsynchronously: true);
var callCredentials = CallCredentials.FromInterceptor(async (context, metadata) =>
{
// The operation is asynchronous to ensure delegate is awaited
await Task.Delay(50);
// The operation is asynchronous to ensure auth interceptor is awaited.
// Sending the request and returning a response is blocked until the auth interceptor completes.
await syncPoint.WaitToContinue();

// Set header.
metadata.Add("authorization", "SECRET_TOKEN");
});
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(credentials: callCredentials), new HelloRequest());
await call.ResponseAsync.DefaultTimeout();
var responseTask = call.ResponseAsync;

await syncPoint.WaitForSyncPoint().DefaultTimeout();

// Response task should be blocked waiting for the auth interceptor to complete.
Assert.False(responseTask.IsCompleted);
// Sending the request should be blocked waiting for the auth interceptor to complete.
Assert.Null(authorizationValue);

syncPoint.Continue();
await responseTask.DefaultTimeout();

// Assert
Assert.AreEqual("SECRET_TOKEN", authorizationValue);
}

[Test]
public async Task CallCredentialsWithHttps_CancellationToken()
{
// Arrange
string? authorizationValue = null;
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
{
authorizationValue = request.Headers.GetValues("authorization").Single();

var reply = new HelloReply { Message = "Hello world" };
var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout();
return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
});
var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act
var unreachableAuthInterceptorSection = false;
var callCredentials = CallCredentials.FromInterceptor(async (context, metadata) =>
{
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
context.CancellationToken.Register(s => ((TaskCompletionSource<object?>)s!).SetCanceled(), tcs);

await tcs.Task;

unreachableAuthInterceptorSection = true;
});
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(credentials: callCredentials), new HelloRequest());
var responseTask = call.ResponseAsync;

call.Dispose();

var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => responseTask).DefaultTimeout();
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);

// Assert
Assert.False(unreachableAuthInterceptorSection);
}

[Test]
public async Task CallCredentialsWithHttp_NoMetadataOnRequest()
{
Expand Down