diff --git a/src/GraphQL.Client.Abstractions.Websocket/GraphQLWebsocketConnectionState.cs b/src/GraphQL.Client.Abstractions.Websocket/GraphQLWebsocketConnectionState.cs new file mode 100644 index 00000000..3ab5a0e2 --- /dev/null +++ b/src/GraphQL.Client.Abstractions.Websocket/GraphQLWebsocketConnectionState.cs @@ -0,0 +1,7 @@ +namespace GraphQL.Client.Abstractions.Websocket { + public enum GraphQLWebsocketConnectionState { + Disconnected, + Connecting, + Connected + } +} diff --git a/src/GraphQL.Client/GraphQLHttpClient.cs b/src/GraphQL.Client/GraphQLHttpClient.cs index 465a0286..dbf4bcb7 100644 --- a/src/GraphQL.Client/GraphQLHttpClient.cs +++ b/src/GraphQL.Client/GraphQLHttpClient.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Concurrent; -using System.Linq; using System.Net.Http; using System.Text; using System.Threading; @@ -33,6 +32,11 @@ public class GraphQLHttpClient : IGraphQLClient { /// public IObservable WebSocketReceiveErrors => graphQlHttpWebSocket.ReceiveErrors; + /// + /// the websocket connection state + /// + public IObservable WebsocketConnectionState => + graphQlHttpWebSocket.ConnectionState; #region Constructors @@ -47,7 +51,7 @@ public GraphQLHttpClient(Action configure) : this(conf public GraphQLHttpClient(GraphQLHttpClientOptions options, HttpClient httpClient) { Options = options; this.HttpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); - this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options); + this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), this); Options.JsonSerializer = JsonSerializer.EnsureAssigned(); } @@ -55,7 +59,7 @@ public GraphQLHttpClient(GraphQLHttpClientOptions options, HttpClient httpClient Options = options ?? throw new ArgumentNullException(nameof(options)); Options.JsonSerializer = serializer ?? throw new ArgumentNullException(nameof(serializer)); this.HttpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); - this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), Options); + this.graphQlHttpWebSocket = new GraphQLHttpWebSocket(GetWebSocketUri(), this); } #endregion diff --git a/src/GraphQL.Client/GraphQLHttpClientOptions.cs b/src/GraphQL.Client/GraphQLHttpClientOptions.cs index 8a9afae8..a52c772c 100644 --- a/src/GraphQL.Client/GraphQLHttpClientOptions.cs +++ b/src/GraphQL.Client/GraphQLHttpClientOptions.cs @@ -49,5 +49,10 @@ public class GraphQLHttpClientOptions { /// Request preprocessing function. Can be used i.e. to inject authorization info into a GraphQL request payload. /// public Func> PreprocessRequest { get; set; } = (request, client) => Task.FromResult(request); + + /// + /// This callback is called after successfully establishing a websocket connection but before any regular request is made. + /// + public Func OnWebsocketConnected { get; set; } = client => Task.CompletedTask; } } diff --git a/src/GraphQL.Client/Websocket/GraphQLHttpWebSocket.cs b/src/GraphQL.Client/Websocket/GraphQLHttpWebSocket.cs index 579be9e7..832fb4cf 100644 --- a/src/GraphQL.Client/Websocket/GraphQLHttpWebSocket.cs +++ b/src/GraphQL.Client/Websocket/GraphQLHttpWebSocket.cs @@ -14,57 +14,63 @@ namespace GraphQL.Client.Http.Websocket { internal class GraphQLHttpWebSocket : IDisposable { private readonly Uri webSocketUri; - private readonly GraphQLHttpClientOptions _options; + private readonly GraphQLHttpClient client; private readonly ArraySegment buffer; - private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource(); + private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + private readonly Subject requestSubject = new Subject(); + private readonly Subject exceptionSubject = new Subject(); + private readonly BehaviorSubject stateSubject = + new BehaviorSubject(GraphQLWebsocketConnectionState.Disconnected); + private readonly IDisposable requestSubscription; - private Subject _responseSubject; - private readonly Subject _requestSubject = new Subject(); - private readonly Subject _exceptionSubject = new Subject(); - private IDisposable _requestSubscription; - - public WebSocketState WebSocketState => clientWebSocket?.State ?? WebSocketState.None; + private int connectionAttempt = 0; + private Subject responseSubject; + private GraphQLHttpClientOptions Options => client.Options; #if NETFRAMEWORK private WebSocket clientWebSocket = null; #else private ClientWebSocket clientWebSocket = null; #endif - private int _connectionAttempt = 0; - public GraphQLHttpWebSocket(Uri webSocketUri, GraphQLHttpClientOptions options) { + + public WebSocketState WebSocketState => clientWebSocket?.State ?? WebSocketState.None; + public IObservable ReceiveErrors => exceptionSubject.AsObservable(); + public IObservable ConnectionState => stateSubject.DistinctUntilChanged(); + + public IObservable ResponseStream { get; } + + public GraphQLHttpWebSocket(Uri webSocketUri, GraphQLHttpClient client) { this.webSocketUri = webSocketUri; - _options = options; + this.client = client; buffer = new ArraySegment(new byte[8192]); - _responseStream = _createResponseStream(); + ResponseStream = _createResponseStream(); - _requestSubscription = _requestSubject.Select(request => Observable.FromAsync(() => _sendWebSocketRequest(request))).Concat().Subscribe(); + requestSubscription = requestSubject.Select(request => Observable.FromAsync(() => _sendWebSocketRequest(request))).Concat().Subscribe(); } - public IObservable ReceiveErrors => _exceptionSubject.AsObservable(); - public IObservable ResponseStream => _responseStream; - public readonly IObservable _responseStream; + #region Send requests public Task SendWebSocketRequest(GraphQLWebSocketRequest request) { - _requestSubject.OnNext(request); + requestSubject.OnNext(request); return request.SendTask(); } private async Task _sendWebSocketRequest(GraphQLWebSocketRequest request) { try { - if (_cancellationTokenSource.Token.IsCancellationRequested) { + if (cancellationTokenSource.Token.IsCancellationRequested) { request.SendCanceled(); return; } await InitializeWebSocket().ConfigureAwait(false); - var requestBytes = _options.JsonSerializer.SerializeToBytes(request); + var requestBytes = Options.JsonSerializer.SerializeToBytes(request); await this.clientWebSocket.SendAsync( new ArraySegment(requestBytes), WebSocketMessageType.Text, true, - _cancellationTokenSource.Token).ConfigureAwait(false); + cancellationTokenSource.Token).ConfigureAwait(false); request.SendCompleted(); } catch (Exception e) { @@ -72,40 +78,28 @@ await this.clientWebSocket.SendAsync( } } - public Task InitializeWebSocketTask { get; private set; } = Task.CompletedTask; - - private readonly object _initializeLock = new object(); - -#region Private Methods - - private Task _backOff() { - _connectionAttempt++; - - if (_connectionAttempt == 1) return Task.CompletedTask; - - var delay = _options.BackOffStrategy(_connectionAttempt - 1); - Debug.WriteLine($"connection attempt #{_connectionAttempt}, backing off for {delay.TotalSeconds} s"); - return Task.Delay(delay); - } + #endregion + private Task initializeWebSocketTask = Task.CompletedTask; + private readonly object initializeLock = new object(); + public Task InitializeWebSocket() { // do not attempt to initialize if cancellation is requested - if (_disposed != null) + if (Completion != null) throw new OperationCanceledException(); - lock (_initializeLock) { + lock (initializeLock) { // if an initialization task is already running, return that - if (InitializeWebSocketTask != null && - !InitializeWebSocketTask.IsFaulted && - !InitializeWebSocketTask.IsCompleted) - return InitializeWebSocketTask; + if (initializeWebSocketTask != null && + !initializeWebSocketTask.IsFaulted && + !initializeWebSocketTask.IsCompleted) + return initializeWebSocketTask; // if the websocket is open, return a completed task if (clientWebSocket != null && clientWebSocket.State == WebSocketState.Open) return Task.CompletedTask; // else (re-)create websocket and connect - //_responseStreamConnection?.Dispose(); clientWebSocket?.Dispose(); #if NETFRAMEWORK @@ -115,13 +109,13 @@ public Task InitializeWebSocket() { switch (clientWebSocket) { case ClientWebSocket nativeWebSocket: nativeWebSocket.Options.AddSubProtocol("graphql-ws"); - nativeWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates; - nativeWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials; + nativeWebSocket.Options.ClientCertificates = ((HttpClientHandler)Options.HttpMessageHandler).ClientCertificates; + nativeWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)Options.HttpMessageHandler).UseDefaultCredentials; break; case System.Net.WebSockets.Managed.ClientWebSocket managedWebSocket: managedWebSocket.Options.AddSubProtocol("graphql-ws"); - managedWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates; - managedWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials; + managedWebSocket.Options.ClientCertificates = ((HttpClientHandler)Options.HttpMessageHandler).ClientCertificates; + managedWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)Options.HttpMessageHandler).UseDefaultCredentials; break; default: throw new NotSupportedException($"unknown websocket type {clientWebSocket.GetType().Name}"); @@ -129,13 +123,47 @@ public Task InitializeWebSocket() { #else clientWebSocket = new ClientWebSocket(); clientWebSocket.Options.AddSubProtocol("graphql-ws"); - clientWebSocket.Options.ClientCertificates = ((HttpClientHandler)_options.HttpMessageHandler).ClientCertificates; - clientWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)_options.HttpMessageHandler).UseDefaultCredentials; + clientWebSocket.Options.ClientCertificates = ((HttpClientHandler)Options.HttpMessageHandler).ClientCertificates; + clientWebSocket.Options.UseDefaultCredentials = ((HttpClientHandler)Options.HttpMessageHandler).UseDefaultCredentials; #endif - return InitializeWebSocketTask = _connectAsync(_cancellationTokenSource.Token); + return initializeWebSocketTask = _connectAsync(cancellationTokenSource.Token); } } + private async Task _connectAsync(CancellationToken token) { + try { + await _backOff().ConfigureAwait(false); + stateSubject.OnNext(GraphQLWebsocketConnectionState.Connecting); + Debug.WriteLine($"opening websocket {clientWebSocket.GetHashCode()}"); + await clientWebSocket.ConnectAsync(webSocketUri, token).ConfigureAwait(false); + stateSubject.OnNext(GraphQLWebsocketConnectionState.Connected); + Debug.WriteLine($"connection established on websocket {clientWebSocket.GetHashCode()}, invoking Options.OnWebsocketConnected()"); + await (Options.OnWebsocketConnected?.Invoke(client) ?? Task.CompletedTask); + Debug.WriteLine($"invoking Options.OnWebsocketConnected() on websocket {clientWebSocket.GetHashCode()}"); + connectionAttempt = 1; + } + catch (Exception e) { + stateSubject.OnNext(GraphQLWebsocketConnectionState.Disconnected); + exceptionSubject.OnNext(e); + throw; + } + } + + /// + /// delay the next connection attempt using + /// + /// + private Task _backOff() { + connectionAttempt++; + + if (connectionAttempt == 1) return Task.CompletedTask; + + var delay = Options.BackOffStrategy?.Invoke(connectionAttempt - 1) ?? TimeSpan.FromSeconds(5); + Debug.WriteLine($"connection attempt #{connectionAttempt}, backing off for {delay.TotalSeconds} s"); + return Task.Delay(delay); + } + + private IObservable _createResponseStream() { return Observable.Create(_createResultStream) // complete sequence on OperationCanceledException, this is triggered by the cancellation token on disposal @@ -144,68 +172,60 @@ private IObservable _createResponseStream() { } private async Task _createResultStream(IObserver observer, CancellationToken token) { - if (_responseSubject == null || _responseSubject.IsDisposed) { - _responseSubject = new Subject(); - var observable = await _getReceiveResultStream().ConfigureAwait(false); - observable.Subscribe(_responseSubject); - - _responseSubject.Subscribe(_ => { }, ex => { - _exceptionSubject.OnNext(ex); - _responseSubject?.Dispose(); - _responseSubject = null; + if (responseSubject == null || responseSubject.IsDisposed) { + // create new response subject + responseSubject = new Subject(); + + // initialize and connect websocket + await InitializeWebSocket().ConfigureAwait(false); + + // loop the receive task and subscribe the created subject to the results + Observable.Defer(() => _getReceiveTask().ToObservable()).Repeat().Subscribe(responseSubject); + + // dispose the subject on any error or completion (will be recreated) + responseSubject.Subscribe(_ => { }, ex => { + exceptionSubject.OnNext(ex); + responseSubject?.Dispose(); + responseSubject = null; + stateSubject.OnNext(GraphQLWebsocketConnectionState.Disconnected); }, () => { - _responseSubject?.Dispose(); - _responseSubject = null; + responseSubject?.Dispose(); + responseSubject = null; + stateSubject.OnNext(GraphQLWebsocketConnectionState.Disconnected); }); } return new CompositeDisposable ( - _responseSubject.Subscribe(observer), + responseSubject.Subscribe(observer), Disposable.Create(() => { Debug.WriteLine("response stream disposed"); }) ); } - private async Task> _getReceiveResultStream() { - await InitializeWebSocket().ConfigureAwait(false); - return Observable.Defer(() => _getReceiveTask().ToObservable()).Repeat(); - } - - private async Task _connectAsync(CancellationToken token) { - try { - await _backOff().ConfigureAwait(false); - Debug.WriteLine($"opening websocket {clientWebSocket.GetHashCode()}"); - await clientWebSocket.ConnectAsync(webSocketUri, token).ConfigureAwait(false); - Debug.WriteLine($"connection established on websocket {clientWebSocket.GetHashCode()}"); - _connectionAttempt = 1; - } - catch (Exception e) { - _exceptionSubject.OnNext(e); - throw; - } - } - - - private Task _receiveAsyncTask = null; - private readonly object _receiveTaskLocker = new object(); + private Task receiveAsyncTask = null; + private readonly object receiveTaskLocker = new object(); /// /// wrapper method to pick up the existing request task if already running /// /// private Task _getReceiveTask() { - lock (_receiveTaskLocker) { - if (_receiveAsyncTask == null || - _receiveAsyncTask.IsFaulted || - _receiveAsyncTask.IsCompleted) - _receiveAsyncTask = _receiveResultAsync(); + lock (receiveTaskLocker) { + if (receiveAsyncTask == null || + receiveAsyncTask.IsFaulted || + receiveAsyncTask.IsCompleted) + receiveAsyncTask = _receiveResultAsync(); } - return _receiveAsyncTask; + return receiveAsyncTask; } + /// + /// read a single message from the websocket + /// + /// private async Task _receiveResultAsync() { try { Debug.WriteLine($"receiving data on websocket {clientWebSocket.GetHashCode()} ..."); @@ -213,17 +233,17 @@ private async Task _receiveResultAsync() { using (var ms = new MemoryStream()) { WebSocketReceiveResult webSocketReceiveResult = null; do { - _cancellationTokenSource.Token.ThrowIfCancellationRequested(); + cancellationTokenSource.Token.ThrowIfCancellationRequested(); webSocketReceiveResult = await clientWebSocket.ReceiveAsync(buffer, CancellationToken.None); ms.Write(buffer.Array, buffer.Offset, webSocketReceiveResult.Count); } while (!webSocketReceiveResult.EndOfMessage); - _cancellationTokenSource.Token.ThrowIfCancellationRequested(); + cancellationTokenSource.Token.ThrowIfCancellationRequested(); ms.Seek(0, SeekOrigin.Begin); if (webSocketReceiveResult.MessageType == WebSocketMessageType.Text) { - var response = await _options.JsonSerializer.DeserializeToWebsocketResponseWrapperAsync(ms); + var response = await Options.JsonSerializer.DeserializeToWebsocketResponseWrapperAsync(ms); response.MessageBytes = ms.ToArray(); return response; } @@ -252,28 +272,36 @@ private async Task _closeAsync(CancellationToken cancellationToken = default) { Debug.WriteLine($"closing websocket {clientWebSocket.GetHashCode()}"); await this.clientWebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", cancellationToken).ConfigureAwait(false); + stateSubject.OnNext(GraphQLWebsocketConnectionState.Disconnected); } -#endregion - -#region IDisposable + #region IDisposable + public void Dispose() => Complete(); - private Task _disposed; - private object _disposedLocker = new object(); - public void Dispose() { - // Async disposal as recommended by Stephen Cleary (https://blog.stephencleary.com/2013/03/async-oop-6-disposal.html) - lock (_disposedLocker) { - if (_disposed == null) _disposed = DisposeAsync(); + /// + /// Cancels the current operation, closes the websocket connection and disposes of internal resources. + /// + public void Complete() { + lock (completedLocker) { + if (Completion == null) Completion = CompleteAsync(); } } - private async Task DisposeAsync() { + /// + /// Task to await the completion (a.k.a. disposal) of this websocket. + /// + /// Async disposal as recommended by Stephen Cleary (https://blog.stephencleary.com/2013/03/async-oop-6-disposal.html) + public Task Completion { get; private set; } + + private readonly object completedLocker = new object(); + private async Task CompleteAsync() { Debug.WriteLine($"disposing websocket {clientWebSocket.GetHashCode()}..."); - if (!_cancellationTokenSource.IsCancellationRequested) - _cancellationTokenSource.Cancel(); + if (!cancellationTokenSource.IsCancellationRequested) + cancellationTokenSource.Cancel(); await _closeAsync().ConfigureAwait(false); + requestSubscription?.Dispose(); clientWebSocket?.Dispose(); - _cancellationTokenSource.Dispose(); + cancellationTokenSource.Dispose(); Debug.WriteLine($"websocket {clientWebSocket.GetHashCode()} disposed"); } #endregion diff --git a/tests/GraphQL.Client.Serializer.Tests/GraphQL.Client.Serializer.Tests.csproj b/tests/GraphQL.Client.Serializer.Tests/GraphQL.Client.Serializer.Tests.csproj index 583a8e82..b61c1466 100644 --- a/tests/GraphQL.Client.Serializer.Tests/GraphQL.Client.Serializer.Tests.csproj +++ b/tests/GraphQL.Client.Serializer.Tests/GraphQL.Client.Serializer.Tests.csproj @@ -2,8 +2,7 @@ netcoreapp3.1 - - false + false diff --git a/tests/GraphQL.Client.Tests.Common/GraphQL.Client.Tests.Common.csproj b/tests/GraphQL.Client.Tests.Common/GraphQL.Client.Tests.Common.csproj index 8526c03a..dfbb9146 100644 --- a/tests/GraphQL.Client.Tests.Common/GraphQL.Client.Tests.Common.csproj +++ b/tests/GraphQL.Client.Tests.Common/GraphQL.Client.Tests.Common.csproj @@ -2,6 +2,7 @@ netstandard2.0 + false @@ -16,6 +17,7 @@ + diff --git a/tests/GraphQL.Client.Tests.Common/Helpers/CallbackMonitor.cs b/tests/GraphQL.Client.Tests.Common/Helpers/CallbackMonitor.cs new file mode 100644 index 00000000..4b623ac6 --- /dev/null +++ b/tests/GraphQL.Client.Tests.Common/Helpers/CallbackMonitor.cs @@ -0,0 +1,115 @@ +using System; +using System.Threading; +using FluentAssertions; +using FluentAssertions.Execution; +using FluentAssertions.Primitives; + +namespace GraphQL.Client.Tests.Common.Helpers { + public class CallbackMonitor { + private readonly ManualResetEventSlim callbackInvoked = new ManualResetEventSlim(); + + /// + /// The timeout for . Defaults to 1 s + /// + public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(1); + + /// + /// Indicates that an update has been received since the last + /// + public bool CallbackInvoked => callbackInvoked.IsSet; + /// + /// The last payload which was received. + /// + public T LastPayload { get; private set; } + + public void Invoke(T param) { + LastPayload = param; + callbackInvoked.Set(); + } + + /// + /// Asserts that a new update has been pushed to the within the configured since the last . + /// If supplied, the action is executed on the submitted payload. + /// + /// action to assert the contents of the payload + public void CallbackShouldHaveBeenInvoked(Action assertPayload = null, TimeSpan? timeout = null) { + try { + callbackInvoked.Wait(timeout ?? Timeout).Should().BeTrue("because the callback method should have been invoked (timeout: {0} s)", + (timeout ?? Timeout).TotalSeconds); + + assertPayload?.Invoke(LastPayload); + } + finally { + Reset(); + } + } + + /// + /// Asserts that no new update has been pushed within the given since the last + /// + /// the time in ms in which no new update must be pushed to the . defaults to 100 + public void CallbackShouldNotHaveBeenInvoked(TimeSpan? timeout = null) { + if (!timeout.HasValue) timeout = TimeSpan.FromMilliseconds(100); + try { + callbackInvoked.Wait(timeout.Value).Should().BeFalse("because the callback method should not have been invoked"); + } + finally { + Reset(); + } + } + + /// + /// Resets the tester class. Should be called before triggering the potential update + /// + public void Reset() { + LastPayload = default(T); + callbackInvoked.Reset(); + } + + + public CallbackAssertions Should() { + return new CallbackAssertions(this); + } + + public class CallbackAssertions : ReferenceTypeAssertions, CallbackAssertions> { + public CallbackAssertions(CallbackMonitor tester) { + Subject = tester; + } + + protected override string Identifier => "callback"; + + public AndWhichConstraint, TPayload> HaveBeenInvokedWithPayload(TimeSpan timeout, + string because = "", params object[] becauseArgs) { + Execute.Assertion + .BecauseOf(because, becauseArgs) + .Given(() => Subject.callbackInvoked.Wait(timeout)) + .ForCondition(isSet => isSet) + .FailWith("Expected {context:callback} to be invoked{reason}, but did not receive a call within {0}", timeout); + + Subject.callbackInvoked.Reset(); + return new AndWhichConstraint, TPayload>(this, Subject.LastPayload); + } + public AndWhichConstraint, TPayload> HaveBeenInvokedWithPayload(string because = "", params object[] becauseArgs) + => HaveBeenInvokedWithPayload(Subject.Timeout, because, becauseArgs); + + public AndConstraint> HaveBeenInvoked(TimeSpan timeout, string because = "", params object[] becauseArgs) + => HaveBeenInvokedWithPayload(timeout, because, becauseArgs); + public AndConstraint> HaveBeenInvoked(string because = "", params object[] becauseArgs) + => HaveBeenInvokedWithPayload(Subject.Timeout, because, becauseArgs); + + public AndConstraint> NotHaveBeenInvoked(TimeSpan timeout, + string because = "", params object[] becauseArgs) { + Execute.Assertion + .BecauseOf(because, becauseArgs) + .Given(() => Subject.callbackInvoked.Wait(timeout)) + .ForCondition(isSet => !isSet) + .FailWith("Expected {context:callback} to not be invoked{reason}, but did receive a call: {0}", Subject.LastPayload); + + Subject.callbackInvoked.Reset(); + return new AndConstraint>(this); + } + public AndConstraint> NotHaveBeenInvoked(string because = "", params object[] becauseArgs) + => NotHaveBeenInvoked(TimeSpan.FromMilliseconds(100), because, becauseArgs); + } + } +} diff --git a/tests/GraphQL.Client.Tests.Common/Helpers/CallbackTester.cs b/tests/GraphQL.Client.Tests.Common/Helpers/CallbackTester.cs deleted file mode 100644 index c8ca29c5..00000000 --- a/tests/GraphQL.Client.Tests.Common/Helpers/CallbackTester.cs +++ /dev/null @@ -1,67 +0,0 @@ -using System; -using System.Threading; -using FluentAssertions; - -namespace GraphQL.Client.Tests.Common.Helpers { - public class CallbackTester { - private ManualResetEventSlim _callbackInvoked { get; } = new ManualResetEventSlim(); - - /// - /// The timeout for . Defaults to 1 s - /// - public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(1); - - /// - /// Indicates that an update has been received since the last - /// - public bool CallbackInvoked => _callbackInvoked.IsSet; - /// - /// The last payload which was received. - /// - public T LastPayload { get; private set; } - - public void Callback(T param) { - LastPayload = param; - _callbackInvoked.Set(); - } - - /// - /// Asserts that a new update has been pushed to the within the configured since the last . - /// If supplied, the action is executed on the submitted payload. - /// - /// action to assert the contents of the payload - public void CallbackShouldHaveBeenInvoked(Action assertPayload = null, TimeSpan? timeout = null) { - try { - _callbackInvoked.Wait(timeout ?? Timeout).Should().BeTrue("because the callback method should have been invoked (timeout: {0} s)", - (timeout ?? Timeout).TotalSeconds); - - assertPayload?.Invoke(LastPayload); - } - finally { - Reset(); - } - } - - /// - /// Asserts that no new update has been pushed within the given since the last - /// - /// the time in ms in which no new update must be pushed to the . defaults to 100 - public void CallbackShouldNotHaveBeenInvoked(TimeSpan? timeout = null) { - if (!timeout.HasValue) timeout = TimeSpan.FromMilliseconds(100); - try { - _callbackInvoked.Wait(timeout.Value).Should().BeFalse("because the callback method should not have been invoked"); - } - finally { - Reset(); - } - } - - /// - /// Resets the tester class. Should be called before triggering the potential update - /// - public void Reset() { - LastPayload = default(T); - _callbackInvoked.Reset(); - } - } -} diff --git a/tests/GraphQL.Client.Tests.Common/Helpers/MiscellaneousExtensions.cs b/tests/GraphQL.Client.Tests.Common/Helpers/MiscellaneousExtensions.cs index 2da34009..0e254672 100644 --- a/tests/GraphQL.Client.Tests.Common/Helpers/MiscellaneousExtensions.cs +++ b/tests/GraphQL.Client.Tests.Common/Helpers/MiscellaneousExtensions.cs @@ -1,4 +1,6 @@ using System.Linq; +using System.Threading.Tasks; +using GraphQL.Client.Http; namespace GraphQL.Client.Tests.Common.Helpers { public static class MiscellaneousExtensions { @@ -7,5 +9,17 @@ public static string RemoveWhitespace(this string input) { .Where(c => !char.IsWhiteSpace(c)) .ToArray()); } + + public static CallbackMonitor ConfigureMonitorForOnWebsocketConnected( + this GraphQLHttpClient client) { + var tester = new CallbackMonitor(); + client.Options.OnWebsocketConnected = c => { + tester.Invoke(c); + return Task.CompletedTask; + }; + return tester; + } + + } } diff --git a/tests/GraphQL.Integration.Tests/QueryAndMutationTests/Base.cs b/tests/GraphQL.Integration.Tests/QueryAndMutationTests/Base.cs index be153ba2..16e73c92 100644 --- a/tests/GraphQL.Integration.Tests/QueryAndMutationTests/Base.cs +++ b/tests/GraphQL.Integration.Tests/QueryAndMutationTests/Base.cs @@ -139,9 +139,9 @@ query Human($id: String!){ [Fact] public async void PreprocessHttpRequestMessageIsCalled() { - var callbackTester = new CallbackTester(); + var callbackTester = new CallbackMonitor(); var graphQLRequest = new GraphQLHttpRequest($"{{ human(id: \"1\") {{ name }} }}") { - PreprocessHttpRequestMessage = callbackTester.Callback + PreprocessHttpRequestMessage = callbackTester.Invoke }; using (var setup = SetupTest()) { diff --git a/tests/GraphQL.Integration.Tests/WebsocketTests/Base.cs b/tests/GraphQL.Integration.Tests/WebsocketTests/Base.cs index ff1cabe1..b17ba3c0 100644 --- a/tests/GraphQL.Integration.Tests/WebsocketTests/Base.cs +++ b/tests/GraphQL.Integration.Tests/WebsocketTests/Base.cs @@ -16,24 +16,20 @@ namespace GraphQL.Integration.Tests.WebsocketTests { public abstract class Base { - protected readonly ITestOutputHelper output; - protected readonly IGraphQLWebsocketJsonSerializer serializer; + protected readonly ITestOutputHelper Output; + protected readonly IGraphQLWebsocketJsonSerializer Serializer; protected IWebHost CreateServer(int port) => WebHostHelpers.CreateServer(port); - public Base(ITestOutputHelper output, IGraphQLWebsocketJsonSerializer serializer) { - this.output = output; - this.serializer = serializer; - } - - public Base(ITestOutputHelper output) { - this.output = output; + protected Base(ITestOutputHelper output, IGraphQLWebsocketJsonSerializer serializer) { + this.Output = output; + this.Serializer = serializer; } [Fact] public async void AssertTestingHarness() { var port = NetworkHelpers.GetFreeTcpPortNumber(); using (CreateServer(port)) { - var client = WebHostHelpers.GetGraphQLClient(port, serializer: serializer); + var client = WebHostHelpers.GetGraphQLClient(port, serializer: Serializer); const string message = "some random testing message"; var response = await client.AddMessageAsync(message).ConfigureAwait(false); @@ -47,7 +43,7 @@ public async void AssertTestingHarness() { public async void CanSendRequestViaWebsocket() { var port = NetworkHelpers.GetFreeTcpPortNumber(); using (CreateServer(port)) { - var client = WebHostHelpers.GetGraphQLClient(port, true, serializer); + var client = WebHostHelpers.GetGraphQLClient(port, true, Serializer); const string message = "some random testing message"; var response = await client.AddMessageAsync(message).ConfigureAwait(false); @@ -59,7 +55,7 @@ public async void CanSendRequestViaWebsocket() { public async void CanHandleRequestErrorViaWebsocket() { var port = NetworkHelpers.GetFreeTcpPortNumber(); using (CreateServer(port)) { - var client = WebHostHelpers.GetGraphQLClient(port, true, serializer); + var client = WebHostHelpers.GetGraphQLClient(port, true, Serializer); var response = await client.SendQueryAsync("this query is formatted quite badly").ConfigureAwait(false); Assert.Single(response.Errors); @@ -79,9 +75,11 @@ public async void CanHandleRequestErrorViaWebsocket() { [Fact] public async void CanCreateObservableSubscription() { var port = NetworkHelpers.GetFreeTcpPortNumber(); - using (CreateServer(port)) { - var client = WebHostHelpers.GetGraphQLClient(port, serializer: serializer); + using (CreateServer(port)){ + var client = WebHostHelpers.GetGraphQLClient(port, serializer: Serializer); + var callbackMonitor = client.ConfigureMonitorForOnWebsocketConnected(); await client.InitializeWebsocketConnection(); + callbackMonitor.Should().HaveBeenInvokedWithPayload(); Debug.WriteLine("creating subscription stream"); IObservable> observable = client.CreateSubscriptionStream(SubscriptionRequest); @@ -121,14 +119,15 @@ public class MessageAddedContent { public async void CanReconnectWithSameObservable() { var port = NetworkHelpers.GetFreeTcpPortNumber(); using (CreateServer(port)) { - var client = WebHostHelpers.GetGraphQLClient(port, serializer: serializer); - await client.InitializeWebsocketConnection(); + var client = WebHostHelpers.GetGraphQLClient(port, serializer: Serializer); + var callbackMonitor = client.ConfigureMonitorForOnWebsocketConnected(); Debug.WriteLine("creating subscription stream"); - IObservable> observable = client.CreateSubscriptionStream(SubscriptionRequest); + var observable = client.CreateSubscriptionStream(SubscriptionRequest); Debug.WriteLine("subscribing..."); var tester = observable.Monitor(); + callbackMonitor.Should().HaveBeenInvokedWithPayload(); const string message1 = "Hello World"; var response = await client.AddMessageAsync(message1).ConfigureAwait(false); @@ -143,9 +142,7 @@ public async void CanReconnectWithSameObservable() { .Which.Data.MessageAdded.Content.Should().Be(message2); Debug.WriteLine("disposing subscription..."); - tester.Dispose(); - await Task.Delay(500); - await client.InitializeWebsocketConnection(); + tester.Dispose(); // does not close the websocket connection Debug.WriteLine("creating new subscription..."); tester = observable.Monitor(); @@ -188,17 +185,19 @@ public class UserJoinedContent { [Fact] public async void CanConnectTwoSubscriptionsSimultaneously() { var port = NetworkHelpers.GetFreeTcpPortNumber(); - var callbackTester = new CallbackTester(); - var callbackTester2 = new CallbackTester(); + var callbackTester = new CallbackMonitor(); + var callbackTester2 = new CallbackMonitor(); using (CreateServer(port)) { - var client = WebHostHelpers.GetGraphQLClient(port, serializer: serializer); + var client = WebHostHelpers.GetGraphQLClient(port, serializer: Serializer); + var callbackMonitor = client.ConfigureMonitorForOnWebsocketConnected(); await client.InitializeWebsocketConnection(); + callbackMonitor.Should().HaveBeenInvokedWithPayload(); Debug.WriteLine("creating subscription stream"); IObservable> observable1 = - client.CreateSubscriptionStream(SubscriptionRequest, callbackTester.Callback); + client.CreateSubscriptionStream(SubscriptionRequest, callbackTester.Invoke); IObservable> observable2 = - client.CreateSubscriptionStream(SubscriptionRequest2, callbackTester2.Callback); + client.CreateSubscriptionStream(SubscriptionRequest2, callbackTester2.Invoke); Debug.WriteLine("subscribing..."); var tester = observable1.Monitor(); @@ -237,15 +236,31 @@ public async void CanConnectTwoSubscriptionsSimultaneously() { public async void CanHandleConnectionTimeout() { var port = NetworkHelpers.GetFreeTcpPortNumber(); var server = CreateServer(port); - var callbackTester = new CallbackTester(); + var errorMonitor = new CallbackMonitor(); + var reconnectBlocker = new ManualResetEventSlim(false); + + var client = WebHostHelpers.GetGraphQLClient(port, serializer: Serializer); + var callbackMonitor = client.ConfigureMonitorForOnWebsocketConnected(); + // configure back-off strategy to allow it to be controlled from within the unit test + client.Options.BackOffStrategy = i => { + reconnectBlocker.Wait(); + return TimeSpan.Zero; + }; + + var statusMonitor = client.WebsocketConnectionState.Monitor(); + statusMonitor.Should().HaveReceivedPayload().Which.Should() + .Be(GraphQLWebsocketConnectionState.Disconnected); - var client = WebHostHelpers.GetGraphQLClient(port, serializer: serializer); - await client.InitializeWebsocketConnection(); Debug.WriteLine("creating subscription stream"); - IObservable> observable = client.CreateSubscriptionStream(SubscriptionRequest, callbackTester.Callback); + IObservable> observable = client.CreateSubscriptionStream(SubscriptionRequest, errorMonitor.Invoke); Debug.WriteLine("subscribing..."); var tester = observable.Monitor(); + statusMonitor.Should().HaveReceivedPayload().Which.Should() + .Be(GraphQLWebsocketConnectionState.Connecting); + statusMonitor.Should().HaveReceivedPayload().Which.Should() + .Be(GraphQLWebsocketConnectionState.Connected); + callbackMonitor.Should().HaveBeenInvokedWithPayload(); const string message1 = "Hello World"; var response = await client.AddMessageAsync(message1).ConfigureAwait(false); @@ -255,18 +270,21 @@ public async void CanHandleConnectionTimeout() { Debug.WriteLine("stopping web host..."); await server.StopAsync(CancellationToken.None).ConfigureAwait(false); + server.Dispose(); Debug.WriteLine("web host stopped..."); - callbackTester.CallbackShouldHaveBeenInvoked(exception => { - Assert.IsType(exception); - }, TimeSpan.FromSeconds(10)); + errorMonitor.Should().HaveBeenInvokedWithPayload(TimeSpan.FromSeconds(10)) + .Which.Should().BeOfType(); + statusMonitor.Should().HaveReceivedPayload().Which.Should() + .Be(GraphQLWebsocketConnectionState.Disconnected); - try { - server.Start(); - } - catch (Exception e) { - output.WriteLine($"failed to restart server: {e}"); - } + server = CreateServer(port); + reconnectBlocker.Set(); + statusMonitor.Should().HaveReceivedPayload(TimeSpan.FromSeconds(10)).Which.Should() + .Be(GraphQLWebsocketConnectionState.Connecting); + statusMonitor.Should().HaveReceivedPayload(TimeSpan.FromSeconds(10)).Which.Should() + .Be(GraphQLWebsocketConnectionState.Connected); + callbackMonitor.Should().HaveBeenInvokedWithPayload(); // disposing the client should complete the subscription client.Dispose(); @@ -279,8 +297,10 @@ public async void CanHandleConnectionTimeout() { public async void CanHandleSubscriptionError() { var port = NetworkHelpers.GetFreeTcpPortNumber(); using (CreateServer(port)) { - var client = WebHostHelpers.GetGraphQLClient(port, serializer: serializer); + var client = WebHostHelpers.GetGraphQLClient(port, serializer: Serializer); + var callbackMonitor = client.ConfigureMonitorForOnWebsocketConnected(); await client.InitializeWebsocketConnection(); + callbackMonitor.Should().HaveBeenInvokedWithPayload(); Debug.WriteLine("creating subscription stream"); IObservable> observable = client.CreateSubscriptionStream( new GraphQLRequest(@" @@ -293,7 +313,7 @@ public async void CanHandleSubscriptionError() { Debug.WriteLine("subscribing..."); using (var tester = observable.Monitor()) { - tester.Should().HaveReceivedPayload() + tester.Should().HaveReceivedPayload(TimeSpan.FromSeconds(3)) .Which.Errors.Should().ContainSingle(); tester.Should().HaveCompleted(); client.Dispose(); @@ -309,8 +329,10 @@ public async void CanHandleQueryErrorInSubscription() { var test = new GraphQLRequest("tset", new { test = "blaa" }); - var client = WebHostHelpers.GetGraphQLClient(port, serializer: serializer); + var client = WebHostHelpers.GetGraphQLClient(port, serializer: Serializer); + var callbackMonitor = client.ConfigureMonitorForOnWebsocketConnected(); await client.InitializeWebsocketConnection(); + callbackMonitor.Should().HaveBeenInvokedWithPayload(); Debug.WriteLine("creating subscription stream"); IObservable> observable = client.CreateSubscriptionStream( new GraphQLRequest(@"