diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index a7c1dd6abfff..5b20d770d9ba 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -59,6 +59,8 @@ public partial class HubConnection : IAsyncDisposable // Default amount of bytes we'll buffer when using Stateful Reconnect until applying backpressure to sends from the client. internal const long DefaultStatefulReconnectBufferSize = 100_000; + internal const string ActivityName = "Microsoft.AspNetCore.SignalR.Client.InvocationOut"; + // The receive loop has a single reader and single writer at a time so optimize the channel for that private static readonly UnboundedChannelOptions _receiveLoopOptions = new UnboundedChannelOptions { @@ -73,11 +75,13 @@ public partial class HubConnection : IAsyncDisposable private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; private readonly ConnectionLogScope _logScope; + private readonly ActivitySource _activitySource; private readonly IHubProtocol _protocol; private readonly IServiceProvider _serviceProvider; private readonly IConnectionFactory _connectionFactory; private readonly IRetryPolicy? _reconnectPolicy; private readonly EndPoint _endPoint; + private readonly string? _serviceName; private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary(StringComparer.Ordinal); // Holds all mutable state other than user-defined handlers and settable properties. @@ -235,6 +239,10 @@ public HubConnection(IConnectionFactory connectionFactory, _logScope = new ConnectionLogScope(); + // ActivitySource can be resolved from the service provider when unit testing. + _activitySource = (serviceProvider.GetService() ?? SignalRClientActivitySource.Instance).ActivitySource; + _serviceName = (_endPoint is UriEndPoint e) ? e.Uri.AbsolutePath.Trim('/') : null; + var options = serviceProvider.GetService>(); ServerTimeout = options?.Value.ServerTimeout ?? DefaultServerTimeout; @@ -720,7 +728,8 @@ async Task OnStreamCanceled(InvocationRequest irq) var readers = default(Dictionary); CheckDisposed(); - var connectionState = await _state.WaitForActiveConnectionAsync(nameof(StreamAsChannelCoreAsync), token: cancellationToken).ConfigureAwait(false); + + var (connectionState, activity) = await WaitForActiveConnectionWithActivityAsync(nameof(StreamAsChannelCoreAsync), methodName, token: cancellationToken).ConfigureAwait(false); ChannelReader channel; try @@ -731,7 +740,7 @@ async Task OnStreamCanceled(InvocationRequest irq) readers = PackageStreamingParams(connectionState, ref args, out var streamIds); // I just want an excuse to use 'irq' as a variable name... - var irq = InvocationRequest.Stream(cancellationToken, returnType, connectionState.GetNextId(), _loggerFactory, this, out channel); + var irq = InvocationRequest.Stream(cancellationToken, returnType, connectionState.GetNextId(), _loggerFactory, this, activity, out channel); await InvokeStreamCore(connectionState, methodName, irq, args, streamIds?.ToArray(), cancellationToken).ConfigureAwait(false); if (cancellationToken.CanBeCanceled) @@ -1003,12 +1012,71 @@ private async Task CommonStreaming(ConnectionState connectionState, string strea } } + private async Task<(ConnectionState, Activity?)> WaitForActiveConnectionWithActivityAsync(string sendingMethodName, string invokedMethodName, CancellationToken token) + { + // Start the activity before waiting on the connection. + // Starting the activity here means time to connect or reconnect is included in the invoke. + var activity = CreateActivity(invokedMethodName); + + try + { + ConnectionState connectionState; + var connectionStateTask = _state.WaitForActiveConnectionAsync(sendingMethodName, token); + if (connectionStateTask.Status == TaskStatus.RanToCompletion) + { + // Attempt to get already connected connection and set server tags using it. + connectionState = connectionStateTask.Result; + SetServerTags(activity, connectionState.ConnectionUrl); + activity?.Start(); + } + else + { + // Fallback to using configured endpoint. + var initialUri = (_endPoint as UriEndPoint)?.Uri; + SetServerTags(activity, initialUri); + activity?.Start(); + + connectionState = await connectionStateTask.ConfigureAwait(false); + + // After connection is returned, check if URL is different. If so, update activity server tags. + if (connectionState.ConnectionUrl != null && connectionState.ConnectionUrl != initialUri) + { + SetServerTags(activity, connectionState.ConnectionUrl); + } + } + + return (connectionState, activity); + } + catch (Exception ex) + { + // If there is an error getting an active connection then the invocation has failed. + if (activity is not null) + { + activity.SetStatus(ActivityStatusCode.Error); + activity.SetTag("error.type", ex.GetType().FullName); + activity.Stop(); + } + + throw; + } + + static void SetServerTags(Activity? activity, Uri? uri) + { + if (activity != null && uri != null) + { + activity.SetTag("server.address", uri.Host); + activity.SetTag("server.port", uri.Port); + } + } + } + private async Task InvokeCoreAsyncCore(string methodName, Type returnType, object?[] args, CancellationToken cancellationToken) { var readers = default(Dictionary); CheckDisposed(); - var connectionState = await _state.WaitForActiveConnectionAsync(nameof(InvokeCoreAsync), token: cancellationToken).ConfigureAwait(false); + + var (connectionState, activity) = await WaitForActiveConnectionWithActivityAsync(nameof(InvokeCoreAsync), methodName, token: cancellationToken).ConfigureAwait(false); Task invocationTask; try @@ -1017,7 +1085,7 @@ private async Task CommonStreaming(ConnectionState connectionState, string strea readers = PackageStreamingParams(connectionState, ref args, out var streamIds); - var irq = InvocationRequest.Invoke(cancellationToken, returnType, connectionState.GetNextId(), _loggerFactory, this, out invocationTask); + var irq = InvocationRequest.Invoke(cancellationToken, returnType, connectionState.GetNextId(), _loggerFactory, this, activity, out invocationTask); await InvokeCore(connectionState, methodName, irq, args, streamIds?.ToArray(), cancellationToken).ConfigureAwait(false); LaunchStreams(connectionState, readers, cancellationToken); @@ -1031,13 +1099,43 @@ private async Task CommonStreaming(ConnectionState connectionState, string strea return await invocationTask.ConfigureAwait(false); } + private Activity? CreateActivity(string methodName) + { + var activity = _activitySource.CreateActivity(ActivityName, ActivityKind.Client); + if (activity is null && Activity.Current is not null && _logger.IsEnabled(LogLevel.Critical)) + { + activity = new Activity(ActivityName); + } + + if (activity is not null) + { + if (!string.IsNullOrEmpty(_serviceName)) + { + activity.DisplayName = $"{_serviceName}/{methodName}"; + activity.SetTag("rpc.service", _serviceName); + } + else + { + activity.DisplayName = methodName; + } + + activity.SetTag("rpc.system", "signalr"); + activity.SetTag("rpc.method", methodName); + } + + return activity; + } + private async Task InvokeCore(ConnectionState connectionState, string methodName, InvocationRequest irq, object?[] args, string[]? streams, CancellationToken cancellationToken) { Log.PreparingBlockingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName!, args.Length); // Client invocations are always blocking var invocationMessage = new InvocationMessage(irq.InvocationId, methodName, args, streams); - InjectHeaders(invocationMessage); + if (irq.Activity is not null) + { + InjectHeaders(irq.Activity, invocationMessage); + } Log.RegisteringInvocation(_logger, irq.InvocationId); connectionState.AddInvocation(irq); @@ -1064,7 +1162,10 @@ private async Task InvokeStreamCore(ConnectionState connectionState, string meth Log.PreparingStreamingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName!, args.Length); var invocationMessage = new StreamInvocationMessage(irq.InvocationId, methodName, args, streams); - InjectHeaders(invocationMessage); + if (irq.Activity is not null) + { + InjectHeaders(irq.Activity, invocationMessage); + } Log.RegisteringInvocation(_logger, irq.InvocationId); @@ -1085,23 +1186,16 @@ private async Task InvokeStreamCore(ConnectionState connectionState, string meth } } - private static void InjectHeaders(HubInvocationMessage invocationMessage) + private static void InjectHeaders(Activity currentActivity, HubInvocationMessage invocationMessage) { - // TODO: Change when SignalR client has an activity. - // This sends info about the current activity, regardless of the activity source, to the SignalR server. - // When SignalR client supports client activities this logic should be updated to only send headers - // if the SignalR client activity is created. The goal is to match the behavior of distributed tracing in HttpClient. - if (Activity.Current is { } currentActivity) + DistributedContextPropagator.Current.Inject(currentActivity, invocationMessage, static (carrier, key, value) => { - DistributedContextPropagator.Current.Inject(currentActivity, invocationMessage, static (carrier, key, value) => + if (carrier is HubInvocationMessage invocationMessage) { - if (carrier is HubInvocationMessage invocationMessage) - { - invocationMessage.Headers ??= new Dictionary(); - invocationMessage.Headers[key] = value; - } - }); - } + invocationMessage.Headers ??= new Dictionary(); + invocationMessage.Headers[key] = value; + } + }); } private async Task SendHubMessage(ConnectionState connectionState, HubMessage hubMessage, CancellationToken cancellationToken = default) @@ -1131,7 +1225,8 @@ private async Task SendCoreAsyncCore(string methodName, object?[] args, Cancella var readers = default(Dictionary); CheckDisposed(); - var connectionState = await _state.WaitForActiveConnectionAsync(nameof(SendCoreAsync), token: cancellationToken).ConfigureAwait(false); + + var (connectionState, activity) = await WaitForActiveConnectionWithActivityAsync(nameof(SendCoreAsync), methodName, token: cancellationToken).ConfigureAwait(false); try { CheckDisposed(); @@ -1140,12 +1235,27 @@ private async Task SendCoreAsyncCore(string methodName, object?[] args, Cancella Log.PreparingNonBlockingInvocation(_logger, methodName, args.Length); var invocationMessage = new InvocationMessage(null, methodName, args, streamIds?.ToArray()); + if (activity is not null) + { + InjectHeaders(activity, invocationMessage); + } await SendHubMessage(connectionState, invocationMessage, cancellationToken).ConfigureAwait(false); LaunchStreams(connectionState, readers, cancellationToken); } + catch (Exception ex) + { + if (activity is not null) + { + activity.SetStatus(ActivityStatusCode.Error); + activity.SetTag("error.type", ex.GetType().FullName); + activity.Stop(); + } + throw; + } finally { + activity?.Stop(); _state.ReleaseConnectionLock(); } } @@ -2018,6 +2128,7 @@ private sealed class ConnectionState : IInvocationBinder private long _nextActivationSendPing; public ConnectionContext Connection { get; } + public Uri? ConnectionUrl { get; } public Task? ReceiveTask { get; set; } public Exception? CloseException { get; set; } public CancellationToken UploadStreamToken { get; set; } @@ -2036,6 +2147,7 @@ public bool Stopping public ConnectionState(ConnectionContext connection, HubConnection hubConnection) { Connection = connection; + ConnectionUrl = (connection.RemoteEndPoint is UriEndPoint ep) ? ep.Uri : null; _hubConnection = hubConnection; _hubConnection._logScope.ConnectionId = connection.ConnectionId; diff --git a/src/SignalR/clients/csharp/Client.Core/src/Internal/InvocationRequest.cs b/src/SignalR/clients/csharp/Client.Core/src/Internal/InvocationRequest.cs index a188e635d0b1..e0bbc9172773 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/Internal/InvocationRequest.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/Internal/InvocationRequest.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -13,15 +15,17 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal; internal abstract partial class InvocationRequest : IDisposable { private readonly CancellationTokenRegistration _cancellationTokenRegistration; + private int _isActivityStopping; protected ILogger Logger { get; } public Type ResultType { get; } public CancellationToken CancellationToken { get; } public string InvocationId { get; } - public HubConnection HubConnection { get; private set; } + public HubConnection HubConnection { get; } + public Activity? Activity { get; } - protected InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILogger logger, HubConnection hubConnection) + protected InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILogger logger, HubConnection hubConnection, Activity? activity) { _cancellationTokenRegistration = cancellationToken.Register(self => ((InvocationRequest)self!).Cancel(), this); @@ -30,21 +34,28 @@ protected InvocationRequest(CancellationToken cancellationToken, Type resultType ResultType = resultType; Logger = logger; HubConnection = hubConnection; + Activity = activity; Log.InvocationCreated(Logger, InvocationId); } - public static InvocationRequest Invoke(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, out Task result) + [MemberNotNullWhen(true, nameof(Activity))] + protected bool TryBeginStopActivity() { - var req = new NonStreaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection); + return Activity != null && Interlocked.Exchange(ref _isActivityStopping, 1) == 0; + } + + public static InvocationRequest Invoke(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, Activity? activity, out Task result) + { + var req = new NonStreaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection, activity); result = req.Result; return req; } public static InvocationRequest Stream(CancellationToken cancellationToken, Type resultType, string invocationId, - ILoggerFactory loggerFactory, HubConnection hubConnection, out ChannelReader result) + ILoggerFactory loggerFactory, HubConnection hubConnection, Activity? activity, out ChannelReader result) { - var req = new Streaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection); + var req = new Streaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection, activity); result = req.Result; return req; } @@ -69,8 +80,8 @@ private sealed class Streaming : InvocationRequest { private readonly Channel _channel = Channel.CreateUnbounded(); - public Streaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection) - : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger(typeof(Streaming)), hubConnection) + public Streaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, Activity? activity) + : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger(typeof(Streaming)), hubConnection, activity) { } @@ -82,7 +93,15 @@ public override void Complete(CompletionMessage completionMessage) if (completionMessage.Result != null) { Log.ReceivedUnexpectedComplete(Logger, InvocationId); + + if (TryBeginStopActivity()) + { + Activity.SetStatus(ActivityStatusCode.Error); + Activity.Stop(); + } + _channel.Writer.TryComplete(new InvalidOperationException("Server provided a result in a completion response to a streamed invocation.")); + return; } if (!string.IsNullOrEmpty(completionMessage.Error)) @@ -91,12 +110,25 @@ public override void Complete(CompletionMessage completionMessage) return; } + if (TryBeginStopActivity()) + { + Activity.Stop(); + } + _channel.Writer.TryComplete(); } public override void Fail(Exception exception) { Log.InvocationFailed(Logger, InvocationId); + + if (TryBeginStopActivity()) + { + Activity.SetStatus(ActivityStatusCode.Error); + Activity.SetTag("error.type", exception.GetType().FullName); + Activity.Stop(); + } + _channel.Writer.TryComplete(exception); } @@ -121,6 +153,13 @@ public override async ValueTask StreamItem(object? item) protected override void Cancel() { + if (TryBeginStopActivity()) + { + Activity.SetStatus(ActivityStatusCode.Error); + Activity.SetTag("error.type", typeof(OperationCanceledException).FullName); + Activity.Stop(); + } + _channel.Writer.TryComplete(new OperationCanceledException()); } } @@ -129,8 +168,8 @@ private sealed class NonStreaming : InvocationRequest { private readonly TaskCompletionSource _completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public NonStreaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection) - : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger(typeof(NonStreaming)), hubConnection) + public NonStreaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, Activity? activity) + : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger(typeof(NonStreaming)), hubConnection, activity) { } @@ -145,12 +184,26 @@ public override void Complete(CompletionMessage completionMessage) } Log.InvocationCompleted(Logger, InvocationId); + + if (TryBeginStopActivity()) + { + Activity.Stop(); + } + _completionSource.TrySetResult(completionMessage.Result); } public override void Fail(Exception exception) { Log.InvocationFailed(Logger, InvocationId); + + if (TryBeginStopActivity()) + { + Activity.SetStatus(ActivityStatusCode.Error); + Activity.SetTag("error.type", exception.GetType().FullName); + Activity.Stop(); + } + _completionSource.TrySetException(exception); } @@ -165,6 +218,13 @@ public override ValueTask StreamItem(object? item) protected override void Cancel() { + if (TryBeginStopActivity()) + { + Activity.SetStatus(ActivityStatusCode.Error); + Activity.SetTag("error.type", typeof(OperationCanceledException).FullName); + Activity.Stop(); + } + _completionSource.TrySetCanceled(); } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/Internal/SignalRClientActivitySource.cs b/src/SignalR/clients/csharp/Client.Core/src/Internal/SignalRClientActivitySource.cs new file mode 100644 index 000000000000..cc6efbf4de24 --- /dev/null +++ b/src/SignalR/clients/csharp/Client.Core/src/Internal/SignalRClientActivitySource.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; + +namespace Microsoft.AspNetCore.SignalR.Client.Internal; + +// Internal for now so we don't need API review. +// Just a wrapper for the ActivitySource. Don't want to put ActivitySource directly in DI as +// it is a public type and could conflict with activity source from another library. +internal sealed class SignalRClientActivitySource +{ + public static readonly SignalRClientActivitySource Instance = new(); + + public ActivitySource ActivitySource { get; } = new ActivitySource("Microsoft.AspNetCore.SignalR.Client"); +} diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.Tracing.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.Tracing.cs new file mode 100644 index 000000000000..8cd134bcceba --- /dev/null +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.Tracing.cs @@ -0,0 +1,747 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Net; +using System.Net.Http; +using System.Net.WebSockets; +using System.Text.Json; +using System.Threading.Channels; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Http.Connections.Client; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.SignalR.Client.Internal; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.Test.Internal; +using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; + +namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests; + +partial class HubConnectionTests : FunctionalTestBase +{ + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task InvokeAsync_SendTraceHeader(string protocolName, HttpTransportType transportType, string path) + { + var protocol = HubProtocols[protocolName]; + await using (var server = await StartServer()) + { + var serverChannel = Channel.CreateUnbounded(); + var clientChannel = Channel.CreateUnbounded(); + var serverSource = server.Services.GetRequiredService().ActivitySource; + var clientSourceContainer = new SignalRClientActivitySource(); + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource) || ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => + { + if (activity.Source == clientSourceContainer.ActivitySource) + { + clientChannel.Writer.TryWrite(activity); + } + else + { + serverChannel.Writer.TryWrite(activity); + } + } + }; + ActivitySource.AddActivityListener(listener); + + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + path, transportType); + connectionBuilder.Services.AddSingleton(protocol); + connectionBuilder.Services.AddSingleton(clientSourceContainer); + + var connection = connectionBuilder.Build(); + + Activity clientParentActivity1 = null; + Activity clientActivity1 = null; + Activity clientParentActivity2 = null; + Activity clientActivity2 = null; + try + { + await connection.StartAsync().DefaultTimeout(); + + // Invocation 1 + try + { + clientParentActivity1 = new Activity("ClientActivity1"); + clientParentActivity1.AddBaggage("baggage-1", "value-1"); + clientParentActivity1.Start(); + + var resultTask = connection.InvokeAsync(nameof(TestHub.HelloWorld)).DefaultTimeout(); + + clientActivity1 = await clientChannel.Reader.ReadAsync().DefaultTimeout(); + + // The SignalR client activity shouldn't escape into user code. + Assert.Equal(clientParentActivity1, Activity.Current); + + var result = await resultTask.ConfigureAwait(false); + Assert.Equal("Hello World!", result); + } + finally + { + clientParentActivity1?.Stop(); + } + + // Invocation 2 + try + { + clientParentActivity2 = new Activity("ClientActivity2"); + clientParentActivity2.AddBaggage("baggage-2", "value-2"); + clientParentActivity2.Start(); + + var resultTask = connection.InvokeAsync(nameof(TestHub.HelloWorld)); + + clientActivity2 = await clientChannel.Reader.ReadAsync().DefaultTimeout(); + + // The SignalR client activity shouldn't escape into user code. + Assert.Equal(clientParentActivity2, Activity.Current); + + var result = await resultTask.DefaultTimeout(); + Assert.Equal("Hello World!", result); + } + finally + { + clientParentActivity2?.Stop(); + } + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + + var port = new Uri(server.Url).Port; + var serverHubName = path switch + { + "/default" => typeof(TestHub).FullName, + "/hubT" => typeof(TestHubT).FullName, + "/dynamic" => typeof(DynamicTestHub).FullName, + _ => throw new InvalidOperationException("Unexpected path: " + path) + }; + var clientHubName = path.TrimStart('/'); + + var serverActivities = await serverChannel.Reader.ReadAtLeastAsync(minimumCount: 4).DefaultTimeout(); + + Assert.Collection(serverActivities, + a => + { + Assert.Equal(SignalRServerActivitySource.OnConnected, a.OperationName); + Assert.Equal($"{serverHubName}/OnConnectedAsync", a.DisplayName); + Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); + Assert.Equal(ActivityKind.Internal, a.Kind); + Assert.False(a.HasRemoteParent); + Assert.Equal(ActivityStatusCode.Unset, a.Status); + Assert.Empty(a.Baggage); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); + Assert.Equal($"{serverHubName}/HelloWorld", a.DisplayName); + Assert.Equal(clientActivity1.Id, a.ParentId); + Assert.Equal(ActivityKind.Server, a.Kind); + Assert.True(a.HasRemoteParent); + Assert.Equal(ActivityStatusCode.Unset, a.Status); + + var baggage = a.Baggage.ToDictionary(); + Assert.Equal("value-1", baggage["baggage-1"]); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); + Assert.Equal($"{serverHubName}/HelloWorld", a.DisplayName); + Assert.Equal(clientActivity2.Id, a.ParentId); + Assert.Equal(ActivityKind.Server, a.Kind); + Assert.True(a.HasRemoteParent); + Assert.Equal(ActivityStatusCode.Unset, a.Status); + + var baggage = a.Baggage.ToDictionary(); + Assert.Equal("value-2", baggage["baggage-2"]); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.OnDisconnected, a.OperationName); + Assert.Equal($"{serverHubName}/OnDisconnectedAsync", a.DisplayName); + Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); + Assert.Equal(ActivityKind.Internal, a.Kind); + Assert.False(a.HasRemoteParent); + Assert.Empty(a.Baggage); + Assert.Equal(ActivityStatusCode.Unset, a.Status); + }); + + // Client activity 1 + Assert.Equal(HubConnection.ActivityName, clientActivity1.OperationName); + Assert.Equal($"{clientHubName}/HelloWorld", clientActivity1.DisplayName); + Assert.Equal(ActivityKind.Client, clientActivity1.Kind); + Assert.Equal(clientParentActivity1, clientActivity1.Parent); + Assert.Equal(ActivityStatusCode.Unset, clientActivity1.Status); + + var baggage = clientActivity1.Baggage.ToDictionary(); + Assert.Equal("value-1", baggage["baggage-1"]); + + var tags = clientActivity1.TagObjects.ToDictionary(); + Assert.Equal("signalr", tags["rpc.system"]); + Assert.Equal("HelloWorld", tags["rpc.method"]); + Assert.Equal(clientHubName, tags["rpc.service"]); + Assert.Equal("127.0.0.1", tags["server.address"]); + Assert.Equal(port, (int)tags["server.port"]); + + // Client activity 2 + Assert.Equal(HubConnection.ActivityName, clientActivity2.OperationName); + Assert.Equal($"{clientHubName}/HelloWorld", clientActivity2.DisplayName); + Assert.Equal(ActivityKind.Client, clientActivity2.Kind); + Assert.Equal(clientParentActivity2, clientActivity2.Parent); + Assert.Equal(ActivityStatusCode.Unset, clientActivity2.Status); + + baggage = clientActivity2.Baggage.ToDictionary(); + Assert.Equal("value-2", baggage["baggage-2"]); + + tags = clientActivity2.TagObjects.ToDictionary(); + Assert.Equal("signalr", tags["rpc.system"]); + Assert.Equal("HelloWorld", tags["rpc.method"]); + Assert.Equal(clientHubName, tags["rpc.service"]); + Assert.Equal("127.0.0.1", tags["server.address"]); + Assert.Equal(port, (int)tags["server.port"]); + } + } + + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task StreamAsyncCore_SendTraceHeader(string protocolName, HttpTransportType transportType, string path) + { + var protocol = HubProtocols[protocolName]; + await using (var server = await StartServer()) + { + var serverChannel = Channel.CreateUnbounded(); + var clientActivityTcs = new TaskCompletionSource(); + Activity clientActivity = null; + var serverSource = server.Services.GetRequiredService().ActivitySource; + var clientSourceContainer = new SignalRClientActivitySource(); + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource) || ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => + { + if (activity.Source == clientSourceContainer.ActivitySource) + { + clientActivityTcs.SetResult(activity); + } + else + { + serverChannel.Writer.TryWrite(activity); + } + } + }; + ActivitySource.AddActivityListener(listener); + + var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory, activitySourceContainer: clientSourceContainer); + + Activity clientParentActivity = null; + try + { + await connection.StartAsync().DefaultTimeout(); + + clientParentActivity = new Activity("ClientActivity"); + clientParentActivity.AddBaggage("baggage-1", "value-1"); + clientParentActivity.Start(); + + var expectedValue = 0; + var streamTo = 5; + var asyncEnumerable = connection.StreamAsyncCore("Stream", new object[] { streamTo }); + + await foreach (var streamValue in asyncEnumerable) + { + // Call starts after user reads from the enumerable. + if (streamValue == 0) + { + // The SignalR client activity should be: + // - Started + // - Still running + // - Not escaped into user code + clientActivity = await clientActivityTcs.Task.DefaultTimeout(); + Assert.NotNull(clientActivity); + Assert.False(clientActivity.IsStopped); + Assert.Equal(clientParentActivity, Activity.Current); + } + + Assert.Equal(expectedValue, streamValue); + expectedValue++; + } + + Assert.Equal(streamTo, expectedValue); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + clientParentActivity?.Stop(); + await connection.DisposeAsync().DefaultTimeout(); + } + + var port = new Uri(server.Url).Port; + var hubName = path switch + { + "/default" => typeof(TestHub).FullName, + "/hubT" => typeof(TestHubT).FullName, + "/dynamic" => typeof(DynamicTestHub).FullName, + _ => throw new InvalidOperationException("Unexpected path: " + path) + }; + var clientHubName = path.TrimStart('/'); + + var serverActivities = await serverChannel.Reader.ReadAtLeastAsync(minimumCount: 3).DefaultTimeout(); + + Assert.Collection(serverActivities, + a => + { + Assert.Equal(SignalRServerActivitySource.OnConnected, a.OperationName); + Assert.Equal($"{hubName}/OnConnectedAsync", a.DisplayName); + Assert.Equal(ActivityKind.Internal, a.Kind); + Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); + Assert.Equal(ActivityStatusCode.Unset, a.Status); + Assert.False(a.HasRemoteParent); + Assert.Empty(a.Baggage); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); + Assert.Equal($"{hubName}/Stream", a.DisplayName); + Assert.Equal(ActivityKind.Server, a.Kind); + Assert.Equal(clientActivity.Id, a.ParentId); + Assert.Equal(ActivityStatusCode.Unset, a.Status); + Assert.True(a.HasRemoteParent); + Assert.True(a.IsStopped); + + var baggage = a.Baggage.ToDictionary(); + Assert.Equal("value-1", baggage["baggage-1"]); + }, + a => + { + Assert.Equal(SignalRServerActivitySource.OnDisconnected, a.OperationName); + Assert.Equal($"{hubName}/OnDisconnectedAsync", a.DisplayName); + Assert.Equal(ActivityKind.Internal, a.Kind); + Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); + Assert.Equal(ActivityStatusCode.Unset, a.Status); + Assert.False(a.HasRemoteParent); + Assert.Empty(a.Baggage); + }); + + Assert.Equal(HubConnection.ActivityName, clientActivity.OperationName); + Assert.Equal($"{clientHubName}/Stream", clientActivity.DisplayName); + Assert.Equal(ActivityKind.Client, clientActivity.Kind); + Assert.Equal(clientParentActivity, clientActivity.Parent); + Assert.Equal(ActivityStatusCode.Unset, clientActivity.Status); + Assert.True(clientActivity.IsStopped); + + var baggage = clientActivity.Baggage.ToDictionary(); + Assert.Equal("value-1", baggage["baggage-1"]); + + var tags = clientActivity.TagObjects.ToDictionary(); + Assert.Equal("signalr", tags["rpc.system"]); + Assert.Equal("Stream", tags["rpc.method"]); + Assert.Equal(clientHubName, tags["rpc.service"]); + Assert.Equal("127.0.0.1", tags["server.address"]); + Assert.Equal(port, (int)tags["server.port"]); + } + } + + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + [LogLevel(LogLevel.Trace)] + public async Task StreamAsyncCanBeCanceled_Tracing(string protocolName, HttpTransportType transportType, string path) + { + var protocol = HubProtocols[protocolName]; + await using (var server = await StartServer()) + { + var serverChannel = Channel.CreateUnbounded(); + var clientActivityTcs = new TaskCompletionSource(); + var serverSource = server.Services.GetRequiredService().ActivitySource; + var clientSourceContainer = new SignalRClientActivitySource(); + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource) || ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => + { + if (activity.Source == clientSourceContainer.ActivitySource) + { + clientActivityTcs.SetResult(activity); + } + else + { + serverChannel.Writer.TryWrite(activity); + } + } + }; + ActivitySource.AddActivityListener(listener); + + var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory, activitySourceContainer: clientSourceContainer); + try + { + await connection.StartAsync().DefaultTimeout(); + + var cts = new CancellationTokenSource(); + + var stream = connection.StreamAsync("Stream", 1000, cts.Token); + var results = new List(); + + var enumerator = stream.GetAsyncEnumerator(); + await Assert.ThrowsAsync(async () => + { + while (await enumerator.MoveNextAsync()) + { + results.Add(enumerator.Current); + cts.Cancel(); + } + }); + + Assert.True(results.Count > 0 && results.Count < 1000); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + + var port = new Uri(server.Url).Port; + var hubName = path switch + { + "/default" => typeof(TestHub).FullName, + "/hubT" => typeof(TestHubT).FullName, + "/dynamic" => typeof(DynamicTestHub).FullName, + _ => throw new InvalidOperationException("Unexpected path: " + path) + }; + var clientHubName = path.TrimStart('/'); + + var serverActivities = await serverChannel.Reader.ReadAtLeastAsync(minimumCount: 3).DefaultTimeout(); + var clientActivity = await clientActivityTcs.Task.DefaultTimeout(); + + Assert.Collection(serverActivities, + a => Assert.Equal($"{hubName}/OnConnectedAsync", a.DisplayName), + a => + { + Assert.Equal($"{hubName}/Stream", a.DisplayName); + Assert.Equal(ActivityStatusCode.Error, clientActivity.Status); + + var tags = clientActivity.TagObjects.ToDictionary(); + Assert.Equal(typeof(OperationCanceledException).FullName, tags["error.type"]); + }, + a => Assert.Equal($"{hubName}/OnDisconnectedAsync", a.DisplayName)); + + Assert.Equal($"{clientHubName}/Stream", clientActivity.DisplayName); + Assert.Equal(ActivityKind.Client, clientActivity.Kind); + Assert.Equal(ActivityStatusCode.Error, clientActivity.Status); + + var tags = clientActivity.TagObjects.ToDictionary(); + Assert.Equal(typeof(OperationCanceledException).FullName, tags["error.type"]); + } + } + + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + [LogLevel(LogLevel.Trace)] + public async Task StreamAsyncWithException_Tracing(string protocolName, HttpTransportType transportType, string path) + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == DefaultHubDispatcherLoggerName && + writeContext.EventId.Name == "FailedInvokingHubMethod"; + } + + var protocol = HubProtocols[protocolName]; + await using (var server = await StartServer(ExpectedErrors)) + { + var serverChannel = Channel.CreateUnbounded(); + var clientActivityTcs = new TaskCompletionSource(); + var serverSource = server.Services.GetRequiredService().ActivitySource; + var clientSourceContainer = new SignalRClientActivitySource(); + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource) || ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => + { + if (activity.Source == clientSourceContainer.ActivitySource) + { + clientActivityTcs.SetResult(activity); + } + else + { + serverChannel.Writer.TryWrite(activity); + } + } + }; + ActivitySource.AddActivityListener(listener); + + var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory, activitySourceContainer: clientSourceContainer); + try + { + await connection.StartAsync().DefaultTimeout(); + var asyncEnumerable = connection.StreamAsync("StreamException"); + var ex = await Assert.ThrowsAsync(async () => + { + await foreach (var streamValue in asyncEnumerable) + { + Assert.True(false, "Expected an exception from the streaming invocation."); + } + }); + + Assert.Equal("An unexpected error occurred invoking 'StreamException' on the server. InvalidOperationException: Error occurred while streaming.", ex.Message); + + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + + var port = new Uri(server.Url).Port; + var serverHubName = path switch + { + "/default" => typeof(TestHub).FullName, + "/hubT" => typeof(TestHubT).FullName, + "/dynamic" => typeof(DynamicTestHub).FullName, + _ => throw new InvalidOperationException("Unexpected path: " + path) + }; + var clientHubName = path.TrimStart('/'); + + var serverActivities = await serverChannel.Reader.ReadAtLeastAsync(minimumCount: 3).DefaultTimeout(); + var clientActivity = await clientActivityTcs.Task.DefaultTimeout(); + + Assert.Collection(serverActivities, + a => Assert.Equal($"{serverHubName}/OnConnectedAsync", a.DisplayName), + a => + { + Assert.Equal($"{serverHubName}/StreamException", a.DisplayName); + Assert.Equal(ActivityStatusCode.Error, clientActivity.Status); + + var tags = clientActivity.TagObjects.ToDictionary(); + Assert.Equal(typeof(HubException).FullName, tags["error.type"]); + }, + a => Assert.Equal($"{serverHubName}/OnDisconnectedAsync", a.DisplayName)); + + Assert.Equal($"{clientHubName}/StreamException", clientActivity.DisplayName); + Assert.Equal(ActivityKind.Client, clientActivity.Kind); + Assert.Equal(ActivityStatusCode.Error, clientActivity.Status); + + var tags = clientActivity.TagObjects.ToDictionary(); + Assert.Equal(typeof(HubException).FullName, tags["error.type"]); + } + } + + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task InvokeAsyncWithException_Tracing(string protocolName, HttpTransportType transportType, string path) + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == DefaultHubDispatcherLoggerName && + writeContext.EventId.Name == "FailedInvokingHubMethod"; + } + + var protocol = HubProtocols[protocolName]; + await using (var server = await StartServer(ExpectedErrors)) + { + var serverChannel = Channel.CreateUnbounded(); + var clientActivityTcs = new TaskCompletionSource(); + var serverSource = server.Services.GetRequiredService().ActivitySource; + var clientSourceContainer = new SignalRClientActivitySource(); + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource) || ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => + { + if (activity.Source == clientSourceContainer.ActivitySource) + { + clientActivityTcs.SetResult(activity); + } + else + { + serverChannel.Writer.TryWrite(activity); + } + } + }; + ActivitySource.AddActivityListener(listener); + + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + path, transportType); + connectionBuilder.Services.AddSingleton(protocol); + connectionBuilder.Services.AddSingleton(clientSourceContainer); + + var connection = connectionBuilder.Build(); + + try + { + await connection.StartAsync().DefaultTimeout(); + + await Assert.ThrowsAnyAsync( + async () => await connection.InvokeAsync(nameof(TestHub.InvokeException))).DefaultTimeout(); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + + var port = new Uri(server.Url).Port; + var serverHubName = path switch + { + "/default" => typeof(TestHub).FullName, + "/hubT" => typeof(TestHubT).FullName, + "/dynamic" => typeof(DynamicTestHub).FullName, + _ => throw new InvalidOperationException("Unexpected path: " + path) + }; + var clientHubName = path.TrimStart('/'); + + var serverActivities = await serverChannel.Reader.ReadAtLeastAsync(minimumCount: 3).DefaultTimeout(); + var clientActivity = await clientActivityTcs.Task.DefaultTimeout(); + + Assert.Collection(serverActivities, + a => Assert.Equal($"{serverHubName}/OnConnectedAsync", a.DisplayName), + a => + { + Assert.Equal($"{serverHubName}/InvokeException", a.DisplayName); + Assert.Equal(ActivityStatusCode.Error, clientActivity.Status); + + var tags = clientActivity.TagObjects.ToDictionary(); + Assert.Equal(typeof(HubException).FullName, tags["error.type"]); + }, + a => Assert.Equal($"{serverHubName}/OnDisconnectedAsync", a.DisplayName)); + + Assert.Equal(HubConnection.ActivityName, clientActivity.OperationName); + Assert.Equal($"{clientHubName}/InvokeException", clientActivity.DisplayName); + Assert.Equal(ActivityKind.Client, clientActivity.Kind); + Assert.Equal(ActivityStatusCode.Error, clientActivity.Status); + + var tags = clientActivity.TagObjects.ToDictionary(); + Assert.Equal(typeof(HubException).FullName, tags["error.type"]); + } + } + + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task SendAsync_Tracing(string protocolName, HttpTransportType transportType, string path) + { + var protocol = HubProtocols[protocolName]; + await using (var server = await StartServer()) + { + var serverChannel = Channel.CreateUnbounded(); + var clientActivityTcs = new TaskCompletionSource(); + var serverSource = server.Services.GetRequiredService().ActivitySource; + var clientSourceContainer = new SignalRClientActivitySource(); + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource) || ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => + { + if (activity.Source == clientSourceContainer.ActivitySource) + { + clientActivityTcs.SetResult(activity); + } + else + { + serverChannel.Writer.TryWrite(activity); + } + } + }; + ActivitySource.AddActivityListener(listener); + + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + path, transportType); + connectionBuilder.Services.AddSingleton(protocol); + connectionBuilder.Services.AddSingleton(clientSourceContainer); + + var connection = connectionBuilder.Build(); + + try + { + await connection.StartAsync().DefaultTimeout(); + + var echoTcs = new TaskCompletionSource(); + connection.On("Echo", echoTcs.SetResult); + + await connection.SendAsync(nameof(TestHub.CallEcho), "Hi"); + + // The SignalR client activity shouldn't escape into user code. + Assert.Null(Activity.Current); + + // Wait until message is echoed back from the server. + // Needed so the client doesn't stop the connection before the server gets the invocation. + await echoTcs.Task.DefaultTimeout(); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + + var port = new Uri(server.Url).Port; + var serverHubName = path switch + { + "/default" => typeof(TestHub).FullName, + "/hubT" => typeof(TestHubT).FullName, + "/dynamic" => typeof(DynamicTestHub).FullName, + _ => throw new InvalidOperationException("Unexpected path: " + path) + }; + var clientHubName = path.TrimStart('/'); + + var serverActivities = await serverChannel.Reader.ReadAtLeastAsync(minimumCount: 3).DefaultTimeout(); + var clientActivity = await clientActivityTcs.Task.DefaultTimeout(); + + Assert.Collection(serverActivities, + a => Assert.Equal($"{serverHubName}/OnConnectedAsync", a.DisplayName), + a => + { + Assert.Equal($"{serverHubName}/CallEcho", a.DisplayName); + Assert.Equal(clientActivity.Id, a.ParentId); + }, + a => Assert.Equal($"{serverHubName}/OnDisconnectedAsync", a.DisplayName)); + + Assert.Equal(HubConnection.ActivityName, clientActivity.OperationName); + Assert.Equal($"{clientHubName}/CallEcho", clientActivity.DisplayName); + Assert.Equal(ActivityKind.Client, clientActivity.Kind); + } + } +} diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index f36614e184a8..be7cd64e1a64 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Diagnostics; using System.Net; using System.Net.Http; using System.Net.WebSockets; @@ -12,7 +11,7 @@ using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections.Client; using Microsoft.AspNetCore.InternalTesting; -using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Client.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Test.Internal; using Microsoft.AspNetCore.SignalR.Tests; @@ -28,7 +27,7 @@ public class HubConnectionTestsCollection : ICollectionFixture(delegateConnectionFactory); + if (activitySourceContainer != null) + { + hubConnectionBuilder.Services.AddSingleton(activitySourceContainer); + } return hubConnectionBuilder.Build(); } @@ -115,131 +119,6 @@ public async Task CheckFixedMessage(string protocolName, HttpTransportType trans } } - [Theory] - [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task InvokeAsync_SendTraceHeader(string protocolName, HttpTransportType transportType, string path) - { - var protocol = HubProtocols[protocolName]; - await using (var server = await StartServer()) - { - var channel = Channel.CreateUnbounded(); - var serverSource = server.Services.GetRequiredService().ActivitySource; - - using var listener = new ActivityListener - { - ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource), - Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, - ActivityStarted = activity => channel.Writer.TryWrite(activity) - }; - ActivitySource.AddActivityListener(listener); - - var connectionBuilder = new HubConnectionBuilder() - .WithLoggerFactory(LoggerFactory) - .WithUrl(server.Url + path, transportType); - connectionBuilder.Services.AddSingleton(protocol); - - var connection = connectionBuilder.Build(); - - Activity clientActivity1 = null; - Activity clientActivity2 = null; - try - { - await connection.StartAsync().DefaultTimeout(); - - // Invocation 1 - try - { - clientActivity1 = new Activity("ClientActivity1"); - clientActivity1.AddBaggage("baggage-1", "value-1"); - clientActivity1.Start(); - - var result = await connection.InvokeAsync(nameof(TestHub.HelloWorld)).DefaultTimeout(); - - Assert.Equal("Hello World!", result); - } - finally - { - clientActivity1?.Stop(); - } - - // Invocation 2 - try - { - clientActivity2 = new Activity("ClientActivity2"); - clientActivity2.AddBaggage("baggage-2", "value-2"); - clientActivity2.Start(); - - var result = await connection.InvokeAsync(nameof(TestHub.HelloWorld)).DefaultTimeout(); - - Assert.Equal("Hello World!", result); - } - finally - { - clientActivity2?.Stop(); - } - } - catch (Exception ex) - { - LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); - throw; - } - finally - { - await connection.DisposeAsync().DefaultTimeout(); - } - - var activities = await channel.Reader.ReadAtLeastAsync(minimumCount: 4).DefaultTimeout(); - - var hubName = path switch - { - "/default" => typeof(TestHub).FullName, - "/hubT" => typeof(TestHubT).FullName, - "/dynamic" => typeof(DynamicTestHub).FullName, - _ => throw new InvalidOperationException("Unexpected path: " + path) - }; - - Assert.Collection(activities, - a => - { - Assert.Equal(SignalRServerActivitySource.OnConnected, a.OperationName); - Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); - Assert.False(a.HasRemoteParent); - Assert.Empty(a.Baggage); - }, - a => - { - Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); - Assert.Equal(clientActivity1.Id, a.ParentId); - Assert.True(a.HasRemoteParent); - Assert.Collection(a.Baggage, - b => - { - Assert.Equal("baggage-1", b.Key); - Assert.Equal("value-1", b.Value); - }); - }, - a => - { - Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); - Assert.Equal(clientActivity2.Id, a.ParentId); - Assert.True(a.HasRemoteParent); - Assert.Collection(a.Baggage, - b => - { - Assert.Equal("baggage-2", b.Key); - Assert.Equal("value-2", b.Value); - }); - }, - a => - { - Assert.Equal(SignalRServerActivitySource.OnDisconnected, a.OperationName); - Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); - Assert.False(a.HasRemoteParent); - Assert.Empty(a.Baggage); - }); - } - } - [Fact] public async Task ServerRejectsClientWithOldProtocol() { @@ -596,97 +475,6 @@ public async Task StreamAsyncCoreTest(string protocolName, HttpTransportType tra } } - [Theory] - [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task StreamAsyncCore_SendTraceHeader(string protocolName, HttpTransportType transportType, string path) - { - var protocol = HubProtocols[protocolName]; - await using (var server = await StartServer()) - { - var channel = Channel.CreateUnbounded(); - var serverSource = server.Services.GetRequiredService().ActivitySource; - - using var listener = new ActivityListener - { - ShouldListenTo = activitySource => ReferenceEquals(activitySource, serverSource), - Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, - ActivityStarted = activity => channel.Writer.TryWrite(activity) - }; - ActivitySource.AddActivityListener(listener); - - var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory); - - Activity clientActivity = null; - try - { - await connection.StartAsync().DefaultTimeout(); - - clientActivity = new Activity("ClientActivity"); - clientActivity.AddBaggage("baggage-1", "value-1"); - clientActivity.Start(); - - var expectedValue = 0; - var streamTo = 5; - var asyncEnumerable = connection.StreamAsyncCore("Stream", new object[] { streamTo }); - await foreach (var streamValue in asyncEnumerable) - { - Assert.Equal(expectedValue, streamValue); - expectedValue++; - } - - Assert.Equal(streamTo, expectedValue); - } - catch (Exception ex) - { - LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); - throw; - } - finally - { - clientActivity?.Stop(); - await connection.DisposeAsync().DefaultTimeout(); - } - - var activities = await channel.Reader.ReadAtLeastAsync(minimumCount: 3).DefaultTimeout(); - - var hubName = path switch - { - "/default" => typeof(TestHub).FullName, - "/hubT" => typeof(TestHubT).FullName, - "/dynamic" => typeof(DynamicTestHub).FullName, - _ => throw new InvalidOperationException("Unexpected path: " + path) - }; - - Assert.Collection(activities, - a => - { - Assert.Equal(SignalRServerActivitySource.OnConnected, a.OperationName); - Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); - Assert.False(a.HasRemoteParent); - Assert.Empty(a.Baggage); - }, - a => - { - Assert.Equal(SignalRServerActivitySource.InvocationIn, a.OperationName); - Assert.Equal(clientActivity.Id, a.ParentId); - Assert.True(a.HasRemoteParent); - Assert.Collection(a.Baggage, - b => - { - Assert.Equal("baggage-1", b.Key); - Assert.Equal("value-1", b.Value); - }); - }, - a => - { - Assert.Equal(SignalRServerActivitySource.OnDisconnected, a.OperationName); - Assert.Equal("Microsoft.AspNetCore.Hosting.HttpRequestIn", a.Parent.OperationName); - Assert.False(a.HasRemoteParent); - Assert.Empty(a.Baggage); - }); - } - } - [Theory] [InlineData("json")] [InlineData("messagepack")] diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs index fdb2f56c175a..17d09d87818c 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs @@ -20,6 +20,8 @@ public class TestHub : Hub public string Echo(string message) => TestHubMethodsImpl.Echo(message); + public string InvokeException() => TestHubMethodsImpl.InvokeException(); + public ChannelReader Stream(int count) => TestHubMethodsImpl.Stream(count); public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); @@ -140,6 +142,8 @@ public class DynamicTestHub : DynamicHub public string Echo(string message) => TestHubMethodsImpl.Echo(message); + public string InvokeException() => TestHubMethodsImpl.InvokeException(); + public ChannelReader Stream(int count) => TestHubMethodsImpl.Stream(count); public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); @@ -174,6 +178,8 @@ public class TestHubT : Hub public string Echo(string message) => TestHubMethodsImpl.Echo(message); + public string InvokeException() => TestHubMethodsImpl.InvokeException(); + public ChannelReader Stream(int count) => TestHubMethodsImpl.Stream(count); public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); @@ -214,6 +220,8 @@ public static string Echo(string message) return message; } + public static string InvokeException() => throw new InvalidOperationException(); + public static ChannelReader Stream(int count) { var channel = Channel.CreateUnbounded(); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs index 4692ba603537..ac74c8c7d0ef 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs @@ -456,6 +456,7 @@ public async Task NegotiateThatReturnsRedirectUrlDoesNotAddAnotherNegotiateVersi testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + EndPoint connectedEndpoint = null; using (var noErrorScope = new VerifyNoErrorsScope()) { await WithConnectionAsync( @@ -463,6 +464,7 @@ await WithConnectionAsync( async (connection) => { await connection.StartAsync().DefaultTimeout(); + connectedEndpoint = connection.RemoteEndPoint; }); } @@ -471,6 +473,7 @@ await WithConnectionAsync( Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); Assert.Equal("https://another.domain.url/chat?negotiateVersion=1&id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); + Assert.Equal("https://another.domain.url/chat", connectedEndpoint.ToString()); } [Fact] diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs index 7415c5acc8eb..9737f544eddc 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Client.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.Extensions.DependencyInjection; @@ -11,12 +12,20 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests; public partial class HubConnectionTests { - private static HubConnection CreateHubConnection(TestConnection connection, IHubProtocol protocol = null, ILoggerFactory loggerFactory = null) + private static HubConnection CreateHubConnection( + TestConnection connection, + IHubProtocol protocol = null, + ILoggerFactory loggerFactory = null, + SignalRClientActivitySource clientActivitySource = null) { var builder = new HubConnectionBuilder().WithUrl("http://example.com"); var delegateConnectionFactory = new DelegateConnectionFactory( - endPoint => connection.StartAsync()); + async endPoint => + { + connection.RemoteEndPoint = endPoint; + return await connection.StartAsync(); + }); builder.Services.AddSingleton(delegateConnectionFactory); @@ -30,6 +39,11 @@ private static HubConnection CreateHubConnection(TestConnection connection, IHub builder.Services.AddSingleton(protocol); } + if (clientActivitySource != null) + { + builder.Services.AddSingleton(clientActivitySource); + } + return builder.Build(); } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs index 738762bcac45..cd7eff93f282 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Diagnostics; using System.Threading.Channels; using Microsoft.AspNetCore.InternalTesting; using Microsoft.AspNetCore.SignalR.Tests; @@ -140,34 +139,6 @@ public async Task InvokeSendsAnInvocationMessage() } } - [Fact] - public async Task InvokeSendsAnInvocationMessage_SendTraceHeaders() - { - var connection = new TestConnection(); - var hubConnection = CreateHubConnection(connection); - try - { - await hubConnection.StartAsync().DefaultTimeout(); - - using var clientActivity = new Activity("ClientActivity"); - clientActivity.Start(); - - var invokeTask = hubConnection.InvokeAsync("Foo"); - - var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout(); - var traceParent = (string)invokeMessage["headers"]["traceparent"]; - - Assert.Equal(clientActivity.Id, traceParent); - - Assert.Equal(TaskStatus.WaitingForActivation, invokeTask.Status); - } - finally - { - await hubConnection.DisposeAsync().DefaultTimeout(); - await connection.DisposeAsync().DefaultTimeout(); - } - } - [Fact] public async Task ReceiveCloseMessageWithoutErrorWillCloseHubConnection() { @@ -253,36 +224,6 @@ public async Task StreamSendsAnInvocationMessage() } } - [Fact] - public async Task StreamSendsAnInvocationMessage_SendTraceHeaders() - { - var connection = new TestConnection(); - var hubConnection = CreateHubConnection(connection); - try - { - await hubConnection.StartAsync().DefaultTimeout(); - - using var clientActivity = new Activity("ClientActivity"); - clientActivity.Start(); - - var channel = await hubConnection.StreamAsChannelAsync("Foo").DefaultTimeout(); - - var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout(); - var traceParent = (string)invokeMessage["headers"]["traceparent"]; - - Assert.Equal(clientActivity.Id, traceParent); - - // Complete the channel - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).DefaultTimeout(); - await channel.Completion.DefaultTimeout(); - } - finally - { - await hubConnection.DisposeAsync().DefaultTimeout(); - await connection.DisposeAsync().DefaultTimeout(); - } - } - [Fact] public async Task InvokeCompletedWhenCompletionMessageReceived() { diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Tracing.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Tracing.cs new file mode 100644 index 000000000000..a02e4c47ff5b --- /dev/null +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Tracing.cs @@ -0,0 +1,219 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.SignalR.Client.Internal; +using Microsoft.AspNetCore.SignalR.Tests; +using Newtonsoft.Json.Linq; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests; + +public partial class HubConnectionTests +{ + public class Tracing : VerifiableLoggedTest + { + [Fact] + public async Task InvokeSendsAnInvocationMessage_SendTraceHeaders() + { + var clientSourceContainer = new SignalRClientActivitySource(); + Activity clientActivity = null; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => clientActivity = activity + }; + ActivitySource.AddActivityListener(listener); + + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, clientActivitySource: clientSourceContainer); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + var invokeTask = hubConnection.InvokeAsync("Foo"); + + var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout(); + var traceParent = (string)invokeMessage["headers"]["traceparent"]; + + Assert.Equal(clientActivity.Id, traceParent); + Assert.Equal("example.com", clientActivity.TagObjects.Single(t => t.Key == "server.address").Value); + + Assert.Equal(TaskStatus.WaitingForActivation, invokeTask.Status); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task StreamSendsAnInvocationMessage_SendTraceHeaders() + { + var clientSourceContainer = new SignalRClientActivitySource(); + Activity clientActivity = null; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => clientActivity = activity + }; + ActivitySource.AddActivityListener(listener); + + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, clientActivitySource: clientSourceContainer); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + var channel = await hubConnection.StreamAsChannelAsync("Foo").DefaultTimeout(); + + var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout(); + var traceParent = (string)invokeMessage["headers"]["traceparent"]; + + Assert.Equal(clientActivity.Id, traceParent); + + // Complete the channel + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).DefaultTimeout(); + await channel.Completion.DefaultTimeout(); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task SendAnInvocationMessage_SendTraceHeaders() + { + var clientSourceContainer = new SignalRClientActivitySource(); + Activity clientActivity = null; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => clientActivity = activity + }; + ActivitySource.AddActivityListener(listener); + + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, clientActivitySource: clientSourceContainer); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + await hubConnection.SendAsync("Foo"); + + var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout(); + var traceParent = (string)invokeMessage["headers"]["traceparent"]; + + Assert.Equal(clientActivity.Id, traceParent); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task InvokeSendsAnInvocationMessage_ConnectionRemoteEndPointChanged_UseRemoteEndpointUrl() + { + var clientSourceContainer = new SignalRClientActivitySource(); + Activity clientActivity = null; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => clientActivity = activity + }; + ActivitySource.AddActivityListener(listener); + + TestConnection connection = null; + connection = new TestConnection(onStart: () => + { + connection.RemoteEndPoint = new UriEndPoint(new Uri("http://example.net")); + return Task.CompletedTask; + }); + var hubConnection = CreateHubConnection(connection, clientActivitySource: clientSourceContainer); + try + { + await hubConnection.StartAsync().DefaultTimeout(); + + _ = hubConnection.InvokeAsync("Foo"); + + await connection.ReadSentJsonAsync().DefaultTimeout(); + + Assert.Equal("example.net", clientActivity.TagObjects.Single(t => t.Key == "server.address").Value); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + + [Fact] + public async Task InvokeSendsAnInvocationMessage_ConnectionRemoteEndPointChangedDuringConnect_UseRemoteEndpointUrl() + { + var clientSourceContainer = new SignalRClientActivitySource(); + var clientActivityTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); ; + + using var listener = new ActivityListener + { + ShouldListenTo = activitySource => ReferenceEquals(activitySource, clientSourceContainer.ActivitySource), + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = clientActivityTcs.SetResult + }; + ActivitySource.AddActivityListener(listener); + + var syncPoint = new SyncPoint(); + TestConnection connection = null; + connection = new TestConnection(onStart: async () => + { + await syncPoint.WaitToContinue(); + connection.RemoteEndPoint = new UriEndPoint(new Uri("http://example.net:5050")); + }); + var hubConnection = CreateHubConnection(connection, clientActivitySource: clientSourceContainer); + try + { + var startTask = hubConnection.StartAsync(); + + _ = hubConnection.InvokeAsync("Foo"); + + var clientActivity = await clientActivityTcs.Task.DefaultTimeout(); + + // Initial server.address uses configured HubConnection URL. + Assert.Equal("example.com", clientActivity.TagObjects.Single(t => t.Key == "server.address").Value); + Assert.Equal(80, (int)clientActivity.TagObjects.Single(t => t.Key == "server.port").Value); + + syncPoint.Continue(); + + await startTask.DefaultTimeout(); + + await connection.ReadSentJsonAsync().DefaultTimeout(); + + // After connection is started, server.address is updated to the connection's remote endpoint. + Assert.Equal("example.net", clientActivity.TagObjects.Single(t => t.Key == "server.address").Value); + Assert.Equal(5050, (int)clientActivity.TagObjects.Single(t => t.Key == "server.port").Value); + } + finally + { + await hubConnection.DisposeAsync().DefaultTimeout(); + await connection.DisposeAsync().DefaultTimeout(); + } + } + } +} diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index e193d346c037..19bca526098e 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -351,6 +351,9 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel throw new InvalidOperationException("Negotiate redirection limit exceeded."); } + // Set the final negotiated URI as the endpoint. + RemoteEndPoint = new UriEndPoint(Utils.CreateEndPointUri(uri)); + // This should only need to happen once var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); @@ -403,6 +406,9 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel _httpConnectionOptions.UseStatefulReconnect = transportType == HttpTransportType.WebSockets ? _httpConnectionOptions.UseStatefulReconnect : false; negotiationResponse = await GetNegotiationResponseAsync(uri, cancellationToken).ConfigureAwait(false); connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); + + // Set the final negotiated URI as the endpoint. + RemoteEndPoint = new UriEndPoint(Utils.CreateEndPointUri(uri)); } Log.StartingTransport(_logger, transportType, uri); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/Utils.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/Utils.cs index b9a44983f63b..265550d6aebc 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/Utils.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/Utils.cs @@ -7,6 +7,19 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; internal static class Utils { + public static Uri CreateEndPointUri(Uri url) + { + // The EndPoint URI shouldn't have querystring or target. + var uriBuilder = new UriBuilder + { + Scheme = url.Scheme, + Host = url.Host, + Port = url.Port, + Path = url.AbsolutePath + }; + return uriBuilder.Uri; + } + public static Uri AppendPath(Uri url, string path) { var builder = new UriBuilder(url); diff --git a/src/SignalR/common/testassets/Tests.Utils/ChannelExtensions.cs b/src/SignalR/common/testassets/Tests.Utils/ChannelExtensions.cs index 1afb51b88e3d..e867a23a933b 100644 --- a/src/SignalR/common/testassets/Tests.Utils/ChannelExtensions.cs +++ b/src/SignalR/common/testassets/Tests.Utils/ChannelExtensions.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; + namespace System.Threading.Channels; public static class ChannelExtensions @@ -41,7 +43,7 @@ public static async Task> ReadAtLeastAsync(this ChannelReader read var items = new List(); - while (items.Count < minimumCount && !cancellationToken.IsCancellationRequested) + while (items.Count < minimumCount) { while (reader.TryRead(out var item)) { @@ -52,10 +54,17 @@ public static async Task> ReadAtLeastAsync(this ChannelReader read } } - var readTask = reader.WaitToReadAsync(cancellationToken).AsTask(); - if (!await readTask.ConfigureAwait(false)) + try + { + var readTask = reader.WaitToReadAsync(cancellationToken).AsTask(); + if (!await readTask.ConfigureAwait(false)) + { + throw new InvalidOperationException($"Channel ended after writing {items.Count} items."); + } + } + catch (OperationCanceledException) { - throw new InvalidOperationException($"Channel ended after writing {items.Count} items."); + throw new OperationCanceledException($"ReadAtLeastAsync canceled with {items.Count} of {minimumCount} items."); } }