diff --git a/root.props b/root.props
index 61db8f41..df16b57f 100644
--- a/root.props
+++ b/root.props
@@ -17,7 +17,7 @@
git
https://github.com/graphql-dotnet/graphql-client.git
True
- 2.0.0-alpha.4.subscription-api.8
+ 2.0.0-alpha.4.subscription-api.10
4
diff --git a/src/GraphQL.Client/Http/GraphQLHttpSubscriptionHelpers.cs b/src/GraphQL.Client/Http/GraphQLHttpSubscriptionHelpers.cs
index ef71a0ce..4ba91205 100644
--- a/src/GraphQL.Client/Http/GraphQLHttpSubscriptionHelpers.cs
+++ b/src/GraphQL.Client/Http/GraphQLHttpSubscriptionHelpers.cs
@@ -3,6 +3,7 @@
using System.Net.WebSockets;
using System.Reactive.Disposables;
using System.Reactive.Linq;
+using System.Reactive.Threading.Tasks;
using System.Threading;
using System.Threading.Tasks;
using GraphQL.Common;
@@ -35,31 +36,39 @@ internal static IObservable CreateSubscriptionStream(
Id = startRequest.Id,
Type = GQLWebSocketMessageType.GQL_STOP
};
- var observable = graphQlHttpWebSocket.ResponseStream
- .Where(response => {
- return response != null && response.Id == startRequest.Id;
- })
- .SelectMany(response =>
- {
- switch (response.Type)
+
+ var observable = Observable.Create(o =>
+ graphQlHttpWebSocket.ResponseStream.Subscribe(response =>
{
- case GQLWebSocketMessageType.GQL_COMPLETE:
+ // ignore null values and messages for other requests
+ if (response == null || response.Id != startRequest.Id) return;
+
+ // terminate the sequence when a 'complete' message is received
+ if (response.Type == GQLWebSocketMessageType.GQL_COMPLETE)
+ {
Debug.WriteLine($"received 'complete' message on subscription {startRequest.Id}");
- return Observable.Empty();
- case GQLWebSocketMessageType.GQL_ERROR:
- Debug.WriteLine($"received 'error' message on subscription {startRequest.Id}");
- return Observable.Throw(
- new GraphQLSubscriptionException(response.Payload));
- default:
- Debug.WriteLine($"received payload on subscription {startRequest.Id}");
- return Observable.Return(((JObject) response?.Payload)
- ?.ToObject());
- }
- });
+ o.OnCompleted();
+ return;
+ }
+
+ // post the GraphQLResponse to the stream (even if a GraphQL error occurred)
+ Debug.WriteLine($"received payload on subscription {startRequest.Id}");
+ o.OnNext(((JObject)response.Payload)?.ToObject());
+
+ // in case of a GraphQL error, terminate the sequence after the response has been posted
+ if (response.Type == GQLWebSocketMessageType.GQL_ERROR)
+ {
+ Debug.WriteLine($"terminating subscription {startRequest.Id} because of a GraphQL error");
+ o.OnCompleted();
+ }
+ },
+ o.OnError,
+ o.OnCompleted)
+ );
try
{
- // intialize websocket (completes immediately if socket is already open)
+ // initialize websocket (completes immediately if socket is already open)
await graphQlHttpWebSocket.InitializeWebSocket().ConfigureAwait(false);
}
catch (Exception e)
@@ -148,12 +157,12 @@ internal static IObservable CreateSubscriptionStream(
.Publish().RefCount();
}
- internal static async Task Request(
+ internal static Task Request(
this GraphQLHttpWebSocket graphQlHttpWebSocket,
GraphQLRequest request,
CancellationToken cancellationToken = default)
{
- return await Observable.Create(async observer =>
+ return Observable.Create(async observer =>
{
var websocketRequest = new GraphQLWebSocketRequest
{
@@ -162,25 +171,12 @@ internal static async Task Request(
Payload = request
};
var observable = graphQlHttpWebSocket.ResponseStream
- .Where(response => {
- return response != null && response.Id == websocketRequest.Id;
- })
- .SelectMany(response =>
+ .Where(response => response != null && response.Id == websocketRequest.Id)
+ .TakeUntil(response => response.Type == GQLWebSocketMessageType.GQL_COMPLETE)
+ .Select(response =>
{
- switch (response.Type)
- {
- case GQLWebSocketMessageType.GQL_COMPLETE:
- Debug.WriteLine($"received 'complete' message on request {websocketRequest.Id}");
- return Observable.Empty();
- case GQLWebSocketMessageType.GQL_ERROR:
- Debug.WriteLine($"received 'error' message on request {websocketRequest.Id}");
- return Observable.Throw(
- new GraphQLSubscriptionException(response.Payload));
- default:
- Debug.WriteLine($"received response for request {websocketRequest.Id}");
- return Observable.Return(((JObject)response?.Payload)
- ?.ToObject());
- }
+ Debug.WriteLine($"received response for request {websocketRequest.Id}");
+ return ((JObject) response?.Payload)?.ToObject();
});
try
@@ -215,7 +211,8 @@ internal static async Task Request(
// complete sequence on OperationCanceledException, this is triggered by the cancellation token
.Catch(exception =>
Observable.Empty())
- .FirstOrDefaultAsync();
+ .FirstOrDefaultAsync()
+ .ToTask(cancellationToken);
}
}
}
diff --git a/src/GraphQL.Client/Http/GraphQLHttpWebSocket.cs b/src/GraphQL.Client/Http/GraphQLHttpWebSocket.cs
index 28ae1e7b..eb2881ba 100644
--- a/src/GraphQL.Client/Http/GraphQLHttpWebSocket.cs
+++ b/src/GraphQL.Client/Http/GraphQLHttpWebSocket.cs
@@ -65,7 +65,7 @@ private async Task _sendWebSocketRequest(GraphQLWebSocketRequest request)
}
await InitializeWebSocket().ConfigureAwait(false);
- var webSocketRequestString = JsonConvert.SerializeObject(request);
+ var webSocketRequestString = JsonConvert.SerializeObject(request, _options.JsonSerializerSettings);
await this.clientWebSocket.SendAsync(
new ArraySegment(Encoding.UTF8.GetBytes(webSocketRequestString)),
WebSocketMessageType.Text,
@@ -245,7 +245,7 @@ private async Task _receiveResultAsync()
{
var stringResult = await reader.ReadToEndAsync();
Debug.WriteLine($"data received on websocket {clientWebSocket.GetHashCode()}: {stringResult}");
- return JsonConvert.DeserializeObject(stringResult);
+ return JsonConvert.DeserializeObject(stringResult, _options.JsonSerializerSettings);
}
}
else
diff --git a/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs b/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs
index 635a0d1c..fb704ad4 100644
--- a/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs
+++ b/tests/GraphQL.Integration.Tests/SubscriptionsTest.cs
@@ -43,10 +43,11 @@ public SubscriptionsTest()
{
}
- private GraphQLHttpClient GetGraphQLClient(int port)
+ private GraphQLHttpClient GetGraphQLClient(int port, bool requestsViaWebsocket = false)
=> new GraphQLHttpClient(new GraphQLHttpClientOptions
{
EndPoint = new Uri($"http://localhost:{port}/graphql"),
+ UseWebSocketForQueriesAndMutations = requestsViaWebsocket
});
@@ -65,6 +66,34 @@ public async void AssertTestingHarness()
}
}
+ [Fact]
+ public async void CanSendRequestViaWebsocket()
+ {
+ var port = NetworkHelpers.GetFreeTcpPortNumber();
+ using (CreateServer(port))
+ {
+ var client = GetGraphQLClient(port, true);
+ const string message = "some random testing message";
+ var response = await client.AddMessageAsync(message).ConfigureAwait(false);
+
+ Assert.Equal(message, (string)response.Data.addMessage.content);
+ }
+ }
+
+ [Fact]
+ public async void CanHandleRequestErrorViaWebsocket()
+ {
+ var port = NetworkHelpers.GetFreeTcpPortNumber();
+ using (CreateServer(port))
+ {
+ var client = GetGraphQLClient(port, true);
+ const string message = "some random testing message";
+ var response = await client.SendQueryAsync(new GraphQLRequest("this query is formatted quite badly")).ConfigureAwait(false);
+
+ Assert.Single(response.Errors);
+ }
+ }
+
private const string SubscriptionQuery = @"
subscription {
messageAdded{
@@ -272,5 +301,63 @@ public async void CanHandleConnectionTimeout()
tester.ShouldHaveCompleted(TimeSpan.FromSeconds(5));
server.Dispose();
}
+
+ [Fact]
+ public async void CanHandleSubscriptionError()
+ {
+ var port = NetworkHelpers.GetFreeTcpPortNumber();
+ using (CreateServer(port))
+ {
+ var client = GetGraphQLClient(port);
+ Debug.WriteLine("creating subscription stream");
+ IObservable observable = client.CreateSubscriptionStream(
+ new GraphQLRequest(@"
+ subscription {
+ failImmediately {
+ content
+ }
+ }")
+ );
+
+ Debug.WriteLine("subscribing...");
+ var tester = observable.SubscribeTester();
+ tester.ShouldHaveReceivedUpdate(gqlResponse =>
+ {
+ Assert.Single(gqlResponse.Errors);
+ });
+ tester.ShouldHaveCompleted();
+
+ client.Dispose();
+ }
+ }
+
+ [Fact]
+ public async void CanHandleQueryErrorInSubscription()
+ {
+ var port = NetworkHelpers.GetFreeTcpPortNumber();
+ using (CreateServer(port))
+ {
+ var client = GetGraphQLClient(port);
+ Debug.WriteLine("creating subscription stream");
+ IObservable observable = client.CreateSubscriptionStream(
+ new GraphQLRequest(@"
+ subscription {
+ fieldDoesNotExist {
+ content
+ }
+ }")
+ );
+
+ Debug.WriteLine("subscribing...");
+ var tester = observable.SubscribeTester();
+ tester.ShouldHaveReceivedUpdate(gqlResponse =>
+ {
+ Assert.Single(gqlResponse.Errors);
+ });
+ tester.ShouldHaveCompleted();
+
+ client.Dispose();
+ }
+ }
}
}
diff --git a/tests/IntegrationTestServer/ChatSchema/ChatSubscriptions.cs b/tests/IntegrationTestServer/ChatSchema/ChatSubscriptions.cs
index 0ceb77a5..178d2b78 100644
--- a/tests/IntegrationTestServer/ChatSchema/ChatSubscriptions.cs
+++ b/tests/IntegrationTestServer/ChatSchema/ChatSubscriptions.cs
@@ -51,6 +51,15 @@ public ChatSubscriptions(IChat chat)
Resolver = new FuncFieldResolver(context => context.Source as MessageFrom),
Subscriber = new EventStreamResolver(context => _chat.UserJoined())
});
+
+
+ AddField(new EventStreamFieldType
+ {
+ Name = "failImmediately",
+ Type = typeof(MessageType),
+ Resolver = new FuncFieldResolver(ResolveMessage),
+ Subscriber = new EventStreamResolver(context => throw new NotSupportedException("this is supposed to fail"))
+ });
}
private IObservable SubscribeById(ResolveEventStreamContext context)