diff --git a/Sources/GRPCCore/Call/Server/RPCRouter.swift b/Sources/GRPCCore/Call/Server/RPCRouter.swift index ab0b615a8..4bfe57c3c 100644 --- a/Sources/GRPCCore/Call/Server/RPCRouter.swift +++ b/Sources/GRPCCore/Call/Server/RPCRouter.swift @@ -133,3 +133,27 @@ public struct RPCRouter: Sendable { return self.handlers.removeValue(forKey: descriptor) != nil } } + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension RPCRouter { + internal func handle( + stream: RPCStream, RPCWriter.Closable>, + interceptors: [any ServerInterceptor] + ) async { + if let handler = self.handlers[stream.descriptor] { + await handler.handle(stream: stream, interceptors: interceptors) + } else { + // If this throws then the stream must be closed which we can't do anything about, so ignore + // any error. + try? await stream.outbound.write(.status(.rpcNotImplemented, [:])) + stream.outbound.finish() + } + } +} + +extension Status { + fileprivate static let rpcNotImplemented = Status( + code: .unimplemented, + message: "Requested RPC isn't implemented by this server." + ) +} diff --git a/Sources/GRPCCore/Call/Server/ServerRequest.swift b/Sources/GRPCCore/Call/Server/ServerRequest.swift index 31bf0e1cf..dafbee3d7 100644 --- a/Sources/GRPCCore/Call/Server/ServerRequest.swift +++ b/Sources/GRPCCore/Call/Server/ServerRequest.swift @@ -72,8 +72,24 @@ extension ServerRequest { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension ServerRequest.Stream { - @_spi(Testing) public init(single request: ServerRequest.Single) { self.init(metadata: request.metadata, messages: .one(request.message)) } } + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension ServerRequest.Single { + public init(stream request: ServerRequest.Stream) async throws { + var iterator = request.messages.makeAsyncIterator() + + guard let message = try await iterator.next() else { + throw RPCError(code: .internalError, message: "Empty stream.") + } + + guard try await iterator.next() == nil else { + throw RPCError(code: .internalError, message: "Too many messages.") + } + + self = ServerRequest.Single(metadata: request.metadata, message: message) + } +} diff --git a/Sources/GRPCCore/Call/Server/ServerResponse.swift b/Sources/GRPCCore/Call/Server/ServerResponse.swift index 5fbdb43ec..a0b516815 100644 --- a/Sources/GRPCCore/Call/Server/ServerResponse.swift +++ b/Sources/GRPCCore/Call/Server/ServerResponse.swift @@ -327,7 +327,6 @@ extension ServerResponse.Stream { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension ServerResponse.Stream { - @_spi(Testing) public init(single response: ServerResponse.Single) { switch response.accepted { case .success(let contents): diff --git a/Sources/GRPCCore/Server.swift b/Sources/GRPCCore/Server.swift new file mode 100644 index 000000000..43a48041f --- /dev/null +++ b/Sources/GRPCCore/Server.swift @@ -0,0 +1,479 @@ +/* + * Copyright 2023, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Atomics + +/// A gRPC server. +/// +/// The server accepts connections from clients and listens on each connection for new streams +/// which are initiated by the client. Each stream maps to a single RPC. The server routes accepted +/// streams to a service to handle the RPC or rejects them with an appropriate error if no service +/// can handle the RPC. +/// +/// A ``Server`` may listen with multiple transports (for example, HTTP/2 and in-process) and route +/// requests from each transport to the same service instance. You can also use "interceptors", +/// to implement cross-cutting logic which apply to all accepted RPCs. Example uses of interceptors +/// include request filtering, authentication, and logging. Once requests have been intercepted +/// they are passed to a handler which in turn returns a response to send back to the client. +/// +/// ## Creating and configuring a server +/// +/// The following example demonstrates how to create and configure a server. +/// +/// ```swift +/// let server = Server() +/// +/// // Create and add an in-process transport. +/// let inProcessTransport = InProcessServerTransport() +/// server.transports.add(inProcessTransport) +/// +/// // Create and register the 'Greeter' and 'Echo' services. +/// server.services.register(GreeterService()) +/// server.services.register(EchoService()) +/// +/// // Create and add some interceptors. +/// server.interceptors.add(StatsRecordingServerInterceptors()) +/// ``` +/// +/// ## Starting and stopping the server +/// +/// Once you have configured the server call ``run()`` to start it. Calling ``run()`` starts each +/// of the server's transports. A ``ServerError`` is thrown if any of the transports can't be +/// started. +/// +/// ```swift +/// // Start running the server. +/// try await server.run() +/// ``` +/// +/// The ``run()`` method won't return until the server has finished handling all requests. You can +/// signal to the server that it should stop accepting new requests by calling ``stopListening()``. +/// This allows the server to drain existing requests gracefully. To stop the server more abruptly +/// you can cancel the task running your server. If your application requires additional resources +/// that need their lifecycles managed you should consider using [Swift Service +/// Lifecycle](https://github.com/swift-server/swift-service-lifecycle). +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +public final class Server: Sendable { + typealias Stream = RPCStream + + /// A collection of ``ServerTransport`` implementations that the server uses to listen + /// for new requests. + public var transports: Transports { + get { + self.storage.withLockedValue { $0.transports } + } + set { + self.storage.withLockedValue { $0.transports = newValue } + } + } + + /// The services registered which the server is serving. + public var services: Services { + get { + self.storage.withLockedValue { $0.services } + } + set { + self.storage.withLockedValue { $0.services = newValue } + } + } + + /// A collection of ``ServerInterceptor`` implementations which are applied to all accepted + /// RPCs. + /// + /// RPCs are intercepted in the order that interceptors are added. That is, a request received + /// from the client will first be intercepted by the first added interceptor followed by the + /// second, and so on. + public var interceptors: Interceptors { + get { + self.storage.withLockedValue { $0.interceptors } + } + set { + self.storage.withLockedValue { $0.interceptors = newValue } + } + } + + /// Underlying storage for the server. + private struct Storage { + var transports: Transports + var services: Services + var interceptors: Interceptors + var state: State + + init() { + self.transports = Transports() + self.services = Services() + self.interceptors = Interceptors() + self.state = .notStarted + } + } + + private let storage: LockedValueBox + + /// The state of the server. + private enum State { + /// The server hasn't been started yet. Can transition to `starting` or `stopped`. + case notStarted + /// The server is starting but isn't accepting requests yet. Can transition to `running` + /// and `stopping`. + case starting + /// The server is running and accepting RPCs. Can transition to `stopping`. + case running + /// The server is stopping and no new RPCs will be accepted. Existing RPCs may run to + /// completion. May transition to `stopped`. + case stopping + /// The server has stopped, no RPCs are in flight and no more will be accepted. This state + /// is terminal. + case stopped + } + + /// Creates a new server with no resources. + /// + /// You can add resources to the server via ``transports-swift.property``, + /// ``services-swift.property``, and ``interceptors-swift.property`` and start the server by + /// calling ``run()``. Any changes to resources after ``run()`` has been called will be ignored. + public init() { + self.storage = LockedValueBox(Storage()) + } + + /// Starts the server and runs until all registered transports have closed. + /// + /// No RPCs are processed until all transports are listening. If a transport fails to start + /// listening then all open transports are closed and a ``ServerError`` is thrown. + /// + /// This function returns when all transports have stopped listening and all requests have been + /// handled. You can signal to transports that they should stop listening by calling + /// ``stopListening()``. The server will continue to process existing requests. + /// + /// To stop the server more abruptly you can cancel the task that this function is running in. + /// + /// You must register all resources you wish to use with the server before calling this function + /// as changes made after calling ``run()`` won't be reflected. + /// + /// - Note: You can only call this function once, repeated calls will result in a + /// ``ServerError`` being thrown. + /// - Important: You must register at least one transport by calling + /// ``Transports-swift.struct/add(_:)`` before calling this method. + public func run() async throws { + let (transports, router, interceptors) = try self.storage.withLockedValue { storage in + switch storage.state { + case .notStarted: + storage.state = .starting + return (storage.transports, storage.services.router, storage.interceptors) + + case .starting, .running: + throw ServerError( + code: .serverIsAlreadyRunning, + message: "The server is already running and can only be started once." + ) + + case .stopping, .stopped: + throw ServerError( + code: .serverIsStopped, + message: "The server has stopped and can only be started once." + ) + } + } + + // When we exit this function we must have stopped. + defer { + self.storage.withLockedValue { $0.state = .stopped } + } + + if transports.values.isEmpty { + throw ServerError( + code: .noTransportsConfigured, + message: """ + Can't start server, no transports are configured. You must add at least one transport \ + to the server using 'transports.add(_:)' before calling 'run()'. + """ + ) + } + + var listeners: [RPCAsyncSequence] = [] + listeners.reserveCapacity(transports.values.count) + + for transport in transports.values { + do { + let listener = try await transport.listen() + listeners.append(listener) + } catch let cause { + // Failed to start, so start stopping. + self.storage.withLockedValue { $0.state = .stopping } + // Some listeners may have started and have streams which need closing. + await Self.rejectRequests(listeners, transports: transports) + + throw ServerError( + code: .failedToStartTransport, + message: """ + Server didn't start because the '\(type(of: transport))' transport threw an error \ + while starting. + """, + cause: cause + ) + } + } + + // May have been told to stop listening while starting the transports. + let isStopping = self.storage.withLockedValue { storage in + switch storage.state { + case .notStarted, .running, .stopped: + fatalError("Invalid state") + + case .starting: + storage.state = .running + return false + + case .stopping: + return true + } + } + + // If the server is stopping then notify the transport and then consume them: there may be + // streams opened at a lower level (e.g. HTTP/2) which are already open and need to be consumed. + if isStopping { + await Self.rejectRequests(listeners, transports: transports) + } else { + await Self.handleRequests(listeners, router: router, interceptors: interceptors) + } + } + + private static func rejectRequests( + _ listeners: [RPCAsyncSequence], + transports: Transports + ) async { + // Tell the active listeners to stop listening. + for transport in transports.values.prefix(listeners.count) { + transport.stopListening() + } + + // Drain any open streams on active listeners. + await withTaskGroup(of: Void.self) { group in + let unavailable = Status( + code: .unavailable, + message: "The server isn't ready to accept requests." + ) + + for listener in listeners { + do { + for try await stream in listener { + group.addTask { + try? await stream.outbound.write(.status(unavailable, [:])) + stream.outbound.finish() + } + } + } catch { + // Suppress any errors, the original error from the transport which failed to start + // should be thrown. + } + } + } + } + + private static func handleRequests( + _ listeners: [RPCAsyncSequence], + router: RPCRouter, + interceptors: Interceptors + ) async { + #if swift(>=5.9) + if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { + await Self.handleRequestsInDiscardingTaskGroup( + listeners, + router: router, + interceptors: interceptors + ) + } else { + await Self.handleRequestsInTaskGroup(listeners, router: router, interceptors: interceptors) + } + #else + await Self.handleRequestsInTaskGroup(listeners, router: router, interceptors: interceptors) + #endif + } + + #if swift(>=5.9) + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + private static func handleRequestsInDiscardingTaskGroup( + _ listeners: [RPCAsyncSequence], + router: RPCRouter, + interceptors: Interceptors + ) async { + await withDiscardingTaskGroup { group in + for listener in listeners { + group.addTask { + await withDiscardingTaskGroup { subGroup in + do { + for try await stream in listener { + subGroup.addTask { + await router.handle(stream: stream, interceptors: interceptors.values) + } + } + } catch { + // If the listener threw then the connection must be broken, cancel all work. + subGroup.cancelAll() + } + } + } + } + } + } + #endif + + private static func handleRequestsInTaskGroup( + _ listeners: [RPCAsyncSequence], + router: RPCRouter, + interceptors: Interceptors + ) async { + // If the discarding task group isn't available then fall back to using a regular task group + // with a limit on subtasks. Most servers will use an HTTP/2 based transport, most + // implementations limit connections to 100 concurrent streams. A limit of 4096 gives the server + // scope to handle nearly 41 completely saturated connections. + let maxConcurrentSubTasks = 4096 + let tasks = ManagedAtomic(0) + + await withTaskGroup(of: Void.self) { group in + for listener in listeners { + group.addTask { + await withTaskGroup(of: Void.self) { subGroup in + do { + for try await stream in listener { + let taskCount = tasks.wrappingIncrementThenLoad(ordering: .sequentiallyConsistent) + if taskCount >= maxConcurrentSubTasks { + _ = await subGroup.next() + tasks.wrappingDecrement(ordering: .sequentiallyConsistent) + } + + subGroup.addTask { + await router.handle(stream: stream, interceptors: interceptors.values) + } + } + } catch { + // If the listener threw then the connection must be broken, cancel all work. + subGroup.cancelAll() + } + } + } + } + } + } + + /// Signal to the server that it should stop listening for new requests. + /// + /// By calling this function you indicate to clients that they mustn't start new requests + /// against this server. Once the server has processed all requests the ``run()`` method returns. + /// + /// Calling this on a server which is already stopping or has stopped has no effect. + public func stopListening() { + let transports = self.storage.withLockedValue { storage in + let transports: Transports? + + switch storage.state { + case .notStarted: + storage.state = .stopped + transports = nil + case .starting: + storage.state = .stopping + transports = nil + case .running: + storage.state = .stopping + transports = storage.transports + case .stopping: + transports = nil + case .stopped: + transports = nil + } + + return transports + } + + if let transports = transports?.values { + for transport in transports { + transport.stopListening() + } + } + } +} + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension Server { + /// The transports which provide a bidirectional communication channel with clients. + /// + /// You can add a new transport by calling ``add(_:)``. + public struct Transports: Sendable { + private(set) var values: [any (ServerTransport & Sendable)] = [] + + /// Add a transport to the server. + /// + /// - Parameter transport: The transport to add. + public mutating func add(_ transport: some (ServerTransport & Sendable)) { + self.values.append(transport) + } + } + + /// The services registered with this server. + /// + /// You can register services by calling ``register(_:)`` or by manually adding handlers for + /// methods to the ``router``. + public struct Services: Sendable { + /// The router storing handlers for known methods. + public var router = RPCRouter() + + /// Registers service methods with the ``router``. + /// + /// - Parameter service: The service to register with the ``router``. + public mutating func register(_ service: some RegistrableRPCService) { + service.registerMethods(with: &self.router) + } + } + + /// A collection of interceptors providing cross-cutting functionality to each accepted RPC. + public struct Interceptors: Sendable { + private(set) var values: [any ServerInterceptor] = [] + + /// Add an interceptor to the server. + /// + /// The order in which interceptors are added reflects the order in which they are called. The + /// first interceptor added will be the first interceptor to intercept each request. The last + /// interceptor added will be the final interceptor to intercept each request before calling + /// the appropriate handler. + /// + /// - Parameter interceptor: The interceptor to add. + public mutating func add(_ interceptor: some ServerInterceptor) { + self.values.append(interceptor) + } + } +} + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension Server.Transports: CustomStringConvertible { + public var description: String { + return String(describing: self.values) + } +} + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension Server.Services: CustomStringConvertible { + public var description: String { + // List the fully qualified all methods ordered by service and then method + let rpcs = self.router.methods.map { $0.fullyQualifiedMethod }.sorted() + return String(describing: rpcs) + } +} + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension Server.Interceptors: CustomStringConvertible { + public var description: String { + return String(describing: self.values.map { String(describing: type(of: $0)) }) + } +} diff --git a/Sources/GRPCCore/ServerError.swift b/Sources/GRPCCore/ServerError.swift new file mode 100644 index 000000000..45e4f4e95 --- /dev/null +++ b/Sources/GRPCCore/ServerError.swift @@ -0,0 +1,148 @@ +/* + * Copyright 2023, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// A runtime error thrown by the server. +/// +/// In contrast to ``RPCError``, the ``ServerError`` represents errors which happen at a scope +/// wider than an individual RPC. For example, attempting to start a server which is already +/// stopped would result in a ``ServerError``. +public struct ServerError: Error, Hashable, @unchecked Sendable { + private var storage: Storage + + // Ensures the underlying storage is unique. + private mutating func ensureUniqueStorage() { + if !isKnownUniquelyReferenced(&self.storage) { + self.storage = self.storage.copy() + } + } + + /// The code indicating the domain of the error. + public var code: Code { + get { self.storage.code } + set { + self.ensureUniqueStorage() + self.storage.code = newValue + } + } + + /// A message providing more details about the error which may include details specific to this + /// instance of the error. + public var message: String { + get { self.storage.message } + set { + self.ensureUniqueStorage() + self.storage.message = newValue + } + } + + /// The original error which led to this error being thrown. + public var cause: Error? { + get { self.storage.cause } + set { + self.ensureUniqueStorage() + self.storage.cause = newValue + } + } + + /// Creates a new error. + /// + /// - Parameters: + /// - code: The error code. + /// - message: A description of the error. + /// - cause: The original error which led to this error being thrown. + public init(code: Code, message: String, cause: Error? = nil) { + self.storage = Storage(code: code, message: message, cause: cause) + } +} + +extension ServerError: CustomStringConvertible { + public var description: String { + if let cause = self.cause { + return "\(self.code): \"\(self.message)\" (cause: \"\(cause)\")" + } else { + return "\(self.code): \"\(self.message)\"" + } + } +} + +extension ServerError { + private final class Storage: Hashable { + var code: Code + var message: String + var cause: Error? + + init(code: Code, message: String, cause: Error?) { + self.code = code + self.message = message + self.cause = cause + } + + func copy() -> Storage { + return Storage(code: self.code, message: self.message, cause: self.cause) + } + + func hash(into hasher: inout Hasher) { + hasher.combine(self.code) + hasher.combine(self.message) + } + + static func == (lhs: Storage, rhs: Storage) -> Bool { + return lhs.code == rhs.code && lhs.message == rhs.message + } + } +} + +extension ServerError { + public struct Code: Hashable, Sendable { + private enum Value { + case serverIsAlreadyRunning + case serverIsStopped + case failedToStartTransport + case noTransportsConfigured + } + + private var value: Value + private init(_ value: Value) { + self.value = value + } + + /// At attempt to start the server was made but it is already running. + public static var serverIsAlreadyRunning: Self { + Self(.serverIsAlreadyRunning) + } + + /// At attempt to start the server was made but it has already stopped. + public static var serverIsStopped: Self { + Self(.serverIsStopped) + } + + /// The server couldn't be started because a transport failed to start. + public static var failedToStartTransport: Self { + Self(.failedToStartTransport) + } + + /// The server couldn't be started because no transports were configured. + public static var noTransportsConfigured: Self { + Self(.noTransportsConfigured) + } + } +} + +extension ServerError.Code: CustomStringConvertible { + public var description: String { + String(describing: self.value) + } +} diff --git a/Sources/GRPCCore/Transport/ServerTransport.swift b/Sources/GRPCCore/Transport/ServerTransport.swift index e538c68f7..3c3dbc45c 100644 --- a/Sources/GRPCCore/Transport/ServerTransport.swift +++ b/Sources/GRPCCore/Transport/ServerTransport.swift @@ -16,8 +16,8 @@ @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public protocol ServerTransport { - associatedtype Inbound: (AsyncSequence & Sendable) where Inbound.Element == RPCRequestPart - associatedtype Outbound: ClosableRPCWriterProtocol + typealias Inbound = RPCAsyncSequence + typealias Outbound = RPCWriter.Closable /// Starts the transport and returns a sequence of accepted streams to handle. /// diff --git a/Tests/GRPCCoreTests/ServerErrorTests.swift b/Tests/GRPCCoreTests/ServerErrorTests.swift new file mode 100644 index 000000000..afe2b8e2a --- /dev/null +++ b/Tests/GRPCCoreTests/ServerErrorTests.swift @@ -0,0 +1,53 @@ +/* + * Copyright 2023, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import GRPCCore +import XCTest + +final class ServerErrorTests: XCTestCase { + func testCopyOnWrite() { + // ServerError has a heap based storage, so check CoW semantics are correctly implemented. + let error1 = ServerError(code: .failedToStartTransport, message: "Failed to start transport") + var error2 = error1 + error2.code = .serverIsAlreadyRunning + XCTAssertEqual(error1.code, .failedToStartTransport) + XCTAssertEqual(error2.code, .serverIsAlreadyRunning) + + var error3 = error1 + error3.message = "foo" + XCTAssertEqual(error1.message, "Failed to start transport") + XCTAssertEqual(error3.message, "foo") + + var error4 = error1 + error4.cause = CancellationError() + XCTAssertNil(error1.cause) + XCTAssert(error4.cause is CancellationError) + } + + func testCustomStringConvertible() { + let error1 = ServerError(code: .failedToStartTransport, message: "Failed to start transport") + XCTAssertDescription(error1, #"failedToStartTransport: "Failed to start transport""#) + + let error2 = ServerError( + code: .failedToStartTransport, + message: "Failed to start transport", + cause: CancellationError() + ) + XCTAssertDescription( + error2, + #"failedToStartTransport: "Failed to start transport" (cause: "CancellationError()")"# + ) + } +} diff --git a/Tests/GRPCCoreTests/ServerTests.swift b/Tests/GRPCCoreTests/ServerTests.swift new file mode 100644 index 000000000..2fbc2f9be --- /dev/null +++ b/Tests/GRPCCoreTests/ServerTests.swift @@ -0,0 +1,455 @@ +/* + * Copyright 2023, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Atomics +import GRPCCore +import XCTest + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +final class ServerTests: XCTestCase { + func makeInProcessPair() -> (client: InProcessClientTransport, server: InProcessServerTransport) { + let server = InProcessServerTransport() + let client = InProcessClientTransport( + server: server, + executionConfigurations: ClientRPCExecutionConfigurationCollection() + ) + + return (client, server) + } + + func withInProcessClientConnectedToServer( + services: [any RegistrableRPCService], + interceptors: [any ServerInterceptor] = [], + _ body: (InProcessClientTransport, Server) async throws -> Void + ) async throws { + let inProcess = self.makeInProcessPair() + let server = Server() + server.transports.add(inProcess.server) + + for service in services { + server.services.register(service) + } + + for interceptor in interceptors { + server.interceptors.add(interceptor) + } + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await server.run() + } + + group.addTask { + try await inProcess.client.connect(lazily: true) + } + + try await body(inProcess.client, server) + inProcess.client.close() + server.stopListening() + } + + } + + func testServerHandlesUnary() async throws { + try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in + try await client.withStream(descriptor: BinaryEcho.Methods.get) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message([3, 1, 4, 1, 5])) + stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + XCTAssertMetadata(metadata) + + let message = try await responseParts.next() + XCTAssertMessage(message) { + XCTAssertEqual($0, [3, 1, 4, 1, 5]) + } + + let status = try await responseParts.next() + XCTAssertStatus(status) { status, _ in + XCTAssertEqual(status.code, .ok) + } + } + } + } + + func testServerHandlesClientStreaming() async throws { + try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in + try await client.withStream(descriptor: BinaryEcho.Methods.collect) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message([3])) + try await stream.outbound.write(.message([1])) + try await stream.outbound.write(.message([4])) + try await stream.outbound.write(.message([1])) + try await stream.outbound.write(.message([5])) + stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + XCTAssertMetadata(metadata) + + let message = try await responseParts.next() + XCTAssertMessage(message) { + XCTAssertEqual($0, [3, 1, 4, 1, 5]) + } + + let status = try await responseParts.next() + XCTAssertStatus(status) { status, _ in + XCTAssertEqual(status.code, .ok) + } + } + } + } + + func testServerHandlesServerStreaming() async throws { + try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in + try await client.withStream(descriptor: BinaryEcho.Methods.expand) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message([3, 1, 4, 1, 5])) + stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + XCTAssertMetadata(metadata) + + for byte in [3, 1, 4, 1, 5] as [UInt8] { + let message = try await responseParts.next() + XCTAssertMessage(message) { + XCTAssertEqual($0, [byte]) + } + } + + let status = try await responseParts.next() + XCTAssertStatus(status) { status, _ in + XCTAssertEqual(status.code, .ok) + } + } + } + } + + func testServerHandlesBidirectionalStreaming() async throws { + try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in + try await client.withStream(descriptor: BinaryEcho.Methods.update) { stream in + try await stream.outbound.write(.metadata([:])) + for byte in [3, 1, 4, 1, 5] as [UInt8] { + try await stream.outbound.write(.message([byte])) + } + stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + XCTAssertMetadata(metadata) + + for byte in [3, 1, 4, 1, 5] as [UInt8] { + let message = try await responseParts.next() + XCTAssertMessage(message) { + XCTAssertEqual($0, [byte]) + } + } + + let status = try await responseParts.next() + XCTAssertStatus(status) { status, _ in + XCTAssertEqual(status.code, .ok) + } + } + } + } + + func testUnimplementedMethod() async throws { + try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in + try await client.withStream( + descriptor: MethodDescriptor(service: "not", method: "implemented") + ) { stream in + try await stream.outbound.write(.metadata([:])) + stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let status = try await responseParts.next() + XCTAssertStatus(status) { status, _ in + XCTAssertEqual(status.code, .unimplemented) + } + } + } + } + + func testMultipleConcurrentRequests() async throws { + try await self.withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, _ in + await withThrowingTaskGroup(of: Void.self) { group in + for i in UInt8.min ..< UInt8.max { + group.addTask { + try await client.withStream(descriptor: BinaryEcho.Methods.get) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message([i])) + stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + XCTAssertMetadata(metadata) + + let message = try await responseParts.next() + XCTAssertMessage(message) { XCTAssertEqual($0, [i]) } + + let status = try await responseParts.next() + XCTAssertStatus(status) { status, _ in + XCTAssertEqual(status.code, .ok) + } + } + } + } + } + } + } + + func testInterceptorsAreAppliedInOrder() async throws { + let counter1 = ManagedAtomic(0) + let counter2 = ManagedAtomic(0) + + try await self.withInProcessClientConnectedToServer( + services: [BinaryEcho()], + interceptors: [ + .requestCounter(counter1), + .rejectAll(with: RPCError(code: .unavailable, message: "")), + .requestCounter(counter2), + ] + ) { client, _ in + try await client.withStream(descriptor: BinaryEcho.Methods.get) { stream in + try await stream.outbound.write(.metadata([:])) + stream.outbound.finish() + + let parts = try await stream.inbound.collect() + XCTAssertStatus(parts.first) { status, _ in + XCTAssertEqual(status.code, .unavailable) + } + } + } + + XCTAssertEqual(counter1.load(ordering: .sequentiallyConsistent), 1) + XCTAssertEqual(counter2.load(ordering: .sequentiallyConsistent), 0) + } + + func testInterceptorsAreNotAppliedToUnimplementedMethods() async throws { + let counter = ManagedAtomic(0) + + try await self.withInProcessClientConnectedToServer( + services: [BinaryEcho()], + interceptors: [.requestCounter(counter)] + ) { client, _ in + try await client.withStream( + descriptor: MethodDescriptor(service: "not", method: "implemented") + ) { stream in + try await stream.outbound.write(.metadata([:])) + stream.outbound.finish() + + let parts = try await stream.inbound.collect() + XCTAssertStatus(parts.first) { status, _ in + XCTAssertEqual(status.code, .unimplemented) + } + } + } + + XCTAssertEqual(counter.load(ordering: .sequentiallyConsistent), 0) + } + + func testNoNewRPCsAfterServerStopListening() async throws { + try await withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, server in + // Run an RPC so we know the server is up. + try await self.doEchoGet(using: client) + + // New streams should fail immediately after this. + server.stopListening() + + // RPC should fail now. + await XCTAssertThrowsRPCErrorAsync { + try await client.withStream(descriptor: BinaryEcho.Methods.get) { stream in + XCTFail("Stream shouldn't be opened") + } + } errorHandler: { error in + XCTAssertEqual(error.code, .failedPrecondition) + } + } + } + + func testInFlightRPCsCanContinueAfterServerStopListening() async throws { + try await withInProcessClientConnectedToServer(services: [BinaryEcho()]) { client, server in + try await client.withStream(descriptor: BinaryEcho.Methods.update) { stream in + try await stream.outbound.write(.metadata([:])) + var iterator = stream.inbound.makeAsyncIterator() + // Don't need to validate the response, just that the server is running. + let metadata = try await iterator.next() + XCTAssertMetadata(metadata) + + // New streams should fail immediately after this. + server.stopListening() + + try await stream.outbound.write(.message([0])) + stream.outbound.finish() + + let message = try await iterator.next() + XCTAssertMessage(message) { XCTAssertEqual($0, [0]) } + let status = try await iterator.next() + XCTAssertStatus(status) + } + } + } + + func testCancelRunningServer() async throws { + let inProcess = self.makeInProcessPair() + let task = Task { + let server = Server() + server.services.register(BinaryEcho()) + server.transports.add(inProcess.server) + try await server.run() + } + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try? await inProcess.client.connect(lazily: true) + } + + try await self.doEchoGet(using: inProcess.client) + // The server must be running at this point as an RPC has completed. + task.cancel() + try await task.value + + group.cancelAll() + } + } + + func testTestRunServerWithNoTransport() async throws { + let server = Server() + await XCTAssertThrowsErrorAsync(ofType: ServerError.self) { + try await server.run() + } errorHandler: { error in + XCTAssertEqual(error.code, .noTransportsConfigured) + } + } + + func testTestRunStoppedServer() async throws { + let server = Server() + server.transports.add(InProcessServerTransport()) + // Run the server. + let task = Task { try await server.run() } + task.cancel() + try await task.value + + // Server is stopped, should throw an error. + await XCTAssertThrowsErrorAsync(ofType: ServerError.self) { + try await server.run() + } errorHandler: { error in + XCTAssertEqual(error.code, .serverIsStopped) + } + } + + func testRunServerWhenTransportThrows() async throws { + let server = Server() + server.transports.add(ThrowOnRunServerTransport()) + await XCTAssertThrowsErrorAsync(ofType: ServerError.self) { + try await server.run() + } errorHandler: { error in + XCTAssertEqual(error.code, .failedToStartTransport) + } + } + + func testRunServerDrainsRunningTransportsWhenOneFailsToStart() async throws { + let server = Server() + + // Register the in process transport first and allow it to come up. + let inProcess = self.makeInProcessPair() + server.transports.add(inProcess.server) + + // Register a transport waits for a signal before throwing. + let signal = AsyncStream.makeStream(of: Void.self) + server.transports.add(ThrowOnSignalServerTransport(signal: signal.stream)) + + // Connect the in process client and start an RPC. When the stream is opened signal the + // other transport to throw. This stream should be failed by the server. + await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await inProcess.client.connect(lazily: true) + } + + group.addTask { + try await inProcess.client.withStream(descriptor: BinaryEcho.Methods.get) { stream in + // The stream is open to the in-process transport. Let the other transport start. + signal.continuation.finish() + try await stream.outbound.write(.metadata([:])) + stream.outbound.finish() + + let parts = try await stream.inbound.collect() + XCTAssertStatus(parts.first) { status, _ in + XCTAssertEqual(status.code, .unavailable) + } + } + } + + await XCTAssertThrowsErrorAsync(ofType: ServerError.self) { + try await server.run() + } errorHandler: { error in + XCTAssertEqual(error.code, .failedToStartTransport) + } + + group.cancelAll() + } + } + + func testInterceptorsDescription() async throws { + let server = Server() + server.interceptors.add(.rejectAll(with: .init(code: .aborted, message: ""))) + server.interceptors.add(.requestCounter(.init(0))) + let description = String(describing: server.interceptors) + let expected = #"["RejectAllServerInterceptor", "RequestCountingServerInterceptor"]"# + XCTAssertEqual(description, expected) + } + + func testServicesDescription() async throws { + let server = Server() + let methods: [(String, String)] = [ + ("helloworld.Greeter", "SayHello"), + ("echo.Echo", "Foo"), + ("echo.Echo", "Bar"), + ("echo.Echo", "Baz"), + ] + + for (service, method) in methods { + let descriptor = MethodDescriptor(service: service, method: method) + server.services.router.registerHandler( + forMethod: descriptor, + deserializer: IdentityDeserializer(), + serializer: IdentitySerializer() + ) { _ in + fatalError("Unreachable") + } + } + + let description = String(describing: server.services) + let expected = """ + ["echo.Echo/Bar", "echo.Echo/Baz", "echo.Echo/Foo", "helloworld.Greeter/SayHello"] + """ + + XCTAssertEqual(description, expected) + } + + private func doEchoGet(using transport: some ClientTransport) async throws { + try await transport.withStream(descriptor: BinaryEcho.Methods.get) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message([0])) + stream.outbound.finish() + // Don't need to validate the response, just that the server is running. + let parts = try await stream.inbound.collect() + XCTAssertEqual(parts.count, 3) + } + } +} diff --git a/Tests/GRPCCoreTests/Test Utilities/Services/BinaryEcho.swift b/Tests/GRPCCoreTests/Test Utilities/Services/BinaryEcho.swift new file mode 100644 index 000000000..6a4ceb07e --- /dev/null +++ b/Tests/GRPCCoreTests/Test Utilities/Services/BinaryEcho.swift @@ -0,0 +1,104 @@ +/* + * Copyright 2023, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import GRPCCore +import XCTest + +struct BinaryEcho: RegistrableRPCService { + func get( + _ request: ServerRequest.Single<[UInt8]> + ) async throws -> ServerResponse.Single<[UInt8]> { + ServerResponse.Single(message: request.message, metadata: request.metadata) + } + + func collect( + _ request: ServerRequest.Stream<[UInt8]> + ) async throws -> ServerResponse.Single<[UInt8]> { + let collected = try await request.messages.reduce(into: []) { $0.append(contentsOf: $1) } + return ServerResponse.Single(message: collected, metadata: request.metadata) + } + + func expand( + _ request: ServerRequest.Single<[UInt8]> + ) async throws -> ServerResponse.Stream<[UInt8]> { + return ServerResponse.Stream(metadata: request.metadata) { + for byte in request.message { + try await $0.write([byte]) + } + return [:] + } + } + + func update( + _ request: ServerRequest.Stream<[UInt8]> + ) async throws -> ServerResponse.Stream<[UInt8]> { + return ServerResponse.Stream(metadata: request.metadata) { + for try await message in request.messages { + try await $0.write(message) + } + return [:] + } + } + + func registerMethods(with router: inout RPCRouter) { + let serializer = IdentitySerializer() + let deserializer = IdentityDeserializer() + + router.registerHandler( + forMethod: Methods.get, + deserializer: deserializer, + serializer: serializer + ) { streamRequest in + let singleRequest = try await ServerRequest.Single(stream: streamRequest) + let singleResponse = try await self.get(singleRequest) + return ServerResponse.Stream(single: singleResponse) + } + + router.registerHandler( + forMethod: Methods.collect, + deserializer: deserializer, + serializer: serializer + ) { streamRequest in + let singleResponse = try await self.collect(streamRequest) + return ServerResponse.Stream(single: singleResponse) + } + + router.registerHandler( + forMethod: Methods.expand, + deserializer: deserializer, + serializer: serializer + ) { streamRequest in + let singleRequest = try await ServerRequest.Single(stream: streamRequest) + let streamResponse = try await self.expand(singleRequest) + return streamResponse + } + + router.registerHandler( + forMethod: Methods.update, + deserializer: deserializer, + serializer: serializer + ) { streamRequest in + let streamResponse = try await self.update(streamRequest) + return streamResponse + } + } + + enum Methods { + static let get = MethodDescriptor(service: "echo.Echo", method: "Get") + static let collect = MethodDescriptor(service: "echo.Echo", method: "Collect") + static let expand = MethodDescriptor(service: "echo.Echo", method: "Expand") + static let update = MethodDescriptor(service: "echo.Echo", method: "Update") + } +} diff --git a/Tests/GRPCCoreTests/Test Utilities/Transport/ThrowingTransport.swift b/Tests/GRPCCoreTests/Test Utilities/Transport/ThrowingTransport.swift index a80dc023d..0dd1cee9b 100644 --- a/Tests/GRPCCoreTests/Test Utilities/Transport/ThrowingTransport.swift +++ b/Tests/GRPCCoreTests/Test Utilities/Transport/ThrowingTransport.swift @@ -49,3 +49,36 @@ struct ThrowOnStreamCreationTransport: ClientTransport { throw RPCError(code: self.code, message: "") } } + +struct ThrowOnRunServerTransport: ServerTransport { + func listen() async throws -> RPCAsyncSequence> { + throw RPCError( + code: .unavailable, + message: "The '\(type(of: self))' transport is never available." + ) + } + + func stopListening() { + // no-op + } +} + +struct ThrowOnSignalServerTransport: ServerTransport { + let signal: AsyncStream + + init(signal: AsyncStream) { + self.signal = signal + } + + func listen() async throws -> RPCAsyncSequence> { + for await _ in self.signal {} + throw RPCError( + code: .unavailable, + message: "The '\(type(of: self))' transport is never available." + ) + } + + func stopListening() { + // no-op + } +} diff --git a/Tests/GRPCCoreTests/Test Utilities/XCTest+Utilities.swift b/Tests/GRPCCoreTests/Test Utilities/XCTest+Utilities.swift index b85889188..7bc88edef 100644 --- a/Tests/GRPCCoreTests/Test Utilities/XCTest+Utilities.swift +++ b/Tests/GRPCCoreTests/Test Utilities/XCTest+Utilities.swift @@ -37,6 +37,34 @@ func XCTAssertThrowsErrorAsync( } } +func XCTAssertThrowsError( + ofType: E.Type, + _ expression: @autoclosure () throws -> T, + _ errorHandler: (E) -> Void +) { + XCTAssertThrowsError(try expression()) { error in + guard let error = error as? E else { + return XCTFail("Error had unexpected type '\(type(of: error))'") + } + errorHandler(error) + } +} + +func XCTAssertThrowsErrorAsync( + ofType: E.Type = E.self, + _ expression: () async throws -> T, + errorHandler: (E) -> Void +) async { + do { + _ = try await expression() + XCTFail("Expression didn't throw") + } catch let error as E { + errorHandler(error) + } catch { + XCTFail("Error had unexpected type '\(type(of: error))'") + } +} + func XCTAssertThrowsRPCError( _ expression: @autoclosure () throws -> T, _ errorHandler: (RPCError) -> Void @@ -76,6 +104,30 @@ func XCTAssertRejected( } } +func XCTAssertMetadata( + _ part: RPCResponsePart?, + metadataHandler: (Metadata) -> Void = { _ in } +) { + switch part { + case .some(.metadata(let metadata)): + metadataHandler(metadata) + default: + XCTFail("Expected '.metadata' but found '\(String(describing: part))'") + } +} + +func XCTAssertMessage( + _ part: RPCResponsePart?, + messageHandler: ([UInt8]) -> Void = { _ in } +) { + switch part { + case .some(.message(let message)): + messageHandler(message) + default: + XCTFail("Expected '.metadata' but found '\(String(describing: part))'") + } +} + func XCTAssertStatus( _ part: RPCResponsePart?, statusHandler: (Status, Metadata) -> Void = { _, _ in } diff --git a/Tests/GRPCTests/ConnectionPool/GRPCChannelPoolTests.swift b/Tests/GRPCTests/ConnectionPool/GRPCChannelPoolTests.swift index dd41402d2..52b3f9377 100644 --- a/Tests/GRPCTests/ConnectionPool/GRPCChannelPoolTests.swift +++ b/Tests/GRPCTests/ConnectionPool/GRPCChannelPoolTests.swift @@ -520,13 +520,13 @@ final class GRPCChannelPoolTests: GRPCTestCase { func testDelegateCanTellWhenFirstConnectionIsBeingEstablished() { final class State { - private enum _State { + private enum Storage { case idle case connecting case connected } - private var state: _State = .idle + private var state: Storage = .idle private let lock = NIOLock() var isConnected: Bool {