diff --git a/root.props b/root.props index 61db8f41..df16b57f 100644 --- a/root.props +++ b/root.props @@ -17,7 +17,7 @@ git https://github.com/graphql-dotnet/graphql-client.git True - 2.0.0-alpha.4.subscription-api.8 + 2.0.0-alpha.4.subscription-api.10 4 diff --git a/src/GraphQL.Client/Http/GraphQLHttpSubscriptionHelpers.cs b/src/GraphQL.Client/Http/GraphQLHttpSubscriptionHelpers.cs index ef71a0ce..4ba91205 100644 --- a/src/GraphQL.Client/Http/GraphQLHttpSubscriptionHelpers.cs +++ b/src/GraphQL.Client/Http/GraphQLHttpSubscriptionHelpers.cs @@ -3,6 +3,7 @@ using System.Net.WebSockets; using System.Reactive.Disposables; using System.Reactive.Linq; +using System.Reactive.Threading.Tasks; using System.Threading; using System.Threading.Tasks; using GraphQL.Common; @@ -35,31 +36,39 @@ internal static IObservable CreateSubscriptionStream( Id = startRequest.Id, Type = GQLWebSocketMessageType.GQL_STOP }; - var observable = graphQlHttpWebSocket.ResponseStream - .Where(response => { - return response != null && response.Id == startRequest.Id; - }) - .SelectMany(response => - { - switch (response.Type) + + var observable = Observable.Create(o => + graphQlHttpWebSocket.ResponseStream.Subscribe(response => { - case GQLWebSocketMessageType.GQL_COMPLETE: + // ignore null values and messages for other requests + if (response == null || response.Id != startRequest.Id) return; + + // terminate the sequence when a 'complete' message is received + if (response.Type == GQLWebSocketMessageType.GQL_COMPLETE) + { Debug.WriteLine($"received 'complete' message on subscription {startRequest.Id}"); - return Observable.Empty(); - case GQLWebSocketMessageType.GQL_ERROR: - Debug.WriteLine($"received 'error' message on subscription {startRequest.Id}"); - return Observable.Throw( - new GraphQLSubscriptionException(response.Payload)); - default: - Debug.WriteLine($"received payload on subscription {startRequest.Id}"); - return Observable.Return(((JObject) response?.Payload) - ?.ToObject()); - } - }); + o.OnCompleted(); + return; + } + + // post the GraphQLResponse to the stream (even if a GraphQL error occurred) + Debug.WriteLine($"received payload on subscription {startRequest.Id}"); + o.OnNext(((JObject)response.Payload)?.ToObject()); + + // in case of a GraphQL error, terminate the sequence after the response has been posted + if (response.Type == GQLWebSocketMessageType.GQL_ERROR) + { + Debug.WriteLine($"terminating subscription {startRequest.Id} because of a GraphQL error"); + o.OnCompleted(); + } + }, + o.OnError, + o.OnCompleted) + ); try { - // intialize websocket (completes immediately if socket is already open) + // initialize websocket (completes immediately if socket is already open) await graphQlHttpWebSocket.InitializeWebSocket().ConfigureAwait(false); } catch (Exception e) @@ -148,12 +157,12 @@ internal static IObservable CreateSubscriptionStream( .Publish().RefCount(); } - internal static async Task Request( + internal static Task Request( this GraphQLHttpWebSocket graphQlHttpWebSocket, GraphQLRequest request, CancellationToken cancellationToken = default) { - return await Observable.Create(async observer => + return Observable.Create(async observer => { var websocketRequest = new GraphQLWebSocketRequest { @@ -162,25 +171,12 @@ internal static async Task Request( Payload = request }; var observable = graphQlHttpWebSocket.ResponseStream - .Where(response => { - return response != null && response.Id == websocketRequest.Id; - }) - .SelectMany(response => + .Where(response => response != null && response.Id == websocketRequest.Id) + .TakeUntil(response => response.Type == GQLWebSocketMessageType.GQL_COMPLETE) + .Select(response => { - switch (response.Type) - { - case GQLWebSocketMessageType.GQL_COMPLETE: - Debug.WriteLine($"received 'complete' message on request {websocketRequest.Id}"); - return Observable.Empty(); - case GQLWebSocketMessageType.GQL_ERROR: - Debug.WriteLine($"received 'error' message on request {websocketRequest.Id}"); - return Observable.Throw( - new GraphQLSubscriptionException(response.Payload)); - default: - Debug.WriteLine($"received response for request {websocketRequest.Id}"); - return Observable.Return(((JObject)response?.Payload) - ?.ToObject()); - } + Debug.WriteLine($"received response for request {websocketRequest.Id}"); + return ((JObject) response?.Payload)?.ToObject(); }); try @@ -215,7 +211,8 @@ internal static async Task Request( // complete sequence on OperationCanceledException, this is triggered by the cancellation token .Catch(exception => Observable.Empty()) - .FirstOrDefaultAsync(); + .FirstOrDefaultAsync() + .ToTask(cancellationToken); } } } diff --git a/src/GraphQL.Client/Http/GraphQLHttpWebSocket.cs b/src/GraphQL.Client/Http/GraphQLHttpWebSocket.cs index 28ae1e7b..eb2881ba 100644 --- a/src/GraphQL.Client/Http/GraphQLHttpWebSocket.cs +++ b/src/GraphQL.Client/Http/GraphQLHttpWebSocket.cs @@ -65,7 +65,7 @@ private async Task _sendWebSocketRequest(GraphQLWebSocketRequest request) } await InitializeWebSocket().ConfigureAwait(false); - var webSocketRequestString = JsonConvert.SerializeObject(request); + var webSocketRequestString = JsonConvert.SerializeObject(request, _options.JsonSerializerSettings); await this.clientWebSocket.SendAsync( new ArraySegment(Encoding.UTF8.GetBytes(webSocketRequestString)), WebSocketMessageType.Text, @@ -245,7 +245,7 @@ private async Task _receiveResultAsync() { var stringResult = await reader.ReadToEndAsync(); Debug.WriteLine($"data received on websocket {clientWebSocket.GetHashCode()}: {stringResult}"); - return JsonConvert.DeserializeObject(stringResult); + return JsonConvert.DeserializeObject(stringResult, _options.JsonSerializerSettings); } } else diff --git a/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs b/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs index 635a0d1c..fb704ad4 100644 --- a/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs +++ b/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs @@ -43,10 +43,11 @@ public SubscriptionsTest() { } - private GraphQLHttpClient GetGraphQLClient(int port) + private GraphQLHttpClient GetGraphQLClient(int port, bool requestsViaWebsocket = false) => new GraphQLHttpClient(new GraphQLHttpClientOptions { EndPoint = new Uri($"http://localhost:{port}/graphql"), + UseWebSocketForQueriesAndMutations = requestsViaWebsocket }); @@ -65,6 +66,34 @@ public async void AssertTestingHarness() } } + [Fact] + public async void CanSendRequestViaWebsocket() + { + var port = NetworkHelpers.GetFreeTcpPortNumber(); + using (CreateServer(port)) + { + var client = GetGraphQLClient(port, true); + const string message = "some random testing message"; + var response = await client.AddMessageAsync(message).ConfigureAwait(false); + + Assert.Equal(message, (string)response.Data.addMessage.content); + } + } + + [Fact] + public async void CanHandleRequestErrorViaWebsocket() + { + var port = NetworkHelpers.GetFreeTcpPortNumber(); + using (CreateServer(port)) + { + var client = GetGraphQLClient(port, true); + const string message = "some random testing message"; + var response = await client.SendQueryAsync(new GraphQLRequest("this query is formatted quite badly")).ConfigureAwait(false); + + Assert.Single(response.Errors); + } + } + private const string SubscriptionQuery = @" subscription { messageAdded{ @@ -272,5 +301,63 @@ public async void CanHandleConnectionTimeout() tester.ShouldHaveCompleted(TimeSpan.FromSeconds(5)); server.Dispose(); } + + [Fact] + public async void CanHandleSubscriptionError() + { + var port = NetworkHelpers.GetFreeTcpPortNumber(); + using (CreateServer(port)) + { + var client = GetGraphQLClient(port); + Debug.WriteLine("creating subscription stream"); + IObservable observable = client.CreateSubscriptionStream( + new GraphQLRequest(@" + subscription { + failImmediately { + content + } + }") + ); + + Debug.WriteLine("subscribing..."); + var tester = observable.SubscribeTester(); + tester.ShouldHaveReceivedUpdate(gqlResponse => + { + Assert.Single(gqlResponse.Errors); + }); + tester.ShouldHaveCompleted(); + + client.Dispose(); + } + } + + [Fact] + public async void CanHandleQueryErrorInSubscription() + { + var port = NetworkHelpers.GetFreeTcpPortNumber(); + using (CreateServer(port)) + { + var client = GetGraphQLClient(port); + Debug.WriteLine("creating subscription stream"); + IObservable observable = client.CreateSubscriptionStream( + new GraphQLRequest(@" + subscription { + fieldDoesNotExist { + content + } + }") + ); + + Debug.WriteLine("subscribing..."); + var tester = observable.SubscribeTester(); + tester.ShouldHaveReceivedUpdate(gqlResponse => + { + Assert.Single(gqlResponse.Errors); + }); + tester.ShouldHaveCompleted(); + + client.Dispose(); + } + } } } diff --git a/tests/IntegrationTestServer/ChatSchema/ChatSubscriptions.cs b/tests/IntegrationTestServer/ChatSchema/ChatSubscriptions.cs index 0ceb77a5..178d2b78 100644 --- a/tests/IntegrationTestServer/ChatSchema/ChatSubscriptions.cs +++ b/tests/IntegrationTestServer/ChatSchema/ChatSubscriptions.cs @@ -51,6 +51,15 @@ public ChatSubscriptions(IChat chat) Resolver = new FuncFieldResolver(context => context.Source as MessageFrom), Subscriber = new EventStreamResolver(context => _chat.UserJoined()) }); + + + AddField(new EventStreamFieldType + { + Name = "failImmediately", + Type = typeof(MessageType), + Resolver = new FuncFieldResolver(ResolveMessage), + Subscriber = new EventStreamResolver(context => throw new NotSupportedException("this is supposed to fail")) + }); } private IObservable SubscribeById(ResolveEventStreamContext context)