Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
#endregion

using System;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Grpc.AspNetCore.Server;
using Grpc.Core;
using Grpc.Core.Interceptors;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Grpc.AspNetCore.Server.ClientFactory
namespace Grpc.AspNetCore.ClientFactory
{
/// <summary>
/// Interceptor that will set the current request's cancellation token and deadline onto CallOptions.
Expand All @@ -31,11 +35,15 @@ namespace Grpc.AspNetCore.Server.ClientFactory
/// </summary>
internal class ContextPropagationInterceptor : Interceptor
{
private readonly GrpcContextPropagationOptions _options;
private readonly IHttpContextAccessor _httpContextAccessor;
private readonly ILogger _logger;

public ContextPropagationInterceptor(IHttpContextAccessor httpContextAccessor)
public ContextPropagationInterceptor(IOptions<GrpcContextPropagationOptions> options, IHttpContextAccessor httpContextAccessor, ILogger<ContextPropagationInterceptor> logger)
{
_options = options.Value;
_httpContextAccessor = httpContextAccessor;
_logger = logger;
}

public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
Expand Down Expand Up @@ -126,47 +134,73 @@ private ClientInterceptorContext<TRequest, TResponse> ConfigureContext<TRequest,
linkedCts = null;

var options = context.Options;
var serverCallContext = GetServerCallContext();

// Use propagated deadline if it is smaller than the specified deadline
if (serverCallContext.Deadline < context.Options.Deadline.GetValueOrDefault(DateTime.MaxValue))
if (TryGetServerCallContext(out var serverCallContext, out var errorMessage))
{
options = options.WithDeadline(serverCallContext.Deadline);
}
// Use propagated deadline if it is smaller than the specified deadline
if (serverCallContext.Deadline < context.Options.Deadline.GetValueOrDefault(DateTime.MaxValue))
{
options = options.WithDeadline(serverCallContext.Deadline);
}

if (serverCallContext.CancellationToken.CanBeCanceled)
{
if (options.CancellationToken.CanBeCanceled)
if (serverCallContext.CancellationToken.CanBeCanceled)
{
// If both propagated and options cancellation token can be canceled
// then set a new linked token of both
linkedCts = CancellationTokenSource.CreateLinkedTokenSource(serverCallContext.CancellationToken, options.CancellationToken);
options = options.WithCancellationToken(linkedCts.Token);
if (options.CancellationToken.CanBeCanceled)
{
// If both propagated and options cancellation token can be canceled
// then set a new linked token of both
linkedCts = CancellationTokenSource.CreateLinkedTokenSource(serverCallContext.CancellationToken, options.CancellationToken);
options = options.WithCancellationToken(linkedCts.Token);
}
else
{
options = options.WithCancellationToken(serverCallContext.CancellationToken);
}
}
else
}
else
{
Log.PropagateServerCallContextFailure(_logger, errorMessage);

if (!_options.SuppressContextNotFoundErrors)
{
options = options.WithCancellationToken(serverCallContext.CancellationToken);
throw new InvalidOperationException("Unable to propagate server context values to the call. " + errorMessage);
}
}

return new ClientInterceptorContext<TRequest, TResponse>(context.Method, context.Host, options);
}

private ServerCallContext GetServerCallContext()
private bool TryGetServerCallContext([NotNullWhen(true)]out ServerCallContext? serverCallContext, [NotNullWhen(false)]out string? errorMessage)
{
var httpContext = _httpContextAccessor.HttpContext;
if (httpContext == null)
{
throw new InvalidOperationException("Unable to propagate server context values to the call. Can't find the current HttpContext.");
errorMessage = "Can't find the current HttpContext.";
serverCallContext = null;
return false;
}

var serverCallContext = httpContext.Features.Get<IServerCallContextFeature>()?.ServerCallContext;
serverCallContext = httpContext.Features.Get<IServerCallContextFeature>()?.ServerCallContext;
if (serverCallContext == null)
{
throw new InvalidOperationException("Unable to propagate server context values to the call. Can't find the current gRPC ServerCallContext.");
errorMessage = "Can't find the gRPC ServerCallContext on the current HttpContext.";
serverCallContext = null;
return false;
}

return serverCallContext;
errorMessage = null;
return true;
}

private static class Log
{
private static readonly Action<ILogger, string, Exception?> _propagateServerCallContextFailure =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(1, "PropagateServerCallContextFailure"), "Unable to propagate server context values to the call. {ErrorMessage}");

public static void PropagateServerCallContextFailure(ILogger logger, string errorMessage)
{
_propagateServerCallContextFailure(logger, errorMessage, null);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

namespace Grpc.AspNetCore.ClientFactory
{
/// <summary>
/// Options used to configure gRPC call context propagation.
/// </summary>
public class GrpcContextPropagationOptions
{
/// <summary>
/// Gets or sets a value that determines if context not found errors are suppressed.
/// <para>
/// When <see langword="false"/>, the client will thrown an error if it is unable to
/// find a call context when propagating values to a gRPC call.
/// Otherwise, the error is suppressed and the gRPC call will be made without context
/// propagation.
/// </para>
/// </summary>
/// <remarks>
/// <para>
/// Call context propagation will error by default if propagation can't happen because
/// the call context wasn't found. This typically happens when a client is used
/// outside the context of an executing gRPC service.
/// </para>
/// <para>
/// Suppressing context not found errors allows a client with propagation enabled to be
/// used outside the context of an executing gRPC service.
/// </para>
/// </remarks>
/// <value>The default value is <see langword="false"/>.</value>
public bool SuppressContextNotFoundErrors { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#endregion

using System;
using Grpc.AspNetCore.Server.ClientFactory;
using Grpc.AspNetCore.ClientFactory;
using Grpc.Core;
using Grpc.Net.ClientFactory;
using Microsoft.Extensions.DependencyInjection.Extensions;
Expand Down Expand Up @@ -58,6 +58,24 @@ public static IHttpClientBuilder EnableCallContextPropagation(this IHttpClientBu
return builder;
}

/// <summary>
/// Configures the server to propagate values from a call's <see cref="ServerCallContext"/>
/// onto the gRPC client.
/// </summary>
/// <param name="builder">The <see cref="IHttpClientBuilder"/>.</param>
/// <param name="configureOptions">An <see cref="Action{GrpcContextPropagationOptions}"/> to configure the provided <see cref="GrpcContextPropagationOptions"/>.</param>
/// <returns>An <see cref="IHttpClientBuilder"/> that can be used to configure the client.</returns>
public static IHttpClientBuilder EnableCallContextPropagation(this IHttpClientBuilder builder, Action<GrpcContextPropagationOptions> configureOptions)
{
if (builder == null)
{
throw new ArgumentNullException(nameof(builder));
}

builder.Services.Configure(configureOptions);
return builder.EnableCallContextPropagation();
}

private static void ValidateGrpcClient(IHttpClientBuilder builder)
{
// Validate the builder is for a gRPC client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#endregion

using System;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -28,6 +29,8 @@
using Grpc.Tests.Shared;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Microsoft.Extensions.Options;
using NUnit.Framework;

Expand Down Expand Up @@ -125,6 +128,42 @@ public async Task CreateClient_NoHttpContext_ThrowError()
Assert.AreEqual("Unable to propagate server context values to the call. Can't find the current HttpContext.", ex.Message);
}

[Test]
public async Task CreateClient_NoHttpContextIgnoreError_Success()
{
// Arrange
var testSink = new TestSink();
var testProvider = new TestLoggerProvider(testSink);

var baseAddress = new Uri("http://localhost");

var services = new ServiceCollection();
services.AddLogging(o => o.AddProvider(testProvider).SetMinimumLevel(LogLevel.Debug));
services.AddOptions();
services.AddSingleton(CreateHttpContextAccessor(null));
services
.AddGrpcClient<Greeter.GreeterClient>(o =>
{
o.Address = baseAddress;
})
.EnableCallContextPropagation(o => o.SuppressContextNotFoundErrors = true)
.AddHttpMessageHandler(() => ClientTestHelpers.CreateTestMessageHandler(new HelloReply()));

var serviceProvider = services.BuildServiceProvider(validateScopes: true);

var clientFactory = new DefaultGrpcClientFactory(
serviceProvider,
serviceProvider.GetRequiredService<IHttpClientFactory>());
var client = clientFactory.CreateClient<Greeter.GreeterClient>(nameof(Greeter.GreeterClient));

// Act
await client.SayHelloAsync(new HelloRequest(), new CallOptions()).ResponseAsync.DefaultTimeout();

// Assert
var log = testSink.Writes.Single(w => w.EventId.Name == "PropagateServerCallContextFailure");
Assert.AreEqual("Unable to propagate server context values to the call. Can't find the current HttpContext.", log.Message);
}

[Test]
public async Task CreateClient_NoServerCallContextOnHttpContext_ThrowError()
{
Expand Down Expand Up @@ -153,7 +192,43 @@ public async Task CreateClient_NoServerCallContextOnHttpContext_ThrowError()
var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => client.SayHelloAsync(new HelloRequest(), new CallOptions()).ResponseAsync).DefaultTimeout();

// Assert
Assert.AreEqual("Unable to propagate server context values to the call. Can't find the current gRPC ServerCallContext.", ex.Message);
Assert.AreEqual("Unable to propagate server context values to the call. Can't find the gRPC ServerCallContext on the current HttpContext.", ex.Message);
}

[Test]
public async Task CreateClient_NoServerCallContextOnHttpContextIgnoreError_Success()
{
// Arrange
var testSink = new TestSink();
var testProvider = new TestLoggerProvider(testSink);

var baseAddress = new Uri("http://localhost");

var services = new ServiceCollection();
services.AddLogging(o => o.AddProvider(testProvider).SetMinimumLevel(LogLevel.Debug));
services.AddOptions();
services.AddSingleton(CreateHttpContextAccessor(new DefaultHttpContext()));
services
.AddGrpcClient<Greeter.GreeterClient>(o =>
{
o.Address = baseAddress;
})
.EnableCallContextPropagation(o => o.SuppressContextNotFoundErrors = true)
.AddHttpMessageHandler(() => ClientTestHelpers.CreateTestMessageHandler(new HelloReply()));

var serviceProvider = services.BuildServiceProvider(validateScopes: true);

var clientFactory = new DefaultGrpcClientFactory(
serviceProvider,
serviceProvider.GetRequiredService<IHttpClientFactory>());
var client = clientFactory.CreateClient<Greeter.GreeterClient>(nameof(Greeter.GreeterClient));

// Act
await client.SayHelloAsync(new HelloRequest(), new CallOptions()).ResponseAsync.DefaultTimeout();

// Assert
var log = testSink.Writes.Single(w => w.EventId.Name == "PropagateServerCallContextFailure");
Assert.AreEqual("Unable to propagate server context values to the call. Can't find the gRPC ServerCallContext on the current HttpContext.", log.Message);
}

private IHttpContextAccessor CreateHttpContextAccessor(HttpContext? httpContext)
Expand Down