diff --git a/Sources/Examples/Echo/main.swift b/Sources/Examples/Echo/main.swift index 263a3c5c1..3e94b2561 100644 --- a/Sources/Examples/Echo/main.swift +++ b/Sources/Examples/Echo/main.swift @@ -82,9 +82,7 @@ func makeEchoClient(address: String, port: Int, ssl: Bool) -> Echo_EchoServiceCl eventLoopGroup: eventLoopGroup, tlsConfiguration: tlsConfiguration) - return try ClientConnection.start(configuration) - .map { Echo_EchoServiceClient(connection: $0) } - .wait() + return Echo_EchoServiceClient(connection: ClientConnection(configuration: configuration)) } catch { print("Unable to create an EchoClient: \(error)") return nil diff --git a/Sources/GRPC/ClientCalls/BaseClientCall.swift b/Sources/GRPC/ClientCalls/BaseClientCall.swift index 370d08c33..b7b6af793 100644 --- a/Sources/GRPC/ClientCalls/BaseClientCall.swift +++ b/Sources/GRPC/ClientCalls/BaseClientCall.swift @@ -108,8 +108,12 @@ open class BaseClientCall { /// Creates and configures an HTTP/2 stream channel. The `self.subchannel` future will hold the /// stream channel once it has been created. private func createStreamChannel() { - self.connection.channel.eventLoop.execute { - self.connection.multiplexer.createStreamChannel(promise: self.streamPromise) { (subchannel, streamID) -> EventLoopFuture in + self.connection.multiplexer.whenFailure { error in + self.streamPromise.fail(error) + } + + self.connection.multiplexer.whenSuccess { multiplexer in + multiplexer.createStreamChannel(promise: self.streamPromise) { (subchannel, streamID) -> EventLoopFuture in subchannel.pipeline.addHandlers( HTTP2ToHTTP1ClientCodec(streamID: streamID, httpProtocol: self.connection.configuration.httpProtocol), HTTP1ToRawGRPCClientCodec(), diff --git a/Sources/GRPC/ClientConnection.swift b/Sources/GRPC/ClientConnection.swift index 7ce6ba0a7..aa7869a66 100644 --- a/Sources/GRPC/ClientConnection.swift +++ b/Sources/GRPC/ClientConnection.swift @@ -53,139 +53,209 @@ import NIOTLS /// delegate associated with this connection (see `DelegatingErrorHandler`). /// /// See `BaseClientCall` for a description of the remainder of the client pipeline. -open class ClientConnection { - /// Makes and configures a `ClientBootstrap` using the provided configuration. - /// - /// Enables `SO_REUSEADDR` and `TCP_NODELAY` and configures the `channelInitializer` to use the - /// handlers detailed in the documentation for `ClientConnection`. - /// - /// - Parameter configuration: The configuration to prepare the bootstrap with. - public class func makeBootstrap(configuration: Configuration) -> ClientBootstrapProtocol { - let bootstrap = GRPCNIO.makeClientBootstrap(group: configuration.eventLoopGroup) - // Enable SO_REUSEADDR and TCP_NODELAY. - .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) - .channelInitializer { channel in - let tlsConfigured = configuration.tlsConfiguration.map { tlsConfiguration in - channel.configureTLS(tlsConfiguration, errorDelegate: configuration.errorDelegate) - } +public class ClientConnection { + /// The configuration this connection was created using. + internal let configuration: ClientConnection.Configuration - return (tlsConfigured ?? channel.eventLoop.makeSucceededFuture(())).flatMap { - channel.configureHTTP2Pipeline(mode: .client) - }.flatMap { _ in - let errorHandler = DelegatingErrorHandler(delegate: configuration.errorDelegate) - return channel.pipeline.addHandler(errorHandler) - } - } + /// The channel which will handle gRPC calls. + internal var channel: EventLoopFuture - return bootstrap - } + /// HTTP multiplexer from the `channel` handling gRPC calls. + internal var multiplexer: EventLoopFuture - /// Verifies that a TLS handshake was successful by using the `TLSVerificationHandler`. - /// - /// - Parameter channel: The channel to verify successful TLS setup on. - public class func verifyTLS(channel: Channel) -> EventLoopFuture { - return channel.pipeline.handler(type: TLSVerificationHandler.self).flatMap { - $0.verification + /// A monitor for the connectivity state. + public let connectivity: ConnectivityStateMonitor + + /// Creates a new connection from the given configuration. + public init(configuration: ClientConnection.Configuration) { + let monitor = ConnectivityStateMonitor(delegate: configuration.connectivityStateDelegate) + let channel = ClientConnection.makeChannel( + configuration: configuration, + connectivityMonitor: monitor + ) + + self.channel = channel + self.multiplexer = channel.flatMap { + $0.pipeline.handler(type: HTTP2StreamMultiplexer.self) + } + self.connectivity = monitor + self.configuration = configuration + + self.channel.whenSuccess { _ in + self.connectivity.state = .ready } + self.replaceChannelAndMultiplexerOnClose(channel: channel) } - /// Makes a `ClientConnection` from the given channel and configuration. - /// - /// - Parameter channel: The channel to use for the connection. - /// - Parameter configuration: The configuration used to create the channel. - public class func makeClientConnection( - channel: Channel, - configuration: Configuration - ) -> EventLoopFuture { - return channel.pipeline.handler(type: HTTP2StreamMultiplexer.self).map { multiplexer in - ClientConnection(channel: channel, multiplexer: multiplexer, configuration: configuration) + /// Registers a callback on the `closeFuture` of the given channel to replace this class's + /// channel and multiplexer. + private func replaceChannelAndMultiplexerOnClose(channel: EventLoopFuture) { + channel.always { result in + // If we failed to get a channel then we've exhausted our backoff; we should `.shutdown`. + if case .failure = result { + self.connectivity.state = .shutdown + } + }.flatMap { + $0.closeFuture + }.whenComplete { _ in + // `.shutdown` is terminal so don't attempt a reconnection. + guard self.connectivity.state != .shutdown else { + return + } + + let newChannel = ClientConnection.makeChannel( + configuration: self.configuration, + connectivityMonitor: self.connectivity + ) + + self.channel = newChannel + self.multiplexer = newChannel.flatMap { + $0.pipeline.handler(type: HTTP2StreamMultiplexer.self) + } + + // Change the state if the connection was successful. + newChannel.whenSuccess { _ in + self.connectivity.state = .ready + } + self.replaceChannelAndMultiplexerOnClose(channel: newChannel) } } - /// Starts a client connection using the given configuration. - /// - /// This involves: creating a `ClientBootstrap`, connecting to a target, verifying that the TLS - /// handshake was successful (if TLS was configured) and creating the `ClientConnection`. - /// See the individual functions for more information: - /// - `makeBootstrap(configuration:)`, - /// - `verifyTLS(channel:)`, and - /// - `makeClientConnection(channel:configuration:)`. - /// - /// - Parameter configuration: The configuration to start the connection with. - public class func start(_ configuration: Configuration) -> EventLoopFuture { - return start(configuration, backoffIterator: configuration.connectionBackoff?.makeIterator()) + /// The `EventLoop` this connection is using. + public var eventLoop: EventLoop { + return self.channel.eventLoop } - /// Starts a client connection using the given configuration and backoff. + /// Closes the connection to the server. + public func close() -> EventLoopFuture { + if self.connectivity.state == .shutdown { + // We're already shutdown or in the process of shutting down. + return channel.flatMap { $0.closeFuture } + } else { + self.connectivity.state = .shutdown + return channel.flatMap { $0.close() } + } + } +} + +extension ClientConnection { + /// Creates a `Channel` using the given configuration. /// - /// In addition to the steps taken in `start(configuration:)`, we _may_ additionally set a - /// connection timeout and schedule a retry attempt (should the connection fail) if a + /// This involves: creating a `ClientBootstrap`, connecting to a target and verifying that the TLS + /// handshake was successful (if TLS was configured). We _may_ additiionally set a connection + /// timeout and schedule a retry attempt (should the connection fail) if a /// `ConnectionBackoff.Iterator` is provided. /// + /// See the individual functions for more information: + /// - `makeBootstrap(configuration:)`, and + /// - `verifyTLS(channel:)`. + /// /// - Parameter configuration: The configuration to start the connection with. - /// - Parameter backoffIterator: A `ConnectionBackoff` iterator which generates connection - /// timeouts and backoffs to use when attempting to retry the connection. - internal class func start( - _ configuration: Configuration, + /// - Parameter connectivityMonitor: A connectivity state monitor. + /// - Parameter backoffIterator: An `Iterator` for `ConnectionBackoff` providing a sequence of + /// connection timeouts and backoff to use when attempting to create a connection. + private class func makeChannel( + configuration: ClientConnection.Configuration, + connectivityMonitor: ConnectivityStateMonitor, backoffIterator: ConnectionBackoff.Iterator? - ) -> EventLoopFuture { + ) -> EventLoopFuture { + connectivityMonitor.state = .connecting let timeoutAndBackoff = backoffIterator?.next() + var bootstrap = ClientConnection.makeBootstrap(configuration: configuration) - var bootstrap = makeBootstrap(configuration: configuration) // Set a timeout, if we have one. if let timeout = timeoutAndBackoff?.timeout { bootstrap = bootstrap.connectTimeout(.seconds(timeInterval: timeout)) } - let connection = bootstrap.connect(to: configuration.target) - .flatMap { channel -> EventLoopFuture in - let tlsVerified: EventLoopFuture? - if configuration.tlsConfiguration != nil { - tlsVerified = verifyTLS(channel: channel) - } else { - tlsVerified = nil - } - - return (tlsVerified ?? channel.eventLoop.makeSucceededFuture(())).flatMap { - makeClientConnection(channel: channel, configuration: configuration) - } + let channel = bootstrap.connect(to: configuration.target).flatMap { channel -> EventLoopFuture in + if configuration.tlsConfiguration != nil { + return ClientConnection.verifyTLS(channel: channel).map { channel } + } else { + return channel.eventLoop.makeSucceededFuture(channel) } + }.always { result in + switch result { + case .success: + // Update the state once the channel has been assigned, when it may be used for making + // RPCs. + break + + case .failure: + // We might try again in a moment. + connectivityMonitor.state = timeoutAndBackoff == nil ? .shutdown : .transientFailure + } + } guard let backoff = timeoutAndBackoff?.backoff else { - return connection + return channel } // If we're in error then schedule our next attempt. - return connection.flatMapError { error in + return channel.flatMapError { error in // The `futureResult` of the scheduled task is of type // `EventLoopFuture>`, so we need to `flatMap` it to // remove a level of indirection. - return connection.eventLoop.scheduleTask(in: .seconds(timeInterval: backoff)) { - return start(configuration, backoffIterator: backoffIterator) + return channel.eventLoop.scheduleTask(in: .seconds(timeInterval: backoff)) { + return makeChannel( + configuration: configuration, + connectivityMonitor: connectivityMonitor, + backoffIterator: backoffIterator + ) }.futureResult.flatMap { nextConnection in return nextConnection } } } - public let channel: Channel - public let multiplexer: HTTP2StreamMultiplexer - public let configuration: Configuration - - init(channel: Channel, multiplexer: HTTP2StreamMultiplexer, configuration: Configuration) { - self.channel = channel - self.multiplexer = multiplexer - self.configuration = configuration + /// Creates a `Channel` using the given configuration amd state connectivity monitor. + /// + /// See `makeChannel(configuration:connectivityMonitor:backoffIterator:)`. + private class func makeChannel( + configuration: ClientConnection.Configuration, + connectivityMonitor: ConnectivityStateMonitor + ) -> EventLoopFuture { + return makeChannel( + configuration: configuration, + connectivityMonitor: connectivityMonitor, + backoffIterator: configuration.connectionBackoff?.makeIterator() + ) } - /// Fired when the client shuts down. - public var onClose: EventLoopFuture { - return channel.closeFuture + /// Makes and configures a `ClientBootstrap` using the provided configuration. + /// + /// Enables `SO_REUSEADDR` and `TCP_NODELAY` and configures the `channelInitializer` to use the + /// handlers detailed in the documentation for `ClientConnection`. + /// + /// - Parameter configuration: The configuration to prepare the bootstrap with. + private class func makeBootstrap(configuration: Configuration) -> ClientBootstrapProtocol { + let bootstrap = GRPCNIO.makeClientBootstrap(group: configuration.eventLoopGroup) + // Enable SO_REUSEADDR and TCP_NODELAY. + .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .channelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) + .channelInitializer { channel in + let tlsConfigured = configuration.tlsConfiguration.map { tlsConfiguration in + channel.configureTLS(tlsConfiguration, errorDelegate: configuration.errorDelegate) + } + + return (tlsConfigured ?? channel.eventLoop.makeSucceededFuture(())).flatMap { + channel.configureHTTP2Pipeline(mode: .client) + }.flatMap { _ in + let errorHandler = DelegatingErrorHandler(delegate: configuration.errorDelegate) + return channel.pipeline.addHandler(errorHandler) + } + } + + return bootstrap } - public func close() -> EventLoopFuture { - return channel.close(mode: .all) + /// Verifies that a TLS handshake was successful by using the `TLSVerificationHandler`. + /// + /// - Parameter channel: The channel to verify successful TLS setup on. + private class func verifyTLS(channel: Channel) -> EventLoopFuture { + return channel.pipeline.handler(type: TLSVerificationHandler.self).flatMap { + $0.verification + } } } @@ -222,6 +292,9 @@ extension ClientConnection { /// cycle. public var errorDelegate: ClientErrorDelegate? + /// A delegate which is called when the connectivity state is changed. + public var connectivityStateDelegate: ConnectivityStateDelegate? + /// TLS configuration for this connection. `nil` if TLS is not desired. public var tlsConfiguration: TLSConfiguration? @@ -240,6 +313,7 @@ extension ClientConnection { /// - Parameter eventLoopGroup: The event loop group to run the connection on. /// - Parameter errorDelegate: The error delegate, defaulting to a delegate which will log only /// on debug builds. + /// - Parameter connectivityStateDelegate: A connectivity state delegate, defaulting to `nil`. /// - Parameter tlsConfiguration: TLS configuration, defaulting to `nil`. /// - Parameter connectionBackoff: The connection backoff configuration to use, defaulting /// to `nil`. @@ -247,12 +321,14 @@ extension ClientConnection { target: ConnectionTarget, eventLoopGroup: EventLoopGroup, errorDelegate: ClientErrorDelegate? = DebugOnlyLoggingClientErrorDelegate.shared, + connectivityStateDelegate: ConnectivityStateDelegate? = nil, tlsConfiguration: TLSConfiguration? = nil, connectionBackoff: ConnectionBackoff? = nil ) { self.target = target self.eventLoopGroup = eventLoopGroup self.errorDelegate = errorDelegate + self.connectivityStateDelegate = connectivityStateDelegate self.tlsConfiguration = tlsConfiguration self.connectionBackoff = connectionBackoff } @@ -309,8 +385,7 @@ fileprivate extension Channel { context: configuration.sslContext, serverHostname: configuration.hostnameOverride) - let verificationHandler = TLSVerificationHandler(errorDelegate: errorDelegate) - return self.pipeline.addHandlers(sslClientHandler, verificationHandler) + return self.pipeline.addHandlers(sslClientHandler, TLSVerificationHandler()) } catch { return self.eventLoop.makeFailedFuture(error) } diff --git a/Sources/GRPC/ConnectivityState.swift b/Sources/GRPC/ConnectivityState.swift new file mode 100644 index 000000000..2626410f9 --- /dev/null +++ b/Sources/GRPC/ConnectivityState.swift @@ -0,0 +1,133 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation + +/// The connectivity state of a client connection. Note that this is heavily lifted from the gRPC +/// documentation: https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md. +public enum ConnectivityState { + /// This is the state where the channel has not yet been created. + case idle + + /// The channel is trying to establish a connection and is waiting to make progress on one of the + /// steps involved in name resolution, TCP connection establishment or TLS handshake. + case connecting + + /// The channel has successfully established a connection all the way through TLS handshake (or + /// equivalent) and protocol-level (HTTP/2, etc) handshaking. + case ready + + /// There has been some transient failure (such as a TCP 3-way handshake timing out or a socket + /// error). Channels in this state will eventually switch to the `.connecting` state and try to + /// establish a connection again. Since retries are done with exponential backoff, channels that + /// fail to connect will start out spending very little time in this state but as the attempts + /// fail repeatedly, the channel will spend increasingly large amounts of time in this state. + case transientFailure + + /// This channel has started shutting down. Any new RPCs should fail immediately. Pending RPCs + /// may continue running till the application cancels them. Channels may enter this state either + /// because the application explicitly requested a shutdown or if a non-recoverable error has + /// happened during attempts to connect. Channels that have entered this state will never leave + /// this state. + case shutdown +} + +public protocol ConnectivityStateDelegate: class { + /// Called when a change in `ConnectivityState` has occurred. + /// + /// - Parameter oldState: The old connectivity state. + /// - Parameter newState: The new connectivity state. + func connectivityStateDidChange(from oldState: ConnectivityState, to newState: ConnectivityState) +} + +public class ConnectivityStateMonitor { + public typealias Callback = () -> Void + + private var idleCallback: Callback? + private var connectingCallback: Callback? + private var readyCallback: Callback? + private var transientFailureCallback: Callback? + private var shutdownCallback: Callback? + + /// A delegate to call when the connectivity state changes. + public var delegate: ConnectivityStateDelegate? + + /// The current state of connectivity. + public internal(set) var state: ConnectivityState { + didSet { + if oldValue != self.state { + self.delegate?.connectivityStateDidChange(from: oldValue, to: self.state) + self.triggerAndResetCallback() + } + } + } + + /// Creates a new connectivity state monitor. + /// + /// - Parameter delegate: A delegate to call when the connectivity state changes. + public init(delegate: ConnectivityStateDelegate?) { + self.delegate = delegate + self.state = .idle + } + + /// Registers a callback on the given state and calls it the next time that state is observed. + /// Subsequent transitions to that state will **not** trigger the callback. + /// + /// - Parameter state: The state on which to call the given callback. + /// - Parameter callback: The closure to call once the given state has been transitioned to. The + /// `callback` can be removed by passing in `nil`. + public func onNext(state: ConnectivityState, callback: Callback?) { + switch state { + case .idle: + self.idleCallback = callback + + case .connecting: + self.connectingCallback = callback + + case .ready: + self.readyCallback = callback + + case .transientFailure: + self.transientFailureCallback = callback + + case .shutdown: + self.shutdownCallback = callback + } + } + + private func triggerAndResetCallback() { + switch self.state { + case .idle: + self.idleCallback?() + self.idleCallback = nil + + case .connecting: + self.connectingCallback?() + self.connectingCallback = nil + + case .ready: + self.readyCallback?() + self.readyCallback = nil + + case .transientFailure: + self.transientFailureCallback?() + self.transientFailureCallback = nil + + case .shutdown: + self.shutdownCallback?() + self.shutdownCallback = nil + } + } +} diff --git a/Sources/GRPC/TLSVerificationHandler.swift b/Sources/GRPC/TLSVerificationHandler.swift index a5e8ce9dd..11956f095 100644 --- a/Sources/GRPC/TLSVerificationHandler.swift +++ b/Sources/GRPC/TLSVerificationHandler.swift @@ -25,7 +25,6 @@ public class TLSVerificationHandler: ChannelInboundHandler, RemovableChannelHand public typealias InboundIn = Any private var verificationPromise: EventLoopPromise! - private let delegate: ClientErrorDelegate? /// A future which is fulfilled when the state of the TLS handshake is known. If the handshake /// was successful and the negotiated application protocol is valid then the future is succeeded. @@ -38,26 +37,21 @@ public class TLSVerificationHandler: ChannelInboundHandler, RemovableChannelHand return verificationPromise.futureResult } - public init(errorDelegate: ClientErrorDelegate?) { - self.delegate = errorDelegate - } + public init() { } public func handlerAdded(context: ChannelHandlerContext) { self.verificationPromise = context.eventLoop.makePromise() // Remove ourselves from the pipeline when the promise gets fulfilled. - self.verificationPromise.futureResult.whenComplete { _ in + self.verificationPromise.futureResult.recover { error in + // If we have an error we should let the rest of the pipeline know. + context.fireErrorCaught(error) + }.whenComplete { _ in context.pipeline.removeHandler(self, promise: nil) } } public func errorCaught(context: ChannelHandlerContext, error: Error) { precondition(self.verificationPromise != nil, "handler has not been added to the pipeline") - - if let delegate = self.delegate { - let grpcError = (error as? GRPCError) ?? GRPCError.unknown(error, origin: .client) - delegate.didCatchError(grpcError.wrappedError, file: grpcError.file, line: grpcError.line) - } - verificationPromise.fail(error) } @@ -73,7 +67,8 @@ public class TLSVerificationHandler: ChannelInboundHandler, RemovableChannelHand if let proto = negotiatedProtocol, GRPCApplicationProtocolIdentifier(rawValue: proto) != nil { self.verificationPromise.succeed(()) } else { - self.verificationPromise.fail(GRPCError.client(.applicationLevelProtocolNegotiationFailed)) + let error = GRPCError.client(.applicationLevelProtocolNegotiationFailed) + self.verificationPromise.fail(error) } } } diff --git a/Sources/GRPCInteroperabilityTests/InteroperabilityTestCases.swift b/Sources/GRPCInteroperabilityTests/InteroperabilityTestCases.swift index 4b5270da4..8f9e33eca 100644 --- a/Sources/GRPCInteroperabilityTests/InteroperabilityTestCases.swift +++ b/Sources/GRPCInteroperabilityTests/InteroperabilityTestCases.swift @@ -639,7 +639,7 @@ class CancelAfterFirstResponse: InteroperabilityTest { func run(using connection: ClientConnection) throws { let client = Grpc_Testing_TestServiceServiceClient(connection: connection) - let promise = client.connection.channel.eventLoop.makePromise(of: Void.self) + let promise = client.connection.eventLoop.makePromise(of: Void.self) let call = client.fullDuplexCall { _ in promise.succeed(()) diff --git a/Sources/GRPCInteroperabilityTests/InteroperabilityTestClientConnection.swift b/Sources/GRPCInteroperabilityTests/InteroperabilityTestClientConnection.swift index f7e7e2e7e..4373c650e 100644 --- a/Sources/GRPCInteroperabilityTests/InteroperabilityTestClientConnection.swift +++ b/Sources/GRPCInteroperabilityTests/InteroperabilityTestClientConnection.swift @@ -31,7 +31,7 @@ public func makeInteroperabilityTestClientConnection( port: Int, eventLoopGroup: EventLoopGroup, useTLS: Bool -) throws -> EventLoopFuture { +) throws -> ClientConnection { var configuration = ClientConnection.Configuration( target: .hostAndPort(host, port), eventLoopGroup: eventLoopGroup) @@ -48,5 +48,5 @@ public func makeInteroperabilityTestClientConnection( configuration.tlsConfiguration = .init(sslContext: context, hostnameOverride: hostOverride) } - return ClientConnection.start(configuration) + return ClientConnection(configuration: configuration) } diff --git a/Sources/GRPCInteroperabilityTestsCLI/main.swift b/Sources/GRPCInteroperabilityTestsCLI/main.swift index 46b7ea644..4c4c0344a 100644 --- a/Sources/GRPCInteroperabilityTestsCLI/main.swift +++ b/Sources/GRPCInteroperabilityTestsCLI/main.swift @@ -160,7 +160,7 @@ let group = Group { group in host: host, port: port, eventLoopGroup: eventLoopGroup, - useTLS: useTLS == "true").wait() + useTLS: useTLS == "true") try runTest(instance, name: testCaseName, connection: connection) } } diff --git a/Sources/GRPCPerformanceTests/main.swift b/Sources/GRPCPerformanceTests/main.swift index 8db22c6bf..d2a86cc6e 100644 --- a/Sources/GRPCPerformanceTests/main.swift +++ b/Sources/GRPCPerformanceTests/main.swift @@ -7,14 +7,12 @@ import Commander struct ConnectionFactory { var configuration: ClientConnection.Configuration - func makeConnection() throws -> EventLoopFuture { - return ClientConnection.start(configuration) + func makeConnection() -> ClientConnection { + return ClientConnection(configuration: self.configuration) } - func makeEchoClient() throws -> EventLoopFuture { - return try self.makeConnection().map { - Echo_EchoServiceClient(connection: $0) - } + func makeEchoClient() -> Echo_EchoServiceClient { + return Echo_EchoServiceClient(connection: self.makeConnection()) } } @@ -42,7 +40,7 @@ class UnaryThroughput: Benchmark { } func setUp() throws { - self.client = try self.factory.makeEchoClient().wait() + self.client = self.factory.makeEchoClient() self.request = String(repeating: "0", count: self.requestLength) } @@ -56,7 +54,7 @@ class UnaryThroughput: Benchmark { client.get(Echo_EchoRequest.with { $0.text = self.request }).response } - try EventLoopFuture.andAllSucceed(requests, on: self.client.connection.channel.eventLoop).wait() + try EventLoopFuture.andAllSucceed(requests, on: self.client.connection.eventLoop).wait() } } @@ -86,8 +84,32 @@ class BidirectionalThroughput: UnaryThroughput { final class ConnectionCreationThroughput: Benchmark { let factory: ConnectionFactory let connections: Int + var createdConnections: [ClientConnection] = [] + + class ConnectionReadinessDelegate: ConnectivityStateDelegate { + let promise: EventLoopPromise + + var ready: EventLoopFuture { + return promise.futureResult + } + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func connectivityStateDidChange(from oldState: ConnectivityState, to newState: ConnectivityState) { + switch newState { + case .ready: + promise.succeed(()) - var createdConnections: [EventLoopFuture] = [] + case .shutdown: + promise.fail(GRPCStatus(code: .unavailable, message: nil)) + + default: + break + } + } + } init(factory: ConnectionFactory, connections: Int) { self.factory = factory @@ -97,20 +119,25 @@ final class ConnectionCreationThroughput: Benchmark { func setUp() throws { } func run() throws { - self.createdConnections = try (0.. ClientConnection { - return try ClientConnection.start(self.makeClientConfiguration(port: port)).wait() + return try ClientConnection(configuration: self.makeClientConfiguration(port: port)) } func makeEchoProvider() -> Echo_EchoProvider { return EchoProvider() } @@ -179,7 +179,6 @@ class EchoTestCaseBase: XCTestCase { override func tearDown() { // Some tests close the channel, so would throw here if called twice. try? self.client.connection.close().wait() - XCTAssertNoThrow(try self.clientEventLoopGroup.syncShutdownGracefully()) self.client = nil self.clientEventLoopGroup = nil diff --git a/Tests/GRPCTests/ClientConnectionBackoffTests.swift b/Tests/GRPCTests/ClientConnectionBackoffTests.swift index 88773e1d4..8f55ede92 100644 --- a/Tests/GRPCTests/ClientConnectionBackoffTests.swift +++ b/Tests/GRPCTests/ClientConnectionBackoffTests.swift @@ -18,32 +18,58 @@ import GRPC import NIO import XCTest +class ConnectivityStateCollectionDelegate: ConnectivityStateDelegate { + var states: [ConnectivityState] = [] + + func clearStates() -> [ConnectivityState] { + defer { + self.states = [] + } + return self.states + } + + func connectivityStateDidChange(from oldState: ConnectivityState, to newState: ConnectivityState) { + self.states.append(newState) + } +} + class ClientConnectionBackoffTests: XCTestCase { let port = 8080 - var client: EventLoopFuture! + var client: ClientConnection! var server: EventLoopFuture! - var group: EventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + var serverGroup: EventLoopGroup! + var clientGroup: EventLoopGroup! + + var stateDelegate = ConnectivityStateCollectionDelegate() + + override func setUp() { + self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.clientGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + } override func tearDown() { if let server = self.server { XCTAssertNoThrow(try server.flatMap { $0.channel.close() }.wait()) } - - // We don't always expect a client (since we deliberately timeout the connection in some cases). - if let client = try? self.client.wait(), client.channel.isActive { - XCTAssertNoThrow(try client.channel.close().wait()) - } - - XCTAssertNoThrow(try self.group.syncShutdownGracefully()) + XCTAssertNoThrow(try? self.serverGroup.syncShutdownGracefully()) + self.server = nil + self.serverGroup = nil + + // We don't always expect a client to be closed cleanly, since in some cases we deliberately + // timeout the connection. + try? self.client.close().wait() + XCTAssertNoThrow(try self.clientGroup.syncShutdownGracefully()) + self.client = nil + self.clientGroup = nil } func makeServer() -> EventLoopFuture { let configuration = Server.Configuration( target: .hostAndPort("localhost", self.port), - eventLoopGroup: self.group, - serviceProviders: []) + eventLoopGroup: self.serverGroup, + serviceProviders: [EchoProvider()]) return Server.start(configuration: configuration) } @@ -51,50 +77,115 @@ class ClientConnectionBackoffTests: XCTestCase { func makeClientConfiguration() -> ClientConnection.Configuration { return .init( target: .hostAndPort("localhost", self.port), - eventLoopGroup: self.group, - connectionBackoff: ConnectionBackoff()) + eventLoopGroup: self.clientGroup, + connectivityStateDelegate: self.stateDelegate, + connectionBackoff: ConnectionBackoff(maximumBackoff: 0.1)) } func makeClientConnection( _ configuration: ClientConnection.Configuration - ) -> EventLoopFuture { - return ClientConnection.start(configuration) + ) -> ClientConnection { + return ClientConnection(configuration: configuration) } func testClientConnectionFailsWithNoBackoff() throws { var configuration = self.makeClientConfiguration() configuration.connectionBackoff = nil + let connectionShutdown = self.expectation(description: "client shutdown") self.client = self.makeClientConnection(configuration) - XCTAssertThrowsError(try self.client.wait()) { error in - XCTAssert(error is NIOConnectionError) + self.client.connectivity.onNext(state: .shutdown) { + connectionShutdown.fulfill() } + + self.wait(for: [connectionShutdown], timeout: 1.0) + XCTAssertEqual(self.stateDelegate.states, [.connecting, .shutdown]) } func testClientEventuallyConnects() throws { - let clientConnected = self.expectation(description: "client connected") - let serverStarted = self.expectation(description: "server started") - // Start the client first. self.client = self.makeClientConnection(self.makeClientConfiguration()) - self.client.assertSuccess(fulfill: clientConnected) - // Sleep for a little bit to make sure we hit the backoff. - Thread.sleep(forTimeInterval: 0.2) + let transientFailure = self.expectation(description: "connection transientFailure") + self.client.connectivity.onNext(state: .transientFailure) { + transientFailure.fulfill() + } + + let connectionReady = self.expectation(description: "connection ready") + self.client.connectivity.onNext(state: .ready) { + connectionReady.fulfill() + } + + self.wait(for: [transientFailure], timeout: 1.0) self.server = self.makeServer() + let serverStarted = self.expectation(description: "server started") self.server.assertSuccess(fulfill: serverStarted) - self.wait(for: [serverStarted, clientConnected], timeout: 2.0, enforceOrder: true) + self.wait(for: [serverStarted, connectionReady], timeout: 2.0, enforceOrder: true) + XCTAssertEqual(self.stateDelegate.states, [.connecting, .transientFailure, .connecting, .ready]) } func testClientEventuallyTimesOut() throws { - var configuration = self.makeClientConfiguration() - configuration.connectionBackoff = ConnectionBackoff(maximumBackoff: 0.1) + let connectionShutdown = self.expectation(description: "connection shutdown") + self.client = self.makeClientConnection(self.makeClientConfiguration()) + self.client.connectivity.onNext(state: .shutdown) { + connectionShutdown.fulfill() + } + self.wait(for: [connectionShutdown], timeout: 1.0) + XCTAssertEqual(self.stateDelegate.states, [.connecting, .transientFailure, .connecting, .shutdown]) + } + + func testClientReconnectsAutomatically() throws { + self.server = self.makeServer() + let server = try self.server.wait() + + let connectionReady = self.expectation(description: "connection ready") + var configuration = self.makeClientConfiguration() + configuration.connectionBackoff!.maximumBackoff = 2.0 self.client = self.makeClientConnection(configuration) - XCTAssertThrowsError(try self.client.wait()) { error in - XCTAssert(error is NIOConnectionError) + self.client.connectivity.onNext(state: .ready) { + connectionReady.fulfill() } + + // Once the connection is ready we can kill the server. + self.wait(for: [connectionReady], timeout: 1.0) + XCTAssertEqual(self.stateDelegate.clearStates(), [.connecting, .ready]) + + try server.close().wait() + try self.serverGroup.syncShutdownGracefully() + self.server = nil + self.serverGroup = nil + + let transientFailure = self.expectation(description: "connection transientFailure") + self.client.connectivity.onNext(state: .transientFailure) { + transientFailure.fulfill() + } + + self.wait(for: [transientFailure], timeout: 1.0) + XCTAssertEqual(self.stateDelegate.clearStates(), [.connecting, .transientFailure]) + + let reconnectionReady = self.expectation(description: "(re)connection ready") + self.client.connectivity.onNext(state: .ready) { + reconnectionReady.fulfill() + } + + let echo = Echo_EchoServiceClient(connection: self.client) + // This should succeed once we get a connection again. + let get = echo.get(.with { $0.text = "hello" }) + + // Start a new server. + self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.server = self.makeServer() + + self.wait(for: [reconnectionReady], timeout: 2.0) + XCTAssertEqual(self.stateDelegate.clearStates(), [.connecting, .ready]) + + // The call should be able to succeed now. + XCTAssertEqual(try get.status.map { $0.code }.wait(), .ok) + + try self.client.close().wait() + XCTAssertEqual(self.stateDelegate.clearStates(), [.shutdown]) } } diff --git a/Tests/GRPCTests/ClientTLSFailureTests.swift b/Tests/GRPCTests/ClientTLSFailureTests.swift index f363dde4b..952a2df1c 100644 --- a/Tests/GRPCTests/ClientTLSFailureTests.swift +++ b/Tests/GRPCTests/ClientTLSFailureTests.swift @@ -20,6 +20,20 @@ import NIO import NIOSSL import XCTest +class ErrorRecordingDelegate: ClientErrorDelegate { + var errors: [Error] = [] + var expectation: XCTestExpectation + + init(expectation: XCTestExpectation) { + self.expectation = expectation + } + + func didCatchError(_ error: Error, file: StaticString, line: Int) { + self.errors.append(error) + self.expectation.fulfill() + } +} + class ClientTLSFailureTests: XCTestCase { let defaultServerTLSConfiguration = TLSConfiguration.forServer( certificateChain: [.certificate(SampleCertificate.server.certificate)], @@ -39,19 +53,26 @@ class ClientTLSFailureTests: XCTestCase { var server: Server! var port: Int! - func makeClientConnection( - configuration: TLSConfiguration, + func makeClientConfiguration( + tls: TLSConfiguration, hostOverride: String? = SampleCertificate.server.commonName - ) throws -> EventLoopFuture { - let context = try NIOSSLContext(configuration: configuration) - let clientConfiguration = ClientConnection.Configuration( + ) throws -> ClientConnection.Configuration { + return ClientConnection.Configuration( target: .hostAndPort("localhost", self.port), eventLoopGroup: self.clientEventLoopGroup, - tlsConfiguration: ClientConnection.TLSConfiguration( - sslContext: context, - hostnameOverride: hostOverride)) + tlsConfiguration: try .init( + sslContext: NIOSSLContext(configuration: tls), + hostnameOverride: hostOverride + ) + ) + } - return ClientConnection.start(clientConfiguration) + func makeClientTLSConfiguration( + tls: TLSConfiguration, + hostOverride: String? = SampleCertificate.server.commonName + ) throws -> ClientConnection.TLSConfiguration { + let context = try NIOSSLContext(configuration: tls) + return .init(sslContext: context, hostnameOverride: hostOverride) } func makeClientConnectionExpectation() -> XCTestExpectation { @@ -90,51 +111,77 @@ class ClientTLSFailureTests: XCTestCase { } func testClientConnectionFailsWhenProtocolCanNotBeNegotiated() throws { - var configuration = defaultClientTLSConfiguration - configuration.applicationProtocols = ["not-h2", "not-grpc-ext"] + let shutdownExpectation = self.expectation(description: "client shutdown") + let errorExpectation = self.expectation(description: "error") - let connection = try self.makeClientConnection(configuration: configuration) - let connectionExpectation = self.makeClientConnectionExpectation() + var tls = defaultClientTLSConfiguration + tls.applicationProtocols = ["not-h2", "not-grpc-ext"] + var configuration = try self.makeClientConfiguration(tls: tls) - connection.assertError(fulfill: connectionExpectation) { error in - let clientError = (error as? GRPCError)?.wrappedError as? GRPCClientError - XCTAssertEqual(clientError, .applicationLevelProtocolNegotiationFailed) + let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation) + configuration.errorDelegate = errorRecorder + + let connection = ClientConnection(configuration: configuration) + connection.connectivity.onNext(state: .shutdown) { + shutdownExpectation.fulfill() } - self.wait(for: [connectionExpectation], timeout: self.defaultTestTimeout) + self.wait(for: [shutdownExpectation, errorExpectation], timeout: self.defaultTestTimeout) + + let clientErrors = errorRecorder.errors.compactMap { $0 as? GRPCClientError } + XCTAssertEqual(clientErrors, [.applicationLevelProtocolNegotiationFailed]) } func testClientConnectionFailsWhenServerIsUnknown() throws { - var configuration = defaultClientTLSConfiguration - configuration.trustRoots = .certificates([]) + let shutdownExpectation = self.expectation(description: "client shutdown") + let errorExpectation = self.expectation(description: "error") - let connection = try self.makeClientConnection(configuration: configuration) - let connectionExpectation = self.makeClientConnectionExpectation() + var tls = defaultClientTLSConfiguration + tls.trustRoots = .certificates([]) + var configuration = try self.makeClientConfiguration(tls: tls) - connection.assertError(fulfill: connectionExpectation) { error in - guard case .some(.handshakeFailed(.sslError)) = error as? NIOSSLError else { - XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError) but got \(error)") - return - } + let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation) + configuration.errorDelegate = errorRecorder + + let connection = ClientConnection(configuration: configuration) + connection.connectivity.onNext(state: .shutdown) { + shutdownExpectation.fulfill() } - self.wait(for: [connectionExpectation], timeout: self.defaultTestTimeout) + self.wait(for: [shutdownExpectation, errorExpectation], timeout: self.defaultTestTimeout) + + if let nioSSLError = errorRecorder.errors.first as? NIOSSLError, + case .handshakeFailed(.sslError) = nioSSLError { + // Expected case. + } else { + XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)") + } } func testClientConnectionFailsWhenHostnameIsNotValid() throws { - let connection = try self.makeClientConnection( - configuration: self.defaultClientTLSConfiguration, - hostOverride: "not-the-server-hostname") + let shutdownExpectation = self.expectation(description: "client shutdown") + let errorExpectation = self.expectation(description: "error") + + var configuration = try self.makeClientConfiguration( + tls: self.defaultClientTLSConfiguration, + hostOverride: "not-the-server-hostname" + ) - let connectionExpectation = self.makeClientConnectionExpectation() + let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation) + configuration.errorDelegate = errorRecorder - connection.assertError(fulfill: connectionExpectation) { error in - guard case .some(.unableToValidateCertificate) = error as? NIOSSLError else { - XCTFail("Expected NIOSSLError.unableToValidateCertificate but got \(error)") - return - } + let connection = ClientConnection(configuration: configuration) + connection.connectivity.onNext(state: .shutdown) { + shutdownExpectation.fulfill() } - self.wait(for: [connectionExpectation], timeout: self.defaultTestTimeout) + self.wait(for: [shutdownExpectation, errorExpectation], timeout: self.defaultTestTimeout) + + if let nioSSLError = errorRecorder.errors.first as? NIOSSLError, + case .unableToValidateCertificate = nioSSLError { + // Expected case. + } else { + XCTFail("Expected NIOSSLError.unableToValidateCertificate") + } } } diff --git a/Tests/GRPCTests/ConnectivityStateMonitorTests.swift b/Tests/GRPCTests/ConnectivityStateMonitorTests.swift new file mode 100644 index 000000000..6ed5866be --- /dev/null +++ b/Tests/GRPCTests/ConnectivityStateMonitorTests.swift @@ -0,0 +1,93 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +@testable import GRPC +import XCTest + +class ConnectivityStateMonitorTests: XCTestCase { + var monitor = ConnectivityStateMonitor(delegate: nil) + + // Ensure `.idle` isn't first since it is the initial state and we only trigger callbacks + // when the state changes, not when the state is set. + let states: [ConnectivityState] = [.connecting, .ready, .transientFailure, .shutdown, .idle] + + func testDelegateOnlyCalledForChanges() { + let recorder = StateRecordingDelegate() + self.monitor.delegate = recorder + + self.monitor.state = .connecting + self.monitor.state = .ready + self.monitor.state = .ready + self.monitor.state = .shutdown + + XCTAssertEqual(recorder.states, [.connecting, .ready, .shutdown]) + } + + func testOnNextIsOnlyInvokedOnce() { + for state in self.states { + let currentState = self.monitor.state + + var calls = 0 + self.monitor.onNext(state: state) { + calls += 1 + } + + // Trigger the callback. + self.monitor.state = state + XCTAssertEqual(calls, 1) + + // Go back and forth; the callback should not be triggered again. + self.monitor.state = currentState + self.monitor.state = state + XCTAssertEqual(calls, 1) + } + } + + func testRemovingCallbacks() { + for state in self.states { + self.monitor.onNext(state: state) { + XCTFail("Callback unexpectedly called") + } + + self.monitor.onNext(state: state, callback: nil) + self.monitor.state = state + } + } + + func testMultipleCallbacksRegistered() { + var calls = 0 + self.states.forEach { + self.monitor.onNext(state: $0) { + calls += 1 + } + } + + self.states.forEach { + self.monitor.state = $0 + } + + XCTAssertEqual(calls, self.states.count) + } +} + +extension ConnectivityStateMonitorTests { + /// A `ConnectivityStateDelegate` which each new state. + class StateRecordingDelegate: ConnectivityStateDelegate { + var states: [ConnectivityState] = [] + func connectivityStateDidChange(from oldState: ConnectivityState, to newState: ConnectivityState) { + self.states.append(newState) + } + } +} diff --git a/Tests/GRPCTests/GRPCInteroperabilityTests.swift b/Tests/GRPCTests/GRPCInteroperabilityTests.swift index c580297ef..840dbec16 100644 --- a/Tests/GRPCTests/GRPCInteroperabilityTests.swift +++ b/Tests/GRPCTests/GRPCInteroperabilityTests.swift @@ -51,7 +51,7 @@ class GRPCInsecureInteroperabilityTests: XCTestCase { port: serverPort, eventLoopGroup: self.clientEventLoopGroup, useTLS: self.useTLS - ).wait() + ) } override func tearDown() { diff --git a/Tests/GRPCTests/XCTestManifests.swift b/Tests/GRPCTests/XCTestManifests.swift index 173be35e7..d21c2dd66 100644 --- a/Tests/GRPCTests/XCTestManifests.swift +++ b/Tests/GRPCTests/XCTestManifests.swift @@ -48,6 +48,7 @@ extension ClientConnectionBackoffTests { ("testClientConnectionFailsWithNoBackoff", testClientConnectionFailsWithNoBackoff), ("testClientEventuallyConnects", testClientEventuallyConnects), ("testClientEventuallyTimesOut", testClientEventuallyTimesOut), + ("testClientReconnectsAutomatically", testClientReconnectsAutomatically), ] } @@ -101,6 +102,18 @@ extension ConnectionBackoffTests { ] } +extension ConnectivityStateMonitorTests { + // DO NOT MODIFY: This is autogenerated, use: + // `swift test --generate-linuxmain` + // to regenerate. + static let __allTests__ConnectivityStateMonitorTests = [ + ("testDelegateOnlyCalledForChanges", testDelegateOnlyCalledForChanges), + ("testMultipleCallbacksRegistered", testMultipleCallbacksRegistered), + ("testOnNextIsOnlyInvokedOnce", testOnNextIsOnlyInvokedOnce), + ("testRemovingCallbacks", testRemovingCallbacks), + ] +} + extension FunctionalTestsAnonymousClient { // DO NOT MODIFY: This is autogenerated, use: // `swift test --generate-linuxmain` @@ -411,6 +424,7 @@ public func __allTests() -> [XCTestCaseEntry] { testCase(ClientThrowingWhenServerReturningErrorTests.__allTests__ClientThrowingWhenServerReturningErrorTests), testCase(ClientTimeoutTests.__allTests__ClientTimeoutTests), testCase(ConnectionBackoffTests.__allTests__ConnectionBackoffTests), + testCase(ConnectivityStateMonitorTests.__allTests__ConnectivityStateMonitorTests), testCase(FunctionalTestsAnonymousClient.__allTests__FunctionalTestsAnonymousClient), testCase(FunctionalTestsAnonymousClientNIOTS.__allTests__FunctionalTestsAnonymousClientNIOTS), testCase(FunctionalTestsInsecureTransport.__allTests__FunctionalTestsInsecureTransport),