diff --git a/src/GraphQL.Client.Http/GraphQLHttpClient.cs b/src/GraphQL.Client.Http/GraphQLHttpClient.cs index 8d2c5db5..43705941 100644 --- a/src/GraphQL.Client.Http/GraphQLHttpClient.cs +++ b/src/GraphQL.Client.Http/GraphQLHttpClient.cs @@ -12,9 +12,13 @@ public class GraphQLHttpClient : IGraphQLClient { private readonly GraphQLHttpWebSocket graphQlHttpWebSocket; private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); - private readonly HttpClient httpClient; private readonly ConcurrentDictionary, object> subscriptionStreams = new ConcurrentDictionary, object>(); + /// + /// the instance of which is used internally + /// + public HttpClient HttpClient { get; } + /// /// The Options to be used /// @@ -30,33 +34,33 @@ public GraphQLHttpClient(Uri endPoint) : this(o => o.EndPoint = endPoint) { } public GraphQLHttpClient(Action configure) { Options = new GraphQLHttpClientOptions(); configure(Options); - this.httpClient = new HttpClient(); + this.HttpClient = new HttpClient(Options.HttpMessageHandler); this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options); } public GraphQLHttpClient(GraphQLHttpClientOptions options) { Options = options; - this.httpClient = new HttpClient(); + this.HttpClient = new HttpClient(Options.HttpMessageHandler); this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options); } public GraphQLHttpClient(GraphQLHttpClientOptions options, HttpClient httpClient) { Options = options; - this.httpClient = httpClient; + this.HttpClient = httpClient; this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options); } + /// public Task> SendQueryAsync(GraphQLRequest request, CancellationToken cancellationToken = default) { return Options.UseWebSocketForQueriesAndMutations - ? this.graphQlHttpWebSocket.Request(request, Options, cancellationToken) + ? this.graphQlHttpWebSocket.SendRequest(request, this, cancellationToken) : this.SendHttpPostRequestAsync(request, cancellationToken); } - public Task> SendMutationAsync(GraphQLRequest request, CancellationToken cancellationToken = default) { - return Options.UseWebSocketForQueriesAndMutations - ? this.graphQlHttpWebSocket.Request(request, Options, cancellationToken) - : this.SendHttpPostRequestAsync(request, cancellationToken); - } + /// + public Task> SendMutationAsync(GraphQLRequest request, + CancellationToken cancellationToken = default) + => SendQueryAsync(request, cancellationToken); /// public IObservable> CreateSubscriptionStream(GraphQLRequest request) { @@ -68,7 +72,7 @@ public IObservable> CreateSubscriptionStream>)subscriptionStreams[key]; - var observable = graphQlHttpWebSocket.CreateSubscriptionStream(request, Options, cancellationToken: cancellationTokenSource.Token); + var observable = graphQlHttpWebSocket.CreateSubscriptionStream(request, this, cancellationToken: cancellationTokenSource.Token); subscriptionStreams.TryAdd(key, observable); return observable; @@ -84,7 +88,7 @@ public IObservable> CreateSubscriptionStream>)subscriptionStreams[key]; - var observable = graphQlHttpWebSocket.CreateSubscriptionStream(request, Options, exceptionHandler, cancellationTokenSource.Token); + var observable = graphQlHttpWebSocket.CreateSubscriptionStream(request, this, exceptionHandler, cancellationTokenSource.Token); subscriptionStreams.TryAdd(key, observable); return observable; } @@ -98,8 +102,9 @@ public IObservable> CreateSubscriptionStream> SendHttpPostRequestAsync(GraphQLRequest request, CancellationToken cancellationToken = default) { - using var httpRequestMessage = this.GenerateHttpRequestMessage(request.SerializeToJson(Options)); - using var httpResponseMessage = await this.httpClient.SendAsync(httpRequestMessage, cancellationToken); + var preprocessedRequest = await Options.PreprocessRequest(request, this); + using var httpRequestMessage = this.GenerateHttpRequestMessage(preprocessedRequest.SerializeToJson(Options)); + using var httpResponseMessage = await this.HttpClient.SendAsync(httpRequestMessage, cancellationToken); if (!httpResponseMessage.IsSuccessStatusCode) { throw new GraphQLHttpException(httpResponseMessage); } @@ -140,7 +145,7 @@ public void Dispose() { private void _dispose() { disposed = true; - this.httpClient.Dispose(); + this.HttpClient.Dispose(); this.graphQlHttpWebSocket.Dispose(); cancellationTokenSource.Cancel(); cancellationTokenSource.Dispose(); diff --git a/src/GraphQL.Client.Http/GraphQLHttpClientOptions.cs b/src/GraphQL.Client.Http/GraphQLHttpClientOptions.cs index ac5a8314..146ff3f4 100644 --- a/src/GraphQL.Client.Http/GraphQLHttpClientOptions.cs +++ b/src/GraphQL.Client.Http/GraphQLHttpClientOptions.cs @@ -2,6 +2,7 @@ using System.Net.Http; using System.Net.Http.Headers; using System.Text.Json; +using System.Threading.Tasks; using Dahomey.Json; namespace GraphQL.Client.Http { @@ -46,6 +47,10 @@ public class GraphQLHttpClientOptions { /// If , the websocket connection is also used for regular queries and mutations /// public bool UseWebSocketForQueriesAndMutations { get; set; } = false; - } + /// + /// Request preprocessing function. Can be used i.e. to inject authorization info into a GraphQL request payload. + /// + public Func> PreprocessRequest { get; set; } = (request, client) => Task.FromResult(request); + } } diff --git a/src/GraphQL.Client.Http/Websocket/GraphQLHttpWebSocket.cs b/src/GraphQL.Client.Http/Websocket/GraphQLHttpWebSocket.cs index 0c516c14..7098a8e5 100644 --- a/src/GraphQL.Client.Http/Websocket/GraphQLHttpWebSocket.cs +++ b/src/GraphQL.Client.Http/Websocket/GraphQLHttpWebSocket.cs @@ -1,6 +1,7 @@ using System; using System.Diagnostics; using System.IO; +using System.Net.Http; using System.Net.WebSockets; using System.Reactive.Disposables; using System.Reactive.Linq; @@ -108,9 +109,13 @@ public Task InitializeWebSocket() { switch (clientWebSocket) { case ClientWebSocket nativeWebSocket: nativeWebSocket.Options.AddSubProtocol("graphql-ws"); + nativeWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates; + nativeWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials; break; case System.Net.WebSockets.Managed.ClientWebSocket managedWebSocket: managedWebSocket.Options.AddSubProtocol("graphql-ws"); + managedWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates; + managedWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials; break; default: throw new NotSupportedException($"unknown websocket type {clientWebSocket.GetType().Name}"); diff --git a/src/GraphQL.Client.Http/Websocket/GraphQLHttpWebsocketHelpers.cs b/src/GraphQL.Client.Http/Websocket/GraphQLHttpWebsocketHelpers.cs index 2227b777..4d929a83 100644 --- a/src/GraphQL.Client.Http/Websocket/GraphQLHttpWebsocketHelpers.cs +++ b/src/GraphQL.Client.Http/Websocket/GraphQLHttpWebsocketHelpers.cs @@ -13,11 +13,12 @@ public static class GraphQLHttpWebsocketHelpers { internal static IObservable> CreateSubscriptionStream( this GraphQLHttpWebSocket graphQlHttpWebSocket, GraphQLRequest request, - GraphQLHttpClientOptions options, + GraphQLHttpClient client, Action exceptionHandler = null, CancellationToken cancellationToken = default) { return Observable.Defer(() => Observable.Create>(async observer => { + await client.Options.PreprocessRequest(request, client); var startRequest = new GraphQLWebSocketRequest { Id = Guid.NewGuid().ToString("N"), Type = GraphQLWebSocketMessageType.GQL_START, @@ -27,34 +28,38 @@ internal static IObservable> CreateSubscriptionStream Id = startRequest.Id, Type = GraphQLWebSocketMessageType.GQL_STOP }; + var initRequest = new GraphQLWebSocketRequest { + Id = startRequest.Id, + Type = GraphQLWebSocketMessageType.GQL_CONNECTION_INIT, + }; var observable = Observable.Create>(o => graphQlHttpWebSocket.ResponseStream // ignore null values and messages for other requests .Where(response => response != null && response.Id == startRequest.Id) .Subscribe(response => { - // terminate the sequence when a 'complete' message is received - if (response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE) { - Debug.WriteLine($"received 'complete' message on subscription {startRequest.Id}"); - o.OnCompleted(); - return; - } - - // post the GraphQLResponse to the stream (even if a GraphQL error occurred) - Debug.WriteLine($"received payload on subscription {startRequest.Id}"); - var typedResponse = - JsonSerializer.Deserialize>(response.MessageBytes, - options.JsonSerializerOptions); - o.OnNext(typedResponse.Payload); - - // in case of a GraphQL error, terminate the sequence after the response has been posted - if (response.Type == GraphQLWebSocketMessageType.GQL_ERROR) { - Debug.WriteLine($"terminating subscription {startRequest.Id} because of a GraphQL error"); - o.OnCompleted(); - } - }, - o.OnError, - o.OnCompleted) + // terminate the sequence when a 'complete' message is received + if (response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE) { + Debug.WriteLine($"received 'complete' message on subscription {startRequest.Id}"); + o.OnCompleted(); + return; + } + + // post the GraphQLResponse to the stream (even if a GraphQL error occurred) + Debug.WriteLine($"received payload on subscription {startRequest.Id}"); + var typedResponse = + JsonSerializer.Deserialize>(response.MessageBytes, + client.Options.JsonSerializerOptions); + o.OnNext(typedResponse.Payload); + + // in case of a GraphQL error, terminate the sequence after the response has been posted + if (response.Type == GraphQLWebSocketMessageType.GQL_ERROR) { + Debug.WriteLine($"terminating subscription {startRequest.Id} because of a GraphQL error"); + o.OnCompleted(); + } + }, + o.OnError, + o.OnCompleted) ); try { @@ -81,6 +86,16 @@ internal static IObservable> CreateSubscriptionStream }) ); + // send connection init + Debug.WriteLine($"sending connection init on subscription {startRequest.Id}"); + try { + await graphQlHttpWebSocket.SendWebSocketRequest(initRequest).ConfigureAwait(false); + } + catch (Exception e) { + Console.WriteLine(e); + throw; + } + Debug.WriteLine($"sending initial message on subscription {startRequest.Id}"); // send subscription request try { @@ -137,53 +152,54 @@ internal static IObservable> CreateSubscriptionStream .Publish().RefCount(); } - internal static Task> Request( + internal static Task> SendRequest( this GraphQLHttpWebSocket graphQlHttpWebSocket, GraphQLRequest request, - GraphQLHttpClientOptions options, + GraphQLHttpClient client, CancellationToken cancellationToken = default) { return Observable.Create>(async observer => { - var websocketRequest = new GraphQLWebSocketRequest { - Id = Guid.NewGuid().ToString("N"), - Type = GraphQLWebSocketMessageType.GQL_START, - Payload = request - }; - var observable = graphQlHttpWebSocket.ResponseStream - .Where(response => response != null && response.Id == websocketRequest.Id) - .TakeUntil(response => response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE) - .Select(response => { - Debug.WriteLine($"received response for request {websocketRequest.Id}"); - var typedResponse = - JsonSerializer.Deserialize>(response.MessageBytes, - options.JsonSerializerOptions); - return typedResponse.Payload; - }); - - try { - // intialize websocket (completes immediately if socket is already open) - await graphQlHttpWebSocket.InitializeWebSocket().ConfigureAwait(false); - } - catch (Exception e) { - // subscribe observer to failed observable - return Observable.Throw>(e).Subscribe(observer); - } - - var disposable = new CompositeDisposable( - observable.Subscribe(observer) - ); - - Debug.WriteLine($"submitting request {websocketRequest.Id}"); - // send request - try { - await graphQlHttpWebSocket.SendWebSocketRequest(websocketRequest).ConfigureAwait(false); - } - catch (Exception e) { - Console.WriteLine(e); - throw; - } - - return disposable; - }) + await client.Options.PreprocessRequest(request, client); + var websocketRequest = new GraphQLWebSocketRequest { + Id = Guid.NewGuid().ToString("N"), + Type = GraphQLWebSocketMessageType.GQL_START, + Payload = request + }; + var observable = graphQlHttpWebSocket.ResponseStream + .Where(response => response != null && response.Id == websocketRequest.Id) + .TakeUntil(response => response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE) + .Select(response => { + Debug.WriteLine($"received response for request {websocketRequest.Id}"); + var typedResponse = + JsonSerializer.Deserialize>(response.MessageBytes, + client.Options.JsonSerializerOptions); + return typedResponse.Payload; + }); + + try { + // intialize websocket (completes immediately if socket is already open) + await graphQlHttpWebSocket.InitializeWebSocket().ConfigureAwait(false); + } + catch (Exception e) { + // subscribe observer to failed observable + return Observable.Throw>(e).Subscribe(observer); + } + + var disposable = new CompositeDisposable( + observable.Subscribe(observer) + ); + + Debug.WriteLine($"submitting request {websocketRequest.Id}"); + // send request + try { + await graphQlHttpWebSocket.SendWebSocketRequest(websocketRequest).ConfigureAwait(false); + } + catch (Exception e) { + Console.WriteLine(e); + throw; + } + + return disposable; + }) // complete sequence on OperationCanceledException, this is triggered by the cancellation token .Catch, OperationCanceledException>(exception => Observable.Empty>()) diff --git a/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs b/tests/GraphQL.Integration.Tests/WebsocketTest.cs similarity index 98% rename from tests/GraphQL.Integration.Tests/SubscriptionsTest.cs rename to tests/GraphQL.Integration.Tests/WebsocketTest.cs index a70a000c..3e5eafb1 100644 --- a/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs +++ b/tests/GraphQL.Integration.Tests/WebsocketTest.cs @@ -12,14 +12,12 @@ using Xunit.Abstractions; namespace GraphQL.Integration.Tests { - public class SubscriptionsTest { + public class WebsocketTest { private readonly ITestOutputHelper output; private static IWebHost CreateServer(int port) => WebHostHelpers.CreateServer(port); - private static TimeSpan WaitForConnectionDelay = TimeSpan.FromMilliseconds(200); - - public SubscriptionsTest(ITestOutputHelper output) { + public WebsocketTest(ITestOutputHelper output) { this.output = output; }