Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions src/GraphQL.Client.Http/GraphQLHttpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ public class GraphQLHttpClient : IGraphQLClient {

private readonly GraphQLHttpWebSocket graphQlHttpWebSocket;
private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource();
private readonly HttpClient httpClient;
private readonly ConcurrentDictionary<Tuple<GraphQLRequest, Type>, object> subscriptionStreams = new ConcurrentDictionary<Tuple<GraphQLRequest, Type>, object>();

/// <summary>
/// the instance of <see cref="HttpClient"/> which is used internally
/// </summary>
public HttpClient HttpClient { get; }

/// <summary>
/// The Options to be used
/// </summary>
Expand All @@ -30,33 +34,33 @@ public GraphQLHttpClient(Uri endPoint) : this(o => o.EndPoint = endPoint) { }
public GraphQLHttpClient(Action<GraphQLHttpClientOptions> configure) {
Options = new GraphQLHttpClientOptions();
configure(Options);
this.httpClient = new HttpClient();
this.HttpClient = new HttpClient(Options.HttpMessageHandler);
this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options);
}

public GraphQLHttpClient(GraphQLHttpClientOptions options) {
Options = options;
this.httpClient = new HttpClient();
this.HttpClient = new HttpClient(Options.HttpMessageHandler);
this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options);
}

public GraphQLHttpClient(GraphQLHttpClientOptions options, HttpClient httpClient) {
Options = options;
this.httpClient = httpClient;
this.HttpClient = httpClient;
this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options);
}

/// <inheritdoc />
public Task<GraphQLResponse<TResponse>> SendQueryAsync<TResponse>(GraphQLRequest request, CancellationToken cancellationToken = default) {
return Options.UseWebSocketForQueriesAndMutations
? this.graphQlHttpWebSocket.Request<TResponse>(request, Options, cancellationToken)
? this.graphQlHttpWebSocket.SendRequest<TResponse>(request, this, cancellationToken)
: this.SendHttpPostRequestAsync<TResponse>(request, cancellationToken);
}

public Task<GraphQLResponse<TResponse>> SendMutationAsync<TResponse>(GraphQLRequest request, CancellationToken cancellationToken = default) {
return Options.UseWebSocketForQueriesAndMutations
? this.graphQlHttpWebSocket.Request<TResponse>(request, Options, cancellationToken)
: this.SendHttpPostRequestAsync<TResponse>(request, cancellationToken);
}
/// <inheritdoc />
public Task<GraphQLResponse<TResponse>> SendMutationAsync<TResponse>(GraphQLRequest request,
CancellationToken cancellationToken = default)
=> SendQueryAsync<TResponse>(request, cancellationToken);

/// <inheritdoc />
public IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TResponse>(GraphQLRequest request) {
Expand All @@ -68,7 +72,7 @@ public IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TRespons
if (subscriptionStreams.ContainsKey(key))
return (IObservable<GraphQLResponse<TResponse>>)subscriptionStreams[key];

var observable = graphQlHttpWebSocket.CreateSubscriptionStream<TResponse>(request, Options, cancellationToken: cancellationTokenSource.Token);
var observable = graphQlHttpWebSocket.CreateSubscriptionStream<TResponse>(request, this, cancellationToken: cancellationTokenSource.Token);

subscriptionStreams.TryAdd(key, observable);
return observable;
Expand All @@ -84,7 +88,7 @@ public IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TRespons
if (subscriptionStreams.ContainsKey(key))
return (IObservable<GraphQLResponse<TResponse>>)subscriptionStreams[key];

var observable = graphQlHttpWebSocket.CreateSubscriptionStream<TResponse>(request, Options, exceptionHandler, cancellationTokenSource.Token);
var observable = graphQlHttpWebSocket.CreateSubscriptionStream<TResponse>(request, this, exceptionHandler, cancellationTokenSource.Token);
subscriptionStreams.TryAdd(key, observable);
return observable;
}
Expand All @@ -98,8 +102,9 @@ public IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TRespons
#region Private Methods

private async Task<GraphQLResponse<TResponse>> SendHttpPostRequestAsync<TResponse>(GraphQLRequest request, CancellationToken cancellationToken = default) {
using var httpRequestMessage = this.GenerateHttpRequestMessage(request.SerializeToJson(Options));
using var httpResponseMessage = await this.httpClient.SendAsync(httpRequestMessage, cancellationToken);
var preprocessedRequest = await Options.PreprocessRequest(request, this);
using var httpRequestMessage = this.GenerateHttpRequestMessage(preprocessedRequest.SerializeToJson(Options));
using var httpResponseMessage = await this.HttpClient.SendAsync(httpRequestMessage, cancellationToken);
if (!httpResponseMessage.IsSuccessStatusCode) {
throw new GraphQLHttpException(httpResponseMessage);
}
Expand Down Expand Up @@ -140,7 +145,7 @@ public void Dispose() {

private void _dispose() {
disposed = true;
this.httpClient.Dispose();
this.HttpClient.Dispose();
this.graphQlHttpWebSocket.Dispose();
cancellationTokenSource.Cancel();
cancellationTokenSource.Dispose();
Expand Down
7 changes: 6 additions & 1 deletion src/GraphQL.Client.Http/GraphQLHttpClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text.Json;
using System.Threading.Tasks;
using Dahomey.Json;

namespace GraphQL.Client.Http {
Expand Down Expand Up @@ -46,6 +47,10 @@ public class GraphQLHttpClientOptions {
/// If <see langword="true"/>, the websocket connection is also used for regular queries and mutations
/// </summary>
public bool UseWebSocketForQueriesAndMutations { get; set; } = false;
}

/// <summary>
/// Request preprocessing function. Can be used i.e. to inject authorization info into a GraphQL request payload.
/// </summary>
public Func<GraphQLRequest, GraphQLHttpClient, Task<GraphQLRequest>> PreprocessRequest { get; set; } = (request, client) => Task.FromResult(request);
}
}
5 changes: 5 additions & 0 deletions src/GraphQL.Client.Http/Websocket/GraphQLHttpWebSocket.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Net.Http;
using System.Net.WebSockets;
using System.Reactive.Disposables;
using System.Reactive.Linq;
Expand Down Expand Up @@ -108,9 +109,13 @@ public Task InitializeWebSocket() {
switch (clientWebSocket) {
case ClientWebSocket nativeWebSocket:
nativeWebSocket.Options.AddSubProtocol("graphql-ws");
nativeWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates;
nativeWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials;
break;
case System.Net.WebSockets.Managed.ClientWebSocket managedWebSocket:
managedWebSocket.Options.AddSubProtocol("graphql-ws");
managedWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates;
managedWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials;
break;
default:
throw new NotSupportedException($"unknown websocket type {clientWebSocket.GetType().Name}");
Expand Down
148 changes: 82 additions & 66 deletions src/GraphQL.Client.Http/Websocket/GraphQLHttpWebsocketHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ public static class GraphQLHttpWebsocketHelpers {
internal static IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream<TResponse>(
this GraphQLHttpWebSocket graphQlHttpWebSocket,
GraphQLRequest request,
GraphQLHttpClientOptions options,
GraphQLHttpClient client,
Action<Exception> exceptionHandler = null,
CancellationToken cancellationToken = default) {
return Observable.Defer(() =>
Observable.Create<GraphQLResponse<TResponse>>(async observer => {
await client.Options.PreprocessRequest(request, client);
var startRequest = new GraphQLWebSocketRequest {
Id = Guid.NewGuid().ToString("N"),
Type = GraphQLWebSocketMessageType.GQL_START,
Expand All @@ -27,34 +28,38 @@ internal static IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream
Id = startRequest.Id,
Type = GraphQLWebSocketMessageType.GQL_STOP
};
var initRequest = new GraphQLWebSocketRequest {
Id = startRequest.Id,
Type = GraphQLWebSocketMessageType.GQL_CONNECTION_INIT,
};

var observable = Observable.Create<GraphQLResponse<TResponse>>(o =>
graphQlHttpWebSocket.ResponseStream
// ignore null values and messages for other requests
.Where(response => response != null && response.Id == startRequest.Id)
.Subscribe(response => {
// terminate the sequence when a 'complete' message is received
if (response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE) {
Debug.WriteLine($"received 'complete' message on subscription {startRequest.Id}");
o.OnCompleted();
return;
}

// post the GraphQLResponse to the stream (even if a GraphQL error occurred)
Debug.WriteLine($"received payload on subscription {startRequest.Id}");
var typedResponse =
JsonSerializer.Deserialize<GraphQLWebSocketResponse<TResponse>>(response.MessageBytes,
options.JsonSerializerOptions);
o.OnNext(typedResponse.Payload);

// in case of a GraphQL error, terminate the sequence after the response has been posted
if (response.Type == GraphQLWebSocketMessageType.GQL_ERROR) {
Debug.WriteLine($"terminating subscription {startRequest.Id} because of a GraphQL error");
o.OnCompleted();
}
},
o.OnError,
o.OnCompleted)
// terminate the sequence when a 'complete' message is received
if (response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE) {
Debug.WriteLine($"received 'complete' message on subscription {startRequest.Id}");
o.OnCompleted();
return;
}

// post the GraphQLResponse to the stream (even if a GraphQL error occurred)
Debug.WriteLine($"received payload on subscription {startRequest.Id}");
var typedResponse =
JsonSerializer.Deserialize<GraphQLWebSocketResponse<TResponse>>(response.MessageBytes,
client.Options.JsonSerializerOptions);
o.OnNext(typedResponse.Payload);

// in case of a GraphQL error, terminate the sequence after the response has been posted
if (response.Type == GraphQLWebSocketMessageType.GQL_ERROR) {
Debug.WriteLine($"terminating subscription {startRequest.Id} because of a GraphQL error");
o.OnCompleted();
}
},
o.OnError,
o.OnCompleted)
);

try {
Expand All @@ -81,6 +86,16 @@ internal static IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream
})
);

// send connection init
Debug.WriteLine($"sending connection init on subscription {startRequest.Id}");
try {
await graphQlHttpWebSocket.SendWebSocketRequest(initRequest).ConfigureAwait(false);
}
catch (Exception e) {
Console.WriteLine(e);
throw;
}

Debug.WriteLine($"sending initial message on subscription {startRequest.Id}");
// send subscription request
try {
Expand Down Expand Up @@ -137,53 +152,54 @@ internal static IObservable<GraphQLResponse<TResponse>> CreateSubscriptionStream
.Publish().RefCount();
}

internal static Task<GraphQLResponse<TResponse>> Request<TResponse>(
internal static Task<GraphQLResponse<TResponse>> SendRequest<TResponse>(
this GraphQLHttpWebSocket graphQlHttpWebSocket,
GraphQLRequest request,
GraphQLHttpClientOptions options,
GraphQLHttpClient client,
CancellationToken cancellationToken = default) {
return Observable.Create<GraphQLResponse<TResponse>>(async observer => {
var websocketRequest = new GraphQLWebSocketRequest {
Id = Guid.NewGuid().ToString("N"),
Type = GraphQLWebSocketMessageType.GQL_START,
Payload = request
};
var observable = graphQlHttpWebSocket.ResponseStream
.Where(response => response != null && response.Id == websocketRequest.Id)
.TakeUntil(response => response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE)
.Select(response => {
Debug.WriteLine($"received response for request {websocketRequest.Id}");
var typedResponse =
JsonSerializer.Deserialize<GraphQLWebSocketResponse<TResponse>>(response.MessageBytes,
options.JsonSerializerOptions);
return typedResponse.Payload;
});

try {
// intialize websocket (completes immediately if socket is already open)
await graphQlHttpWebSocket.InitializeWebSocket().ConfigureAwait(false);
}
catch (Exception e) {
// subscribe observer to failed observable
return Observable.Throw<GraphQLResponse<TResponse>>(e).Subscribe(observer);
}

var disposable = new CompositeDisposable(
observable.Subscribe(observer)
);

Debug.WriteLine($"submitting request {websocketRequest.Id}");
// send request
try {
await graphQlHttpWebSocket.SendWebSocketRequest(websocketRequest).ConfigureAwait(false);
}
catch (Exception e) {
Console.WriteLine(e);
throw;
}

return disposable;
})
await client.Options.PreprocessRequest(request, client);
var websocketRequest = new GraphQLWebSocketRequest {
Id = Guid.NewGuid().ToString("N"),
Type = GraphQLWebSocketMessageType.GQL_START,
Payload = request
};
var observable = graphQlHttpWebSocket.ResponseStream
.Where(response => response != null && response.Id == websocketRequest.Id)
.TakeUntil(response => response.Type == GraphQLWebSocketMessageType.GQL_COMPLETE)
.Select(response => {
Debug.WriteLine($"received response for request {websocketRequest.Id}");
var typedResponse =
JsonSerializer.Deserialize<GraphQLWebSocketResponse<TResponse>>(response.MessageBytes,
client.Options.JsonSerializerOptions);
return typedResponse.Payload;
});

try {
// intialize websocket (completes immediately if socket is already open)
await graphQlHttpWebSocket.InitializeWebSocket().ConfigureAwait(false);
}
catch (Exception e) {
// subscribe observer to failed observable
return Observable.Throw<GraphQLResponse<TResponse>>(e).Subscribe(observer);
}

var disposable = new CompositeDisposable(
observable.Subscribe(observer)
);

Debug.WriteLine($"submitting request {websocketRequest.Id}");
// send request
try {
await graphQlHttpWebSocket.SendWebSocketRequest(websocketRequest).ConfigureAwait(false);
}
catch (Exception e) {
Console.WriteLine(e);
throw;
}

return disposable;
})
// complete sequence on OperationCanceledException, this is triggered by the cancellation token
.Catch<GraphQLResponse<TResponse>, OperationCanceledException>(exception =>
Observable.Empty<GraphQLResponse<TResponse>>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
using Xunit.Abstractions;

namespace GraphQL.Integration.Tests {
public class SubscriptionsTest {
public class WebsocketTest {
private readonly ITestOutputHelper output;

private static IWebHost CreateServer(int port) => WebHostHelpers.CreateServer<StartupChat>(port);

private static TimeSpan WaitForConnectionDelay = TimeSpan.FromMilliseconds(200);

public SubscriptionsTest(ITestOutputHelper output) {
public WebsocketTest(ITestOutputHelper output) {
this.output = output;
}

Expand Down