From 6b2d8bcd38426bebab2ed0de9e18ec3303213728 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 29 Jan 2019 15:34:49 +0000 Subject: [PATCH 01/10] Improve error handling in NIO server. - Adds a user-configurable error handler to the server - Updates NIO server codegen to provide an optional error handler - Errors are handled by GRPCChannelHandler or BaseCallHandler, depending on the pipeline state - Adds some error handling tests - Tidies some logic in HTTP1ToRawGRPCServerCodec - Extends message handling logic in HTTP1ToRawGRPCServerCodec to handle messages split across multiple ByteBuffers (i.e. when a message exceeds a the size of a frame) --- Makefile | 2 +- .../CallHandlers/BaseCallHandler.swift | 63 ++++- .../BidirectionalStreamingCallHandler.swift | 4 +- .../ClientStreamingCallHandler.swift | 4 +- .../ServerStreamingCallHandler.swift | 10 +- .../CallHandlers/UnaryCallHandler.swift | 10 +- Sources/SwiftGRPCNIO/GRPCChannelHandler.swift | 47 ++-- Sources/SwiftGRPCNIO/GRPCServer.swift | 6 +- Sources/SwiftGRPCNIO/GRPCServerCodec.swift | 12 +- Sources/SwiftGRPCNIO/GRPCStatus.swift | 15 ++ .../HTTP1ToRawGRPCServerCodec.swift | 227 ++++++++++++------ .../Generator-Server.swift | 4 +- ...nnelHandlerResponseCapturingTestCase.swift | 78 ++++++ .../GRPCChannelHandlerTests.swift | 197 +++++++++++++++ Tests/SwiftGRPCNIOTests/NIOServerTests.swift | 10 + Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift | 10 +- 16 files changed, 578 insertions(+), 121 deletions(-) create mode 100644 Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift create mode 100644 Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift diff --git a/Makefile b/Makefile index b19dbccfd..c30e61663 100644 --- a/Makefile +++ b/Makefile @@ -48,7 +48,7 @@ test-plugin: test-plugin-nio: swift build $(CFLAGS) --product protoc-gen-swiftgrpc protoc Sources/Examples/Echo/echo.proto --proto_path=Sources/Examples/Echo --plugin=.build/debug/protoc-gen-swift --plugin=.build/debug/protoc-gen-swiftgrpc --swiftgrpc_out=/tmp --swiftgrpc_opt=Client=false,NIO=true - diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift + diff -u /tmp/echo.grpc.swift Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift xcodebuild: project xcodebuild -project SwiftGRPC.xcodeproj -configuration "Debug" -parallelizeTargets -target SwiftGRPC -target Echo -target Simple -target protoc-gen-swiftgrpc build diff --git a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift index 33d18d5e6..1808cafa5 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift @@ -8,29 +8,78 @@ import NIOHTTP1 /// Calls through to `processMessage` for individual messages it receives, which needs to be implemented by subclasses. public class BaseCallHandler: GRPCCallHandler { public func makeGRPCServerCodec() -> ChannelHandler { return GRPCServerCodec() } - + /// Called whenever a message has been received. /// /// Overridden by subclasses. - public func processMessage(_ message: RequestMessage) { + public func processMessage(_ message: RequestMessage) throws { fatalError("needs to be overridden") } - + /// Called when the client has half-closed the stream, indicating that they won't send any further data. /// /// Overridden by subclasses if the "end-of-stream" event is relevant. public func endOfStreamReceived() { } + + /// Whether this handler can still write messages to the client. + private var serverCanWrite = true + + /// Called for each error recieved in `errorCaught(ctx:error:)`. + private let errorHandler: ((Error) -> Void)? + + public init(errorHandler: ((Error) -> Void)? = nil) { + self.errorHandler = errorHandler + } } extension BaseCallHandler: ChannelInboundHandler { public typealias InboundIn = GRPCServerRequestPart - public typealias OutboundOut = GRPCServerResponsePart + + /// Passes errors to the user-provided `errorHandler`. After an error has been received an + /// appropriate status is written. Errors which don't conform to `GRPCStatusTransformable` + /// return a status with code `.internalError`. + public func errorCaught(ctx: ChannelHandlerContext, error: Error) { + errorHandler?(error) + + let status = (error as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError + self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart.status(status)), promise: nil) + } public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { switch self.unwrapInboundIn(data) { - case .head: preconditionFailure("should not have received headers") - case .message(let message): processMessage(message) - case .end: endOfStreamReceived() + case .head: + // Head should have been handled by `GRPCChannelHandler`. + self.errorCaught(ctx: ctx, error: GRPCStatus(code: .unknown, message: "unexpectedly received head")) + + case .message(let message): + do { + try processMessage(message) + } catch { + self.errorCaught(ctx: ctx, error: error) + } + + case .end: + endOfStreamReceived() + } + } +} + +extension BaseCallHandler: ChannelOutboundHandler { + public typealias OutboundIn = GRPCServerResponsePart + public typealias OutboundOut = GRPCServerResponsePart + + public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + guard serverCanWrite else { + promise?.fail(error: GRPCStatus.processingError) + return + } + + // We can only write one status; make sure we don't write again. + if case .status = unwrapOutboundIn(data) { + serverCanWrite = false + ctx.writeAndFlush(data, promise: promise) + } else { + ctx.write(data, promise: promise) } } } diff --git a/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift index 46f4b7622..f147ec719 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift @@ -15,8 +15,8 @@ public class BidirectionalStreamingCallHandler) -> EventLoopFuture) { - super.init() + public init(channel: Channel, request: HTTPRequestHead, errorHandler: ((Error) -> Void)?, eventObserverFactory: (StreamingResponseCallContext) -> EventLoopFuture) { + super.init(errorHandler: errorHandler) let context = StreamingResponseCallContextImpl(channel: channel, request: request) self.context = context let eventObserver = eventObserverFactory(context) diff --git a/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift index bd03ae744..8886d5660 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift @@ -14,8 +14,8 @@ public class ClientStreamingCallHandler) -> EventLoopFuture) { - super.init() + public init(channel: Channel, request: HTTPRequestHead, errorHandler: ((Error) -> Void)?, eventObserverFactory: (UnaryResponseCallContext) -> EventLoopFuture) { + super.init(errorHandler: errorHandler) let context = UnaryResponseCallContextImpl(channel: channel, request: request) self.context = context let eventObserver = eventObserverFactory(context) diff --git a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift index 893745c69..ed01046d3 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift @@ -13,8 +13,8 @@ public class ServerStreamingCallHandler? - public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (StreamingResponseCallContext) -> EventObserver) { - super.init() + public init(channel: Channel, request: HTTPRequestHead, errorHandler: ((Error) -> Void)?, eventObserverFactory: (StreamingResponseCallContext) -> EventObserver) { + super.init(errorHandler: errorHandler) let context = StreamingResponseCallContextImpl(channel: channel, request: request) self.context = context self.eventObserver = eventObserverFactory(context) @@ -26,12 +26,10 @@ public class ServerStreamingCallHandler private var context: UnaryResponseCallContext? - public init(channel: Channel, request: HTTPRequestHead, eventObserverFactory: (UnaryResponseCallContext) -> EventObserver) { - super.init() + public init(channel: Channel, request: HTTPRequestHead, errorHandler: ((Error) -> Void)?, eventObserverFactory: (UnaryResponseCallContext) -> EventObserver) { + super.init(errorHandler: errorHandler) let context = UnaryResponseCallContextImpl(channel: channel, request: request) self.context = context self.eventObserver = eventObserverFactory(context) @@ -26,12 +26,10 @@ public class UnaryCallHandler } } - public override func processMessage(_ message: RequestMessage) { + public override func processMessage(_ message: RequestMessage) throws { guard let eventObserver = self.eventObserver, let context = self.context else { - //! FIXME: Better handle this error? - print("multiple messages received on unary call") - return + throw GRPCStatus(code: .unimplemented, message: "multiple messages received on unary call") } let resultFuture = eventObserver(message) diff --git a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift index 3b8b475eb..00ad07d43 100644 --- a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift +++ b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift @@ -19,7 +19,7 @@ public protocol CallHandlerProvider: class { /// Determines, calls and returns the appropriate request handler (`GRPCCallHandler`), depending on the request's /// method. Returns nil for methods not handled by this service. - func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler? + func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorHandler: ((Error) -> Void)?) -> GRPCCallHandler? } /// Listens on a newly-opened HTTP2 subchannel and yields to the sub-handler matching a call, if available. @@ -28,30 +28,31 @@ public protocol CallHandlerProvider: class { /// for an `GRPCCallHandler` object. That object is then forwarded the individual gRPC messages. public final class GRPCChannelHandler { private let servicesByName: [String: CallHandlerProvider] + private let errorHandler: ((Error) -> Void)? - public init(servicesByName: [String: CallHandlerProvider]) { + public init(servicesByName: [String: CallHandlerProvider], errorHandler: ((Error) -> Void)? = nil) { self.servicesByName = servicesByName + self.errorHandler = errorHandler } } extension GRPCChannelHandler: ChannelInboundHandler { public typealias InboundIn = RawGRPCServerRequestPart public typealias OutboundOut = RawGRPCServerResponsePart - + + public func errorCaught(ctx: ChannelHandlerContext, error: Error) { + errorHandler?(error) + + let status = (error as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError + ctx.writeAndFlush(wrapOutboundOut(.status(status)), promise: nil) + } + public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { let requestPart = self.unwrapInboundIn(data) switch requestPart { case .head(let requestHead): - // URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash: - // - uriComponents[0]: empty - // - uriComponents[1]: service name (including the package name); - // `CallHandlerProvider`s should provide the service name including the package name. - // - uriComponents[2]: method name. - let uriComponents = requestHead.uri.components(separatedBy: "/") - guard uriComponents.count >= 3 && uriComponents[0].isEmpty, - let providerForServiceName = servicesByName[uriComponents[1]], - let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: ctx.channel) else { - ctx.writeAndFlush(self.wrapOutboundOut(.status(.unimplemented(method: requestHead.uri))), promise: nil) + guard let callHandler = getCallHandler(channel: ctx.channel, requestHead: requestHead) else { + errorCaught(ctx: ctx, error: GRPCStatus.unimplemented(method: requestHead.uri)) return } @@ -71,7 +72,25 @@ extension GRPCChannelHandler: ChannelInboundHandler { .whenComplete { ctx.pipeline.remove(handler: self, promise: handlerRemoved) } case .message, .end: - preconditionFailure("received \(requestPart), should have been removed as a handler at this point") + // We can reach this point if we're receiving messages for a method that isn't implemented. + // A status resposne will have been fired which should also close the stream; there's not a + // lot we can do at this point. + break + } + } + + private func getCallHandler(channel: Channel, requestHead: HTTPRequestHead) -> GRPCCallHandler? { + // URI format: "/package.Servicename/MethodName", resulting in the following components separated by a slash: + // - uriComponents[0]: empty + // - uriComponents[1]: service name (including the package name); + // `CallHandlerProvider`s should provide the service name including the package name. + // - uriComponents[2]: method name. + let uriComponents = requestHead.uri.components(separatedBy: "/") + guard uriComponents.count >= 3 && uriComponents[0].isEmpty, + let providerForServiceName = servicesByName[uriComponents[1]], + let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: channel, errorHandler: errorHandler) else { + return nil } + return callHandler } } diff --git a/Sources/SwiftGRPCNIO/GRPCServer.swift b/Sources/SwiftGRPCNIO/GRPCServer.swift index f87a5166a..67aa1a9c0 100644 --- a/Sources/SwiftGRPCNIO/GRPCServer.swift +++ b/Sources/SwiftGRPCNIO/GRPCServer.swift @@ -12,7 +12,9 @@ public final class GRPCServer { hostname: String, port: Int, eventLoopGroup: EventLoopGroup, - serviceProviders: [CallHandlerProvider]) -> EventLoopFuture { + serviceProviders: [CallHandlerProvider], + errorHandler: ((Error) -> Void)? = nil + ) -> EventLoopFuture { let servicesByName = Dictionary(uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) }) let bootstrap = ServerBootstrap(group: eventLoopGroup) // Specify a backlog to avoid overloading the server. @@ -27,7 +29,7 @@ public final class GRPCServer { let multiplexer = HTTP2StreamMultiplexer { (channel, streamID) -> EventLoopFuture in return channel.pipeline.add(handler: HTTP2ToHTTP1ServerCodec(streamID: streamID)) .then { channel.pipeline.add(handler: HTTP1ToRawGRPCServerCodec()) } - .then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName)) } + .then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName, errorHandler: errorHandler)) } } return channel.pipeline.add(handler: multiplexer) diff --git a/Sources/SwiftGRPCNIO/GRPCServerCodec.swift b/Sources/SwiftGRPCNIO/GRPCServerCodec.swift index 21652e8bb..4cc5b214d 100644 --- a/Sources/SwiftGRPCNIO/GRPCServerCodec.swift +++ b/Sources/SwiftGRPCNIO/GRPCServerCodec.swift @@ -19,7 +19,7 @@ public enum GRPCServerResponsePart { } /// A simple channel handler that translates raw gRPC packets into decoded protobuf messages, and vice versa. -public final class GRPCServerCodec { } +public final class GRPCServerCodec {} extension GRPCServerCodec: ChannelInboundHandler { public typealias InboundIn = RawGRPCServerRequestPart @@ -35,8 +35,7 @@ extension GRPCServerCodec: ChannelInboundHandler { do { ctx.fireChannelRead(self.wrapInboundOut(.message(try RequestMessage(serializedData: messageAsData)))) } catch { - //! FIXME: Ensure that the last handler in the pipeline returns `.dataLoss` here? - ctx.fireErrorCaught(error) + ctx.fireErrorCaught(GRPCStatus.requestProtoParseError) } case .end: @@ -54,6 +53,7 @@ extension GRPCServerCodec: ChannelOutboundHandler { switch responsePart { case .headers(let headers): ctx.write(self.wrapOutboundOut(.headers(headers)), promise: promise) + case .message(let message): do { let messageData = try message.serializedData() @@ -61,9 +61,11 @@ extension GRPCServerCodec: ChannelOutboundHandler { responseBuffer.write(bytes: messageData) ctx.write(self.wrapOutboundOut(.message(responseBuffer)), promise: promise) } catch { - promise?.fail(error: error) - ctx.fireErrorCaught(error) + let status = GRPCStatus.responseProtoSerializationError + promise?.fail(error: status) + ctx.fireErrorCaught(status) } + case .status(let status): ctx.write(self.wrapOutboundOut(.status(status)), promise: promise) } diff --git a/Sources/SwiftGRPCNIO/GRPCStatus.swift b/Sources/SwiftGRPCNIO/GRPCStatus.swift index 9e4109b82..cb1ae07d4 100644 --- a/Sources/SwiftGRPCNIO/GRPCStatus.swift +++ b/Sources/SwiftGRPCNIO/GRPCStatus.swift @@ -27,4 +27,19 @@ public struct GRPCStatus: Error { public static func unimplemented(method: String) -> GRPCStatus { return GRPCStatus(code: .unimplemented, message: "unknown method " + method) } + + // These status codes are informed by: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md + static internal let requestProtoParseError = GRPCStatus(code: .internalError, message: "could not parse request proto") + static internal let responseProtoSerializationError = GRPCStatus(code: .internalError, message: "could not serialize response proto") + static internal let unsupportedCompression = GRPCStatus(code: .unimplemented, message: "compression is not supported on the server") +} + +protocol GRPCStatusTransformable: Error { + func asGRPCStatus() -> GRPCStatus +} + +extension GRPCStatus: GRPCStatusTransformable { + func asGRPCStatus() -> GRPCStatus { + return self + } } diff --git a/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift b/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift index f8a84cfe3..3db609585 100644 --- a/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift +++ b/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift @@ -27,111 +27,200 @@ public enum RawGRPCServerResponsePart { /// /// The translation from HTTP2 to HTTP1 is done by `HTTP2ToHTTP1ServerCodec`. public final class HTTP1ToRawGRPCServerCodec { - private enum State { + internal var inboundState = InboundState.expectingHeaders + internal var outboundState = OutboundState.expectingHeaders + + private var buffer: NIO.ByteBuffer? + + // 1-byte for compression flag, 4-bytes for message length. + private let protobufMetadataSize = 5 +} + +extension HTTP1ToRawGRPCServerCodec { + enum InboundState { case expectingHeaders - case expectingCompressedFlag - case expectingMessageLength - case receivedMessageLength(UInt32) - - var expectingBody: Bool { - switch self { - case .expectingHeaders: return false - case .expectingCompressedFlag, .expectingMessageLength, .receivedMessageLength: return true - } + case expectingBody(Body) + // ignore any additional messages; e.g. we've seen .end or we've sent an error and are waiting for the stream to close. + case ignore + + enum Body { + case expectingCompressedFlag + case expectingMessageLength + case receivedMessageLength(UInt32) } } - private var state = State.expectingHeaders + enum OutboundState { + case expectingHeaders + case expectingBodyOrStatus + case ignore + } +} - private var buffer: NIO.ByteBuffer? +extension HTTP1ToRawGRPCServerCodec { + struct StateMachineError: Error, GRPCStatusTransformable { + private let message: String + + init(_ message: String) { + self.message = message + } + + func asGRPCStatus() -> GRPCStatus { + return GRPCStatus.processingError + } + } } extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { public typealias InboundIn = HTTPServerRequestPart public typealias InboundOut = RawGRPCServerRequestPart - + public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { - switch self.unwrapInboundIn(data) { - case .head(let requestHead): - guard case .expectingHeaders = state - else { preconditionFailure("received headers while in state \(state)") } - - state = .expectingCompressedFlag - buffer = ctx.channel.allocator.buffer(capacity: 5) - - ctx.fireChannelRead(self.wrapInboundOut(.head(requestHead))) - - case .body(var body): - guard var buffer = buffer - else { preconditionFailure("buffer not initialized") } - assert(state.expectingBody, "received body while in state \(state)") - buffer.write(buffer: &body) - - // Iterate over all available incoming data, trying to read length-delimited messages. - // Each message has the following format: - // - 1 byte "compressed" flag (currently always zero, as we do not support for compression) - // - 4 byte signed-integer payload length (N) - // - N bytes payload (normally a valid wire-format protocol buffer) - requestProcessing: while true { - switch state { - case .expectingHeaders: preconditionFailure("unexpected state \(state)") - case .expectingCompressedFlag: - guard let compressionFlag: Int8 = buffer.readInteger() else { break requestProcessing } - //! FIXME: Avoid crashing here and instead drop the connection. - precondition(compressionFlag == 0, "unexpected compression flag \(compressionFlag); compression is not supported and we did not indicate support for it") - state = .expectingMessageLength - - case .expectingMessageLength: - guard let messageLength: UInt32 = buffer.readInteger() else { break requestProcessing } - state = .receivedMessageLength(messageLength) - - case .receivedMessageLength(let messageLength): - guard let messageBytes = buffer.readBytes(length: numericCast(messageLength)) else { break } - - //! FIXME: Use a slice of this buffer instead of copying to a new buffer. - var responseBuffer = ctx.channel.allocator.buffer(capacity: messageBytes.count) - responseBuffer.write(bytes: messageBytes) - ctx.fireChannelRead(self.wrapInboundOut(.message(responseBuffer))) - //! FIXME: Call buffer.discardReadBytes() here? - //! ALTERNATIVE: Check if the buffer has no further data right now, then clear it. - - state = .expectingCompressedFlag - } + if case .ignore = inboundState { return } + + do { + switch self.unwrapInboundIn(data) { + case .head(let requestHead): + inboundState = try processHead(ctx: ctx, requestHead: requestHead) + + case .body(var body): + inboundState = try processBody(ctx: ctx, body: &body) + + case .end(let trailers): + inboundState = try processEnd(ctx: ctx, trailers: trailers) } + } catch { + ctx.fireErrorCaught(error) + inboundState = .ignore + } + } + + func processHead(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) throws -> InboundState { + guard case .expectingHeaders = inboundState else { + throw StateMachineError("expecteded state .expectingHeaders, got \(inboundState)") + } + + ctx.fireChannelRead(self.wrapInboundOut(.head(requestHead))) - case .end(let trailers): - if let trailers = trailers { - //! FIXME: Better handle this error. - print("unexpected trailers received: \(trailers)") - return + return .expectingBody(.expectingCompressedFlag) + } + + func processBody(ctx: ChannelHandlerContext, body: inout ByteBuffer) throws -> InboundState { + guard case .expectingBody(let bodyState) = inboundState else { + throw StateMachineError("expecteded state .expectingBody(_), got \(inboundState)") + } + + return .expectingBody(try processBodyState(ctx: ctx, initialState: bodyState, messageBuffer: &body)) + } + + func processBodyState(ctx: ChannelHandlerContext, initialState: InboundState.Body, messageBuffer: inout ByteBuffer) throws -> InboundState.Body { + var bodyState = initialState + + // Iterate over all available incoming data, trying to read length-delimited messages. + // Each message has the following format: + // - 1 byte "compressed" flag (currently always zero, as we do not support for compression) + // - 4 byte signed-integer payload length (N) + // - N bytes payload (normally a valid wire-format protocol buffer) + while true { + switch bodyState { + case .expectingCompressedFlag: + guard let compressionFlag: Int8 = messageBuffer.readInteger() else { return .expectingCompressedFlag } + + // TODO: Add support for compression. + guard compressionFlag == 0 else { throw GRPCStatus.unsupportedCompression } + + bodyState = .expectingMessageLength + + case .expectingMessageLength: + guard let messageLength: UInt32 = messageBuffer.readInteger() else { return .expectingMessageLength } + bodyState = .receivedMessageLength(messageLength) + + case .receivedMessageLength(let messageLength): + // We need to account for messages being spread across multiple `ByteBuffer`s so buffer them + // into `buffer`. Note: when messages are contained within a single `ByteBuffer` we're just + // taking a slice so don't incur any extra writes. + guard messageBuffer.readableBytes >= messageLength else { + let remainingBytes = messageLength - numericCast(messageBuffer.readableBytes) + + if var buffer = buffer { + buffer.write(buffer: &messageBuffer) + self.buffer = buffer + } else { + messageBuffer.reserveCapacity(numericCast(messageLength)) + self.buffer = messageBuffer + } + + return .receivedMessageLength(remainingBytes) + } + + // We know buffer.readableBytes >= messageLength, so it's okay to force unwrap here. + var slice = messageBuffer.readSlice(length: numericCast(messageLength))! + + if var buffer = buffer { + buffer.write(buffer: &slice) + ctx.fireChannelRead(self.wrapInboundOut(.message(buffer))) + } else { + ctx.fireChannelRead(self.wrapInboundOut(.message(slice))) + } + + self.buffer = nil + bodyState = .expectingCompressedFlag } - ctx.fireChannelRead(self.wrapInboundOut(.end)) } } + + private func processEnd(ctx: ChannelHandlerContext, trailers: HTTPHeaders?) throws -> InboundState { + guard trailers == nil else { + throw StateMachineError("unexpected trailers received \(String(describing: trailers))") + } + + ctx.fireChannelRead(self.wrapInboundOut(.end)) + return .ignore + } } extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler { public typealias OutboundIn = RawGRPCServerResponsePart public typealias OutboundOut = HTTPServerResponsePart - + public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + if case .ignore = outboundState { return } + let responsePart = self.unwrapOutboundIn(data) switch responsePart { case .headers(let headers): + guard case .expectingHeaders = outboundState else { return } + //! FIXME: Should return a different version if we want to support pPRC. ctx.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: .init(major: 2, minor: 0), status: .ok, headers: headers))), promise: promise) + outboundState = .expectingBodyOrStatus + case .message(var messageBytes): - // Write out a length-delimited message payload. See `channelRead` fpor the corresponding format. - var responseBuffer = ctx.channel.allocator.buffer(capacity: messageBytes.readableBytes + 5) + guard case .expectingBodyOrStatus = outboundState else { return } + + // Write out a length-delimited message payload. See `processBodyState` for the corresponding format. + var responseBuffer = ctx.channel.allocator.buffer(capacity: messageBytes.readableBytes + protobufMetadataSize) responseBuffer.write(integer: Int8(0)) // Compression flag: no compression responseBuffer.write(integer: UInt32(messageBytes.readableBytes)) responseBuffer.write(buffer: &messageBytes) ctx.write(self.wrapOutboundOut(.body(.byteBuffer(responseBuffer))), promise: promise) + outboundState = .expectingBodyOrStatus + case .status(let status): var trailers = status.trailingMetadata trailers.add(name: "grpc-status", value: String(describing: status.code.rawValue)) trailers.add(name: "grpc-message", value: status.message) - ctx.write(self.wrapOutboundOut(.end(trailers)), promise: promise) + + // "Trailers-Only" response + if case .expectingHeaders = outboundState { + trailers.add(name: "content-type", value: "application/grpc") + let responseHead = HTTPResponseHead(version: .init(major: 2, minor: 0), status: .ok) + ctx.write(self.wrapOutboundOut(.head(responseHead)), promise: nil) + } + + ctx.writeAndFlush(self.wrapOutboundOut(.end(trailers)), promise: promise) + outboundState = .ignore + inboundState = .ignore } } } diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift index 45de736ff..ec46ab4e1 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift @@ -85,7 +85,7 @@ extension Generator { if options.generateNIOImplementation { println("/// Determines, calls and returns the appropriate request handler, depending on the request's method.") println("/// Returns nil for methods not handled by this service.") - println("\(access) func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler? {") + println("\(access) func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorHandler: ((Error) -> Void)? = nil) -> GRPCCallHandler? {") indent() println("switch methodName {") for method in service.methods { @@ -99,7 +99,7 @@ extension Generator { case .clientStreaming: callHandlerType = "ClientStreamingCallHandler" case .bidirectionalStreaming: callHandlerType = "BidirectionalStreamingCallHandler" } - println("return \(callHandlerType)(channel: channel, request: request) { context in") + println("return \(callHandlerType)(channel: channel, request: request, errorHandler: errorHandler) { context in") indent() switch streamingType(method) { case .unary, .serverStreaming: diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift new file mode 100644 index 000000000..3f0eda98f --- /dev/null +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift @@ -0,0 +1,78 @@ +import Foundation +import NIO +import NIOHTTP1 +@testable import SwiftGRPCNIO +import XCTest + +internal struct CaseExtractError: Error { + let message: String +} + +@discardableResult +func extractHeaders(_ response: RawGRPCServerResponsePart) throws -> HTTPHeaders { + guard case .headers(let headers) = response else { + throw CaseExtractError(message: "\(response) did not match .headers") + } + return headers +} + +@discardableResult +func extractMessage(_ response: RawGRPCServerResponsePart) throws -> ByteBuffer { + guard case .message(let message) = response else { + throw CaseExtractError(message: "\(response) did not match .message") + } + return message +} + +@discardableResult +func extractStatus(_ response: RawGRPCServerResponsePart) throws -> GRPCStatus { + guard case .status(let status) = response else { + throw CaseExtractError(message: "\(response) did not match .status") + } + return status +} + +class CollectingChannelHandler: ChannelOutboundHandler { + var responses: [OutboundIn] = [] + var responseExpectation: XCTestExpectation? + + init(responseExpectation: XCTestExpectation? = nil) { + self.responseExpectation = responseExpectation + } + + func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + responses.append(unwrapOutboundIn(data)) + responseExpectation?.fulfill() + } +} + +class GRPCChannelHandlerResponseCapturingTestCase: XCTestCase { + static let defaultTimeout: TimeInterval = 0.2 + static let echoProvider: [String: CallHandlerProvider] = ["echo.Echo": EchoProvider_NIO()] + + func configureChannel(withHandlers handlers: [ChannelHandler]) -> EventLoopFuture { + let channel = EmbeddedChannel() + return channel.pipeline.addHandlers(handlers, first: true) + .map { _ in channel } + } + + /// Waits for `count` responses to be collected and then returns them. The test fails if too many + /// responses are collected or not enough are collected before the timeout. + func waitForGRPCChannelHandlerResponses( + count: Int, + servicesByName: [String: CallHandlerProvider] = echoProvider, + timeout: TimeInterval = defaultTimeout, + callback: @escaping (EmbeddedChannel) throws -> Void + ) -> [RawGRPCServerResponsePart] { + let responseExpectation = expectation(description: "expecting \(count) responses") + responseExpectation.expectedFulfillmentCount = count + responseExpectation.assertForOverFulfill = true + + let collector = CollectingChannelHandler(responseExpectation: responseExpectation) + _ = configureChannel(withHandlers: [collector, GRPCChannelHandler(servicesByName: servicesByName)]) + .thenThrowing(callback) + + waitForExpectations(timeout: timeout) + return collector.responses + } +} diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift new file mode 100644 index 000000000..1bae3fa20 --- /dev/null +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift @@ -0,0 +1,197 @@ +import Foundation +import XCTest +import NIO +import NIOHTTP1 +@testable import SwiftGRPCNIO + +func gRPCMessage(channel: EmbeddedChannel, compression: Bool = false, message: Data? = nil) -> ByteBuffer { + let messageLength = message?.count ?? 0 + var buffer = channel.allocator.buffer(capacity: 5 + messageLength) + buffer.write(integer: Int8(compression ? 1 : 0)) + buffer.write(integer: UInt32(messageLength)) + if let bytes = message { + buffer.write(bytes: bytes) + } + return buffer +} + +class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { + func testUnimplementedMethodReturnsUnimplementedStatus() throws { + let responses = waitForGRPCChannelHandlerResponses(count: 1) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "unimplemented") + try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) + } + + XCTAssertNoThrow(try extractStatus(responses[0])) { status in + XCTAssertEqual(status.code, .unimplemented) + } + } + + func testImplementedMethodReturnsHeadersMessageAndStatus() throws { + let responses = waitForGRPCChannelHandlerResponses(count: 3) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) + + let request = Echo_EchoRequest.with { $0.text = "echo!" } + let requestData = try request.serializedData() + var buffer = channel.allocator.buffer(capacity: requestData.count) + buffer.write(bytes: requestData) + try channel.writeInbound(RawGRPCServerRequestPart.message(buffer)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractMessage(responses[1])) + XCTAssertNoThrow(try extractStatus(responses[2])) { status in + XCTAssertEqual(status.code, .ok) + } + } + + func testImplementedMethodReturnsStatusForBadlyFormedProto() throws { + let responses = waitForGRPCChannelHandlerResponses(count: 2) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) + + var buffer = channel.allocator.buffer(capacity: 3) + buffer.write(bytes: [1, 2, 3]) + try channel.writeInbound(RawGRPCServerRequestPart.message(buffer)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + let expectedStatus = GRPCStatus.requestProtoParseError + XCTAssertEqual(status.code, expectedStatus.code) + XCTAssertEqual(status.message, expectedStatus.message) + } + } +} + +class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCase { + func testUnimplementedStatusReturnedWhenCompressionFlagIsSet() throws { + let responses = waitForGRPCChannelHandlerResponses(count: 2) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel, compression: true))) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + let expected = GRPCStatus.unsupportedCompression + XCTAssertEqual(status.code, expected.code) + XCTAssertEqual(status.message, expected.message) + } + } + + func testMessageCanBeSentAcrossMultipleByteBuffers() throws { + let responses = waitForGRPCChannelHandlerResponses(count: 3) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + // Sending the header allocates a buffer. + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + + let request = Echo_EchoRequest.with { $0.text = "echo!" } + let requestAsData = try request.serializedData() + + var buffer = channel.allocator.buffer(capacity: 1) + buffer.write(integer: Int8(0)) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + + buffer = channel.allocator.buffer(capacity: 4) + buffer.write(integer: Int32(requestAsData.count)) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + + buffer = channel.allocator.buffer(capacity: requestAsData.count) + buffer.write(bytes: requestAsData) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractMessage(responses[1])) + XCTAssertNoThrow(try extractStatus(responses[2])) { status in + XCTAssertEqual(status.code, .ok) + } + } + + func testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized() throws { + let responses = waitForGRPCChannelHandlerResponses(count: 2) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + + let buffer = gRPCMessage(channel: channel, message: Data(bytes: [42])) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + let expected = GRPCStatus.requestProtoParseError + XCTAssertEqual(status.code, expected.code) + XCTAssertEqual(status.message, expected.message) + } + } + + func testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest() throws { + let responses = waitForGRPCChannelHandlerResponses(count: 2) { channel in + // We have to use "Collect" (client streaming) as the tests rely on `EmbeddedChannel` which runs in this thread. + // In the current server implementation, responses from unary calls send a status immediately after sending the response. + // As such, a unary "Get" would return an "ok" status before the trailers would be sent. + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Collect") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) + + var trailers = HTTPHeaders() + trailers.add(name: "foo", value: "bar") + try channel.writeInbound(HTTPServerRequestPart.end(trailers)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + XCTAssertEqual(status.code, .internalError) + } + } + + func testOnlyOneStatusIsReturned() throws { + let responses = waitForGRPCChannelHandlerResponses(count: 3) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) + + // Sending trailers with `.end` should trigger an error. However, writing a message to a unary call + // will trigger a response and status to be sent back. Since we're using `EmbeddedChannel` this will + // be done before the trailers are sent. If a 4th resposne were to be sent (for the error status) then + // the test would fail. + + var trailers = HTTPHeaders() + trailers.add(name: "foo", value: "bar") + try channel.writeInbound(HTTPServerRequestPart.end(trailers)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractMessage(responses[1])) + XCTAssertNoThrow(try extractStatus(responses[2])) { status in + XCTAssertEqual(status.code, .ok) + } + } + + override func waitForGRPCChannelHandlerResponses( + count: Int, + servicesByName: [String: CallHandlerProvider] = GRPCChannelHandlerResponseCapturingTestCase.echoProvider, + timeout: TimeInterval = GRPCChannelHandlerResponseCapturingTestCase.defaultTimeout, + callback: @escaping (EmbeddedChannel) throws -> Void + ) -> [RawGRPCServerResponsePart] { + return super.waitForGRPCChannelHandlerResponses(count: count, servicesByName: servicesByName, timeout: timeout) { channel in + _ = channel.pipeline.addHandlers(HTTP1ToRawGRPCServerCodec(), first: true) + .thenThrowing { _ in try callback(channel) } + } + } +} + +// Assert the given expression does not throw, and validate the return value from that expression. +public func XCTAssertNoThrow( + _ expression: @autoclosure () throws -> T, + _ message: String = "", + file: StaticString = #file, + line: UInt = #line, + validate: (T) -> Void +) { + var value: T? = nil + XCTAssertNoThrow(try value = expression(), message, file: file, line: line) + value.map { validate($0) } +} diff --git a/Tests/SwiftGRPCNIOTests/NIOServerTests.swift b/Tests/SwiftGRPCNIOTests/NIOServerTests.swift index 61801389d..60fbbe526 100644 --- a/Tests/SwiftGRPCNIOTests/NIOServerTests.swift +++ b/Tests/SwiftGRPCNIOTests/NIOServerTests.swift @@ -122,6 +122,12 @@ extension NIOServerTests { XCTAssertEqual("Swift echo get: foo", try! client.get(Echo_EchoRequest(text: "foo")).text) } + func testUnaryWithLargeData() throws { + // Default max frame size is: 16,384. We'll exceed this as we also have to send the size and compression flag. + let request = Echo_EchoRequest.with { $0.text = String(repeating: "e", count: 16_384) } + XCTAssertNoThrow(try client.get(request)) + } + func testUnaryLotsOfRequests() { // Sending that many requests at once can sometimes trip things up, it seems. client.timeout = 5.0 @@ -135,6 +141,10 @@ extension NIOServerTests { } print("total time for \(numberOfRequests) requests: \(Double(clock() - clockStart) / Double(CLOCKS_PER_SEC))") } + + func testUnaryEmptyRequest() throws { + XCTAssertNoThrow(try client.get(Echo_EchoRequest())) + } } extension NIOServerTests { diff --git a/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift b/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift index e3cfbfb44..6a235fece 100644 --- a/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift +++ b/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift @@ -40,29 +40,29 @@ extension Echo_EchoProvider_NIO { /// Determines, calls and returns the appropriate request handler, depending on the request's method. /// Returns nil for methods not handled by this service. - internal func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel) -> GRPCCallHandler? { + internal func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorHandler: ((Error) -> Void)? = nil) -> GRPCCallHandler? { switch methodName { case "Get": - return UnaryCallHandler(channel: channel, request: request) { context in + return UnaryCallHandler(channel: channel, request: request, errorHandler: errorHandler) { context in return { request in self.get(request: request, context: context) } } case "Expand": - return ServerStreamingCallHandler(channel: channel, request: request) { context in + return ServerStreamingCallHandler(channel: channel, request: request, errorHandler: errorHandler) { context in return { request in self.expand(request: request, context: context) } } case "Collect": - return ClientStreamingCallHandler(channel: channel, request: request) { context in + return ClientStreamingCallHandler(channel: channel, request: request, errorHandler: errorHandler) { context in return self.collect(context: context) } case "Update": - return BidirectionalStreamingCallHandler(channel: channel, request: request) { context in + return BidirectionalStreamingCallHandler(channel: channel, request: request, errorHandler: errorHandler) { context in return self.update(context: context) } From d2be70d10bbc9ffd65011670d6a7dd7ec9ef9d47 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Wed, 30 Jan 2019 14:19:41 +0000 Subject: [PATCH 02/10] Update error delegate --- .../CallHandlers/BaseCallHandler.swift | 11 ++++---- .../BidirectionalStreamingCallHandler.swift | 4 +-- .../ClientStreamingCallHandler.swift | 4 +-- .../ServerStreamingCallHandler.swift | 4 +-- .../CallHandlers/UnaryCallHandler.swift | 4 +-- Sources/SwiftGRPCNIO/GRPCChannelHandler.swift | 15 +++++------ Sources/SwiftGRPCNIO/GRPCServer.swift | 4 +-- .../SwiftGRPCNIO/ServerErrorDelegate.swift | 23 +++++++++++++++++ .../Generator-Server.swift | 4 +-- ...nnelHandlerResponseCapturingTestCase.swift | 25 ++++++------------- .../GRPCChannelHandlerTests.swift | 21 ++++++++-------- Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift | 10 ++++---- 12 files changed, 71 insertions(+), 58 deletions(-) create mode 100644 Sources/SwiftGRPCNIO/ServerErrorDelegate.swift diff --git a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift index 1808cafa5..3fdfa1b20 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift @@ -25,10 +25,10 @@ public class BaseCallHandler: private var serverCanWrite = true /// Called for each error recieved in `errorCaught(ctx:error:)`. - private let errorHandler: ((Error) -> Void)? + private weak var errorDelegate: ServerErrorDelegate? - public init(errorHandler: ((Error) -> Void)? = nil) { - self.errorHandler = errorHandler + public init(errorDelegate: ServerErrorDelegate? = nil) { + self.errorDelegate = errorDelegate } } @@ -39,9 +39,10 @@ extension BaseCallHandler: ChannelInboundHandler { /// appropriate status is written. Errors which don't conform to `GRPCStatusTransformable` /// return a status with code `.internalError`. public func errorCaught(ctx: ChannelHandlerContext, error: Error) { - errorHandler?(error) + errorDelegate?.observe(error) - let status = (error as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError + let transformed = errorDelegate?.transform(error) ?? error + let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart.status(status)), promise: nil) } diff --git a/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift index f147ec719..2d5a4294c 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/BidirectionalStreamingCallHandler.swift @@ -15,8 +15,8 @@ public class BidirectionalStreamingCallHandler Void)?, eventObserverFactory: (StreamingResponseCallContext) -> EventLoopFuture) { - super.init(errorHandler: errorHandler) + public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext) -> EventLoopFuture) { + super.init(errorDelegate: errorDelegate) let context = StreamingResponseCallContextImpl(channel: channel, request: request) self.context = context let eventObserver = eventObserverFactory(context) diff --git a/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift index 8886d5660..a6213a497 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift @@ -14,8 +14,8 @@ public class ClientStreamingCallHandler Void)?, eventObserverFactory: (UnaryResponseCallContext) -> EventLoopFuture) { - super.init(errorHandler: errorHandler) + public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext) -> EventLoopFuture) { + super.init(errorDelegate: errorDelegate) let context = UnaryResponseCallContextImpl(channel: channel, request: request) self.context = context let eventObserver = eventObserverFactory(context) diff --git a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift index ed01046d3..dde2c2306 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift @@ -13,8 +13,8 @@ public class ServerStreamingCallHandler? - public init(channel: Channel, request: HTTPRequestHead, errorHandler: ((Error) -> Void)?, eventObserverFactory: (StreamingResponseCallContext) -> EventObserver) { - super.init(errorHandler: errorHandler) + public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (StreamingResponseCallContext) -> EventObserver) { + super.init(errorDelegate: errorDelegate) let context = StreamingResponseCallContextImpl(channel: channel, request: request) self.context = context self.eventObserver = eventObserverFactory(context) diff --git a/Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift index a1fe31c26..2e275d207 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift @@ -14,8 +14,8 @@ public class UnaryCallHandler private var context: UnaryResponseCallContext? - public init(channel: Channel, request: HTTPRequestHead, errorHandler: ((Error) -> Void)?, eventObserverFactory: (UnaryResponseCallContext) -> EventObserver) { - super.init(errorHandler: errorHandler) + public init(channel: Channel, request: HTTPRequestHead, errorDelegate: ServerErrorDelegate?, eventObserverFactory: (UnaryResponseCallContext) -> EventObserver) { + super.init(errorDelegate: errorDelegate) let context = UnaryResponseCallContextImpl(channel: channel, request: request) self.context = context self.eventObserver = eventObserverFactory(context) diff --git a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift index 00ad07d43..3a913b061 100644 --- a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift +++ b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift @@ -19,7 +19,7 @@ public protocol CallHandlerProvider: class { /// Determines, calls and returns the appropriate request handler (`GRPCCallHandler`), depending on the request's /// method. Returns nil for methods not handled by this service. - func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorHandler: ((Error) -> Void)?) -> GRPCCallHandler? + func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler? } /// Listens on a newly-opened HTTP2 subchannel and yields to the sub-handler matching a call, if available. @@ -28,11 +28,11 @@ public protocol CallHandlerProvider: class { /// for an `GRPCCallHandler` object. That object is then forwarded the individual gRPC messages. public final class GRPCChannelHandler { private let servicesByName: [String: CallHandlerProvider] - private let errorHandler: ((Error) -> Void)? + private weak var errorDelegate: ServerErrorDelegate? - public init(servicesByName: [String: CallHandlerProvider], errorHandler: ((Error) -> Void)? = nil) { + public init(servicesByName: [String: CallHandlerProvider], errorDelegate: ServerErrorDelegate? = nil) { self.servicesByName = servicesByName - self.errorHandler = errorHandler + self.errorDelegate = errorDelegate } } @@ -41,9 +41,10 @@ extension GRPCChannelHandler: ChannelInboundHandler { public typealias OutboundOut = RawGRPCServerResponsePart public func errorCaught(ctx: ChannelHandlerContext, error: Error) { - errorHandler?(error) + errorDelegate?.observe(error) - let status = (error as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError + let transformedError = (errorDelegate?.transform(error) ?? error) + let status = (transformedError as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError ctx.writeAndFlush(wrapOutboundOut(.status(status)), promise: nil) } @@ -88,7 +89,7 @@ extension GRPCChannelHandler: ChannelInboundHandler { let uriComponents = requestHead.uri.components(separatedBy: "/") guard uriComponents.count >= 3 && uriComponents[0].isEmpty, let providerForServiceName = servicesByName[uriComponents[1]], - let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: channel, errorHandler: errorHandler) else { + let callHandler = providerForServiceName.handleMethod(uriComponents[2], request: requestHead, serverHandler: self, channel: channel, errorDelegate: errorDelegate) else { return nil } return callHandler diff --git a/Sources/SwiftGRPCNIO/GRPCServer.swift b/Sources/SwiftGRPCNIO/GRPCServer.swift index 67aa1a9c0..c658d5fce 100644 --- a/Sources/SwiftGRPCNIO/GRPCServer.swift +++ b/Sources/SwiftGRPCNIO/GRPCServer.swift @@ -13,7 +13,7 @@ public final class GRPCServer { port: Int, eventLoopGroup: EventLoopGroup, serviceProviders: [CallHandlerProvider], - errorHandler: ((Error) -> Void)? = nil + errorDelegate: ServerErrorDelegate? = nil ) -> EventLoopFuture { let servicesByName = Dictionary(uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) }) let bootstrap = ServerBootstrap(group: eventLoopGroup) @@ -29,7 +29,7 @@ public final class GRPCServer { let multiplexer = HTTP2StreamMultiplexer { (channel, streamID) -> EventLoopFuture in return channel.pipeline.add(handler: HTTP2ToHTTP1ServerCodec(streamID: streamID)) .then { channel.pipeline.add(handler: HTTP1ToRawGRPCServerCodec()) } - .then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName, errorHandler: errorHandler)) } + .then { channel.pipeline.add(handler: GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorDelegate)) } } return channel.pipeline.add(handler: multiplexer) diff --git a/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift b/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift new file mode 100644 index 000000000..fee423ceb --- /dev/null +++ b/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift @@ -0,0 +1,23 @@ +import Foundation + +public protocol ServerErrorDelegate: class { + /// Called when an error thrown in the channel pipeline. + func observe(_ error: Error) + + /// Transforms the given error into a new error. + /// + /// This allows framework to transform errors which may be out of their control + /// due to third-party libraries, for example, into more meaningful errors or + /// `GRPCStatus` errors. Errors returned from this protocol are not passed to + /// `observe`. + /// + /// - note: + /// This defaults to returning the provided error. + func transform(_ error: Error) -> Error +} + +public extension ServerErrorDelegate { + func transform(_ error: Error) -> Error { + return error + } +} diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift index ec46ab4e1..3d782e9a9 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift @@ -85,7 +85,7 @@ extension Generator { if options.generateNIOImplementation { println("/// Determines, calls and returns the appropriate request handler, depending on the request's method.") println("/// Returns nil for methods not handled by this service.") - println("\(access) func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorHandler: ((Error) -> Void)? = nil) -> GRPCCallHandler? {") + println("\(access) func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate? = nil) -> GRPCCallHandler? {") indent() println("switch methodName {") for method in service.methods { @@ -99,7 +99,7 @@ extension Generator { case .clientStreaming: callHandlerType = "ClientStreamingCallHandler" case .bidirectionalStreaming: callHandlerType = "BidirectionalStreamingCallHandler" } - println("return \(callHandlerType)(channel: channel, request: request, errorHandler: errorHandler) { context in") + println("return \(callHandlerType)(channel: channel, request: request, errorDelegate: errorDelegate) { context in") indent() switch streamingType(method) { case .unary, .serverStreaming: diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift index 3f0eda98f..5daaf4eb8 100644 --- a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift @@ -34,20 +34,13 @@ func extractStatus(_ response: RawGRPCServerResponsePart) throws -> GRPCStatus { class CollectingChannelHandler: ChannelOutboundHandler { var responses: [OutboundIn] = [] - var responseExpectation: XCTestExpectation? - - init(responseExpectation: XCTestExpectation? = nil) { - self.responseExpectation = responseExpectation - } func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { responses.append(unwrapOutboundIn(data)) - responseExpectation?.fulfill() } } class GRPCChannelHandlerResponseCapturingTestCase: XCTestCase { - static let defaultTimeout: TimeInterval = 0.2 static let echoProvider: [String: CallHandlerProvider] = ["echo.Echo": EchoProvider_NIO()] func configureChannel(withHandlers handlers: [ChannelHandler]) -> EventLoopFuture { @@ -56,23 +49,19 @@ class GRPCChannelHandlerResponseCapturingTestCase: XCTestCase { .map { _ in channel } } - /// Waits for `count` responses to be collected and then returns them. The test fails if too many - /// responses are collected or not enough are collected before the timeout. + /// Waits for `count` responses to be collected and then returns them. The test fails if the number + /// of collected responses does not match the expected. func waitForGRPCChannelHandlerResponses( count: Int, servicesByName: [String: CallHandlerProvider] = echoProvider, - timeout: TimeInterval = defaultTimeout, callback: @escaping (EmbeddedChannel) throws -> Void - ) -> [RawGRPCServerResponsePart] { - let responseExpectation = expectation(description: "expecting \(count) responses") - responseExpectation.expectedFulfillmentCount = count - responseExpectation.assertForOverFulfill = true - - let collector = CollectingChannelHandler(responseExpectation: responseExpectation) - _ = configureChannel(withHandlers: [collector, GRPCChannelHandler(servicesByName: servicesByName)]) + ) throws -> [RawGRPCServerResponsePart] { + let collector = CollectingChannelHandler() + try configureChannel(withHandlers: [collector, GRPCChannelHandler(servicesByName: servicesByName)]) .thenThrowing(callback) + .wait() - waitForExpectations(timeout: timeout) + XCTAssertEqual(count, collector.responses.count) return collector.responses } } diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift index 1bae3fa20..4ecd8b96e 100644 --- a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift @@ -17,7 +17,7 @@ func gRPCMessage(channel: EmbeddedChannel, compression: Bool = false, message: D class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { func testUnimplementedMethodReturnsUnimplementedStatus() throws { - let responses = waitForGRPCChannelHandlerResponses(count: 1) { channel in + let responses = try waitForGRPCChannelHandlerResponses(count: 1) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "unimplemented") try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) } @@ -28,7 +28,7 @@ class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { } func testImplementedMethodReturnsHeadersMessageAndStatus() throws { - let responses = waitForGRPCChannelHandlerResponses(count: 3) { channel in + let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) @@ -47,7 +47,7 @@ class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { } func testImplementedMethodReturnsStatusForBadlyFormedProto() throws { - let responses = waitForGRPCChannelHandlerResponses(count: 2) { channel in + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) @@ -67,7 +67,7 @@ class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCase { func testUnimplementedStatusReturnedWhenCompressionFlagIsSet() throws { - let responses = waitForGRPCChannelHandlerResponses(count: 2) { channel in + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel, compression: true))) @@ -82,7 +82,7 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas } func testMessageCanBeSentAcrossMultipleByteBuffers() throws { - let responses = waitForGRPCChannelHandlerResponses(count: 3) { channel in + let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") // Sending the header allocates a buffer. try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) @@ -111,7 +111,7 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas } func testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized() throws { - let responses = waitForGRPCChannelHandlerResponses(count: 2) { channel in + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) @@ -128,7 +128,7 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas } func testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest() throws { - let responses = waitForGRPCChannelHandlerResponses(count: 2) { channel in + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in // We have to use "Collect" (client streaming) as the tests rely on `EmbeddedChannel` which runs in this thread. // In the current server implementation, responses from unary calls send a status immediately after sending the response. // As such, a unary "Get" would return an "ok" status before the trailers would be sent. @@ -148,7 +148,7 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas } func testOnlyOneStatusIsReturned() throws { - let responses = waitForGRPCChannelHandlerResponses(count: 3) { channel in + let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) @@ -173,10 +173,9 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas override func waitForGRPCChannelHandlerResponses( count: Int, servicesByName: [String: CallHandlerProvider] = GRPCChannelHandlerResponseCapturingTestCase.echoProvider, - timeout: TimeInterval = GRPCChannelHandlerResponseCapturingTestCase.defaultTimeout, callback: @escaping (EmbeddedChannel) throws -> Void - ) -> [RawGRPCServerResponsePart] { - return super.waitForGRPCChannelHandlerResponses(count: count, servicesByName: servicesByName, timeout: timeout) { channel in + ) throws -> [RawGRPCServerResponsePart] { + return try super.waitForGRPCChannelHandlerResponses(count: count, servicesByName: servicesByName) { channel in _ = channel.pipeline.addHandlers(HTTP1ToRawGRPCServerCodec(), first: true) .thenThrowing { _ in try callback(channel) } } diff --git a/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift b/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift index 6a235fece..4d39364b3 100644 --- a/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift +++ b/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift @@ -40,29 +40,29 @@ extension Echo_EchoProvider_NIO { /// Determines, calls and returns the appropriate request handler, depending on the request's method. /// Returns nil for methods not handled by this service. - internal func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorHandler: ((Error) -> Void)? = nil) -> GRPCCallHandler? { + internal func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate? = nil) -> GRPCCallHandler? { switch methodName { case "Get": - return UnaryCallHandler(channel: channel, request: request, errorHandler: errorHandler) { context in + return UnaryCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in return { request in self.get(request: request, context: context) } } case "Expand": - return ServerStreamingCallHandler(channel: channel, request: request, errorHandler: errorHandler) { context in + return ServerStreamingCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in return { request in self.expand(request: request, context: context) } } case "Collect": - return ClientStreamingCallHandler(channel: channel, request: request, errorHandler: errorHandler) { context in + return ClientStreamingCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in return self.collect(context: context) } case "Update": - return BidirectionalStreamingCallHandler(channel: channel, request: request, errorHandler: errorHandler) { context in + return BidirectionalStreamingCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in return self.update(context: context) } From 005c70db4267921dce593d48b0297c30d664f2d7 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Thu, 31 Jan 2019 10:19:53 +0000 Subject: [PATCH 03/10] Strongly hold errorDelegate in the server until shutdown --- Sources/SwiftGRPCNIO/GRPCServer.swift | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/Sources/SwiftGRPCNIO/GRPCServer.swift b/Sources/SwiftGRPCNIO/GRPCServer.swift index c658d5fce..a54b6c3d8 100644 --- a/Sources/SwiftGRPCNIO/GRPCServer.swift +++ b/Sources/SwiftGRPCNIO/GRPCServer.swift @@ -41,13 +41,22 @@ public final class GRPCServer { .childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) return bootstrap.bind(host: hostname, port: port) - .map { GRPCServer(channel: $0) } + .map { GRPCServer(channel: $0, errorDelegate: errorDelegate) } } private let channel: Channel + private var errorDelegate: ServerErrorDelegate? - private init(channel: Channel) { + private init(channel: Channel, errorDelegate: ServerErrorDelegate?) { self.channel = channel + + // Maintain a strong reference to ensure it lives as long as the server. + self.errorDelegate = errorDelegate + + // `BaseCallHandler` holds a weak reference to the delegate; nil out this reference to avoid retain cycles. + onClose.whenComplete { + self.errorDelegate = nil + } } /// Fired when the server shuts down. @@ -55,6 +64,7 @@ public final class GRPCServer { return channel.closeFuture } + /// Shut down the server; this should be called to avoid leaking resources. public func close() -> EventLoopFuture { return channel.close(mode: .all) } From f9e60b14b16ed2e21da8729f735ed48397170570 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 19 Feb 2019 17:21:07 +0000 Subject: [PATCH 04/10] More errors to a dedicated enum, fix typos, etc. --- .../CallHandlers/BaseCallHandler.swift | 8 +-- .../ServerStreamingCallHandler.swift | 2 +- .../CallHandlers/UnaryCallHandler.swift | 2 +- Sources/SwiftGRPCNIO/GRPCChannelHandler.swift | 6 +- Sources/SwiftGRPCNIO/GRPCError.swift | 70 +++++++++++++++++++ Sources/SwiftGRPCNIO/GRPCServer.swift | 4 +- Sources/SwiftGRPCNIO/GRPCServerCodec.swift | 8 +-- Sources/SwiftGRPCNIO/GRPCStatus.swift | 16 +---- .../HTTP1ToRawGRPCServerCodec.swift | 63 +++++++---------- .../LoggingServerErrorDelegate.swift | 24 +++++++ .../SwiftGRPCNIO/ServerErrorDelegate.swift | 21 +++++- .../Generator-Server.swift | 2 +- ...nnelHandlerResponseCapturingTestCase.swift | 56 +++++++-------- .../GRPCChannelHandlerTests.swift | 57 ++++++++------- Tests/SwiftGRPCNIOTests/NIOServerTests.swift | 6 +- Tests/SwiftGRPCNIOTests/TestHelpers.swift | 61 ++++++++++++++++ Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift | 2 +- 17 files changed, 277 insertions(+), 131 deletions(-) create mode 100644 Sources/SwiftGRPCNIO/GRPCError.swift create mode 100644 Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift create mode 100644 Tests/SwiftGRPCNIOTests/TestHelpers.swift diff --git a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift index 3fdfa1b20..718e211b1 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift @@ -27,7 +27,7 @@ public class BaseCallHandler: /// Called for each error recieved in `errorCaught(ctx:error:)`. private weak var errorDelegate: ServerErrorDelegate? - public init(errorDelegate: ServerErrorDelegate? = nil) { + public init(errorDelegate: ServerErrorDelegate?) { self.errorDelegate = errorDelegate } } @@ -48,9 +48,9 @@ extension BaseCallHandler: ChannelInboundHandler { public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { switch self.unwrapInboundIn(data) { - case .head: + case .head(let requestHead): // Head should have been handled by `GRPCChannelHandler`. - self.errorCaught(ctx: ctx, error: GRPCStatus(code: .unknown, message: "unexpectedly received head")) + self.errorCaught(ctx: ctx, error: GRPCError.invalidState("unexpected request head received \(requestHead)")) case .message(let message): do { @@ -71,7 +71,7 @@ extension BaseCallHandler: ChannelOutboundHandler { public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { guard serverCanWrite else { - promise?.fail(error: GRPCStatus.processingError) + promise?.fail(error: GRPCError.serverNotWritable) return } diff --git a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift index dde2c2306..e282c2964 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift @@ -29,7 +29,7 @@ public class ServerStreamingCallHandler public override func processMessage(_ message: RequestMessage) throws { guard let eventObserver = self.eventObserver, let context = self.context else { - throw GRPCStatus(code: .unimplemented, message: "multiple messages received on unary call") + throw GRPCError.requestCardinalityViolation } let resultFuture = eventObserver(message) diff --git a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift index 3a913b061..156a44863 100644 --- a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift +++ b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift @@ -30,7 +30,7 @@ public final class GRPCChannelHandler { private let servicesByName: [String: CallHandlerProvider] private weak var errorDelegate: ServerErrorDelegate? - public init(servicesByName: [String: CallHandlerProvider], errorDelegate: ServerErrorDelegate? = nil) { + public init(servicesByName: [String: CallHandlerProvider], errorDelegate: ServerErrorDelegate?) { self.servicesByName = servicesByName self.errorDelegate = errorDelegate } @@ -43,7 +43,7 @@ extension GRPCChannelHandler: ChannelInboundHandler { public func errorCaught(ctx: ChannelHandlerContext, error: Error) { errorDelegate?.observe(error) - let transformedError = (errorDelegate?.transform(error) ?? error) + let transformedError = errorDelegate?.transform(error) ?? error let status = (transformedError as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError ctx.writeAndFlush(wrapOutboundOut(.status(status)), promise: nil) } @@ -53,7 +53,7 @@ extension GRPCChannelHandler: ChannelInboundHandler { switch requestPart { case .head(let requestHead): guard let callHandler = getCallHandler(channel: ctx.channel, requestHead: requestHead) else { - errorCaught(ctx: ctx, error: GRPCStatus.unimplemented(method: requestHead.uri)) + errorCaught(ctx: ctx, error: GRPCError.unimplementedMethod(requestHead.uri)) return } diff --git a/Sources/SwiftGRPCNIO/GRPCError.swift b/Sources/SwiftGRPCNIO/GRPCError.swift new file mode 100644 index 000000000..5106feb07 --- /dev/null +++ b/Sources/SwiftGRPCNIO/GRPCError.swift @@ -0,0 +1,70 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation + +public enum GRPCError: Error, Equatable { + /// The RPC method is not implemented on the server. + case unimplementedMethod(String) + + /// It was not possible to parse the request protobuf. + case requestProtoParseFailure + + /// It was not possible to serialize the response protobuf. + case responseProtoSerializationFailure + + /// The given compression mechanism is not supported. + case unsupportedCompressionMechanism(String) + + /// Compression was indicated in the gRPC message, but not for the call. + case unexpectedCompression + + /// More than one request was sent for a unary-request call. + case requestCardinalityViolation + + /// The server received a message when it was not in a writable state. + case serverNotWritable + + /// An invalid state has been reached; something has gone very wrong. + case invalidState(String) +} + +extension GRPCError: GRPCStatusTransformable { + public func asGRPCStatus() -> GRPCStatus { + // These status codes are informed by: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md + switch self { + case .unimplementedMethod(let method): + return GRPCStatus(code: .unimplemented, message: "unknown method \(method)") + + case .requestProtoParseFailure: + return GRPCStatus(code: .internalError, message: "could not parse request proto") + + case .responseProtoSerializationFailure: + return GRPCStatus(code: .internalError, message: "could not serialize response proto") + + case .unsupportedCompressionMechanism(let mechanism): + return GRPCStatus(code: .unimplemented, message: "unsupported compression mechanism \(mechanism)") + + case .unexpectedCompression: + return GRPCStatus(code: .unimplemented, message: "compression was enabled for this gRPC message but not for this call") + + case .requestCardinalityViolation: + return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent more") + + case .serverNotWritable, .invalidState: + return GRPCStatus.processingError + } + } +} diff --git a/Sources/SwiftGRPCNIO/GRPCServer.swift b/Sources/SwiftGRPCNIO/GRPCServer.swift index a54b6c3d8..8ca9edc5c 100644 --- a/Sources/SwiftGRPCNIO/GRPCServer.swift +++ b/Sources/SwiftGRPCNIO/GRPCServer.swift @@ -13,7 +13,7 @@ public final class GRPCServer { port: Int, eventLoopGroup: EventLoopGroup, serviceProviders: [CallHandlerProvider], - errorDelegate: ServerErrorDelegate? = nil + errorDelegate: ServerErrorDelegate? = LoggingServerErrorDelegate() ) -> EventLoopFuture { let servicesByName = Dictionary(uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) }) let bootstrap = ServerBootstrap(group: eventLoopGroup) @@ -53,7 +53,7 @@ public final class GRPCServer { // Maintain a strong reference to ensure it lives as long as the server. self.errorDelegate = errorDelegate - // `BaseCallHandler` holds a weak reference to the delegate; nil out this reference to avoid retain cycles. + // nil out errorDelegate to avoid retain cycles. onClose.whenComplete { self.errorDelegate = nil } diff --git a/Sources/SwiftGRPCNIO/GRPCServerCodec.swift b/Sources/SwiftGRPCNIO/GRPCServerCodec.swift index 4cc5b214d..ef67a75c7 100644 --- a/Sources/SwiftGRPCNIO/GRPCServerCodec.swift +++ b/Sources/SwiftGRPCNIO/GRPCServerCodec.swift @@ -35,7 +35,7 @@ extension GRPCServerCodec: ChannelInboundHandler { do { ctx.fireChannelRead(self.wrapInboundOut(.message(try RequestMessage(serializedData: messageAsData)))) } catch { - ctx.fireErrorCaught(GRPCStatus.requestProtoParseError) + ctx.fireErrorCaught(GRPCError.requestProtoParseFailure) } case .end: @@ -61,9 +61,9 @@ extension GRPCServerCodec: ChannelOutboundHandler { responseBuffer.write(bytes: messageData) ctx.write(self.wrapOutboundOut(.message(responseBuffer)), promise: promise) } catch { - let status = GRPCStatus.responseProtoSerializationError - promise?.fail(error: status) - ctx.fireErrorCaught(status) + let error = GRPCError.responseProtoSerializationFailure + promise?.fail(error: error) + ctx.fireErrorCaught(error) } case .status(let status): diff --git a/Sources/SwiftGRPCNIO/GRPCStatus.swift b/Sources/SwiftGRPCNIO/GRPCStatus.swift index cb1ae07d4..7a023242c 100644 --- a/Sources/SwiftGRPCNIO/GRPCStatus.swift +++ b/Sources/SwiftGRPCNIO/GRPCStatus.swift @@ -2,7 +2,7 @@ import Foundation import NIOHTTP1 /// Encapsulates the result of a gRPC call. -public struct GRPCStatus: Error { +public struct GRPCStatus: Error, Equatable { /// The code to return in the `grpc-status` header. public let code: StatusCode /// The message to return in the `grpc-message` header. @@ -22,24 +22,14 @@ public struct GRPCStatus: Error { public static let ok = GRPCStatus(code: .ok, message: "OK") /// "Internal server error" status. public static let processingError = GRPCStatus(code: .internalError, message: "unknown error processing request") - - /// Status indicating that the given method is not implemented. - public static func unimplemented(method: String) -> GRPCStatus { - return GRPCStatus(code: .unimplemented, message: "unknown method " + method) - } - - // These status codes are informed by: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md - static internal let requestProtoParseError = GRPCStatus(code: .internalError, message: "could not parse request proto") - static internal let responseProtoSerializationError = GRPCStatus(code: .internalError, message: "could not serialize response proto") - static internal let unsupportedCompression = GRPCStatus(code: .unimplemented, message: "compression is not supported on the server") } -protocol GRPCStatusTransformable: Error { +public protocol GRPCStatusTransformable: Error { func asGRPCStatus() -> GRPCStatus } extension GRPCStatus: GRPCStatusTransformable { - func asGRPCStatus() -> GRPCStatus { + public func asGRPCStatus() -> GRPCStatus { return self } } diff --git a/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift b/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift index 3db609585..022eeea98 100644 --- a/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift +++ b/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift @@ -27,8 +27,8 @@ public enum RawGRPCServerResponsePart { /// /// The translation from HTTP2 to HTTP1 is done by `HTTP2ToHTTP1ServerCodec`. public final class HTTP1ToRawGRPCServerCodec { - internal var inboundState = InboundState.expectingHeaders - internal var outboundState = OutboundState.expectingHeaders + var inboundState = InboundState.expectingHeaders + var outboundState = OutboundState.expectingHeaders private var buffer: NIO.ByteBuffer? @@ -46,7 +46,7 @@ extension HTTP1ToRawGRPCServerCodec { enum Body { case expectingCompressedFlag case expectingMessageLength - case receivedMessageLength(UInt32) + case expectingMoreMessageBytes(UInt32) } } @@ -57,20 +57,6 @@ extension HTTP1ToRawGRPCServerCodec { } } -extension HTTP1ToRawGRPCServerCodec { - struct StateMachineError: Error, GRPCStatusTransformable { - private let message: String - - init(_ message: String) { - self.message = message - } - - func asGRPCStatus() -> GRPCStatus { - return GRPCStatus.processingError - } - } -} - extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { public typealias InboundIn = HTTPServerRequestPart public typealias InboundOut = RawGRPCServerRequestPart @@ -97,7 +83,7 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { func processHead(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) throws -> InboundState { guard case .expectingHeaders = inboundState else { - throw StateMachineError("expecteded state .expectingHeaders, got \(inboundState)") + throw GRPCError.invalidState("expecteded state .expectingHeaders, got \(inboundState)") } ctx.fireChannelRead(self.wrapInboundOut(.head(requestHead))) @@ -107,7 +93,7 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { func processBody(ctx: ChannelHandlerContext, body: inout ByteBuffer) throws -> InboundState { guard case .expectingBody(let bodyState) = inboundState else { - throw StateMachineError("expecteded state .expectingBody(_), got \(inboundState)") + throw GRPCError.invalidState("expecteded state .expectingBody(_), got \(inboundState)") } return .expectingBody(try processBodyState(ctx: ctx, initialState: bodyState, messageBuffer: &body)) @@ -124,37 +110,36 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { while true { switch bodyState { case .expectingCompressedFlag: - guard let compressionFlag: Int8 = messageBuffer.readInteger() else { return .expectingCompressedFlag } + guard let compressedFlag: Int8 = messageBuffer.readInteger() else { return .expectingCompressedFlag } // TODO: Add support for compression. - guard compressionFlag == 0 else { throw GRPCStatus.unsupportedCompression } + guard compressedFlag == 0 else { throw GRPCError.unexpectedCompression } bodyState = .expectingMessageLength case .expectingMessageLength: guard let messageLength: UInt32 = messageBuffer.readInteger() else { return .expectingMessageLength } - bodyState = .receivedMessageLength(messageLength) + bodyState = .expectingMoreMessageBytes(messageLength) - case .receivedMessageLength(let messageLength): + case .expectingMoreMessageBytes(let bytesOutstanding): // We need to account for messages being spread across multiple `ByteBuffer`s so buffer them // into `buffer`. Note: when messages are contained within a single `ByteBuffer` we're just // taking a slice so don't incur any extra writes. - guard messageBuffer.readableBytes >= messageLength else { - let remainingBytes = messageLength - numericCast(messageBuffer.readableBytes) + guard messageBuffer.readableBytes >= bytesOutstanding else { + let remainingBytes = bytesOutstanding - numericCast(messageBuffer.readableBytes) if var buffer = buffer { buffer.write(buffer: &messageBuffer) self.buffer = buffer } else { - messageBuffer.reserveCapacity(numericCast(messageLength)) + messageBuffer.reserveCapacity(numericCast(bytesOutstanding)) self.buffer = messageBuffer } - - return .receivedMessageLength(remainingBytes) + return .expectingMoreMessageBytes(remainingBytes) } // We know buffer.readableBytes >= messageLength, so it's okay to force unwrap here. - var slice = messageBuffer.readSlice(length: numericCast(messageLength))! + var slice = messageBuffer.readSlice(length: numericCast(bytesOutstanding))! if var buffer = buffer { buffer.write(buffer: &slice) @@ -170,8 +155,8 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { } private func processEnd(ctx: ChannelHandlerContext, trailers: HTTPHeaders?) throws -> InboundState { - guard trailers == nil else { - throw StateMachineError("unexpected trailers received \(String(describing: trailers))") + if let trailers = trailers { + throw GRPCError.invalidState("unexpected trailers received \(trailers)") } ctx.fireChannelRead(self.wrapInboundOut(.end)) @@ -207,17 +192,19 @@ extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler { outboundState = .expectingBodyOrStatus case .status(let status): + // If we error before sending the initial headers (e.g. unimplemtned method) then we won't have sent the request head. + // NIOHTTP2 doesn't support sending a single frame as a "Trailers-Only" response so we still need to loop back and + // send the request head first. + if case .expectingHeaders = outboundState { + var headers = HTTPHeaders() + headers.add(name: "content-type", value: "application/grpc") + self.write(ctx: ctx, data: NIOAny(RawGRPCServerResponsePart.headers(headers)), promise: nil) + } + var trailers = status.trailingMetadata trailers.add(name: "grpc-status", value: String(describing: status.code.rawValue)) trailers.add(name: "grpc-message", value: status.message) - // "Trailers-Only" response - if case .expectingHeaders = outboundState { - trailers.add(name: "content-type", value: "application/grpc") - let responseHead = HTTPResponseHead(version: .init(major: 2, minor: 0), status: .ok) - ctx.write(self.wrapOutboundOut(.head(responseHead)), promise: nil) - } - ctx.writeAndFlush(self.wrapOutboundOut(.end(trailers)), promise: promise) outboundState = .ignore inboundState = .ignore diff --git a/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift b/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift new file mode 100644 index 000000000..86c128d31 --- /dev/null +++ b/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift @@ -0,0 +1,24 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation + +public class LoggingServerErrorDelegate: ServerErrorDelegate { + public init() {} + + public func observe(_ error: Error) { + print("[grpc-server] \(error)") + } +} diff --git a/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift b/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift index fee423ceb..15b312959 100644 --- a/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift +++ b/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift @@ -1,12 +1,29 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ import Foundation +import NIO public protocol ServerErrorDelegate: class { - /// Called when an error thrown in the channel pipeline. + //: FIXME: provide more context about where the error was thrown. + /// Called when an error is thrown in the channel pipeline. func observe(_ error: Error) /// Transforms the given error into a new error. /// - /// This allows framework to transform errors which may be out of their control + /// This allows framework users to transform errors which may be out of their control /// due to third-party libraries, for example, into more meaningful errors or /// `GRPCStatus` errors. Errors returned from this protocol are not passed to /// `observe`. diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift index 3d782e9a9..a7de29b47 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift @@ -85,7 +85,7 @@ extension Generator { if options.generateNIOImplementation { println("/// Determines, calls and returns the appropriate request handler, depending on the request's method.") println("/// Returns nil for methods not handled by this service.") - println("\(access) func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate? = nil) -> GRPCCallHandler? {") + println("\(access) func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler? {") indent() println("switch methodName {") for method in service.methods { diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift index 5daaf4eb8..863bb11ae 100644 --- a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift @@ -4,34 +4,6 @@ import NIOHTTP1 @testable import SwiftGRPCNIO import XCTest -internal struct CaseExtractError: Error { - let message: String -} - -@discardableResult -func extractHeaders(_ response: RawGRPCServerResponsePart) throws -> HTTPHeaders { - guard case .headers(let headers) = response else { - throw CaseExtractError(message: "\(response) did not match .headers") - } - return headers -} - -@discardableResult -func extractMessage(_ response: RawGRPCServerResponsePart) throws -> ByteBuffer { - guard case .message(let message) = response else { - throw CaseExtractError(message: "\(response) did not match .message") - } - return message -} - -@discardableResult -func extractStatus(_ response: RawGRPCServerResponsePart) throws -> GRPCStatus { - guard case .status(let status) = response else { - throw CaseExtractError(message: "\(response) did not match .status") - } - return status -} - class CollectingChannelHandler: ChannelOutboundHandler { var responses: [OutboundIn] = [] @@ -40,8 +12,19 @@ class CollectingChannelHandler: ChannelOutboundHandler { } } +class CollectingServerErrorDelegate: ServerErrorDelegate { + var errors: [Error] = [] + + func observe(_ error: Error) { + self.errors.append(error) + } +} + class GRPCChannelHandlerResponseCapturingTestCase: XCTestCase { static let echoProvider: [String: CallHandlerProvider] = ["echo.Echo": EchoProvider_NIO()] + class var defaultServiceProvider: [String: CallHandlerProvider] { + return echoProvider + } func configureChannel(withHandlers handlers: [ChannelHandler]) -> EventLoopFuture { let channel = EmbeddedChannel() @@ -49,15 +32,28 @@ class GRPCChannelHandlerResponseCapturingTestCase: XCTestCase { .map { _ in channel } } + var errorCollector: CollectingServerErrorDelegate = CollectingServerErrorDelegate() + + override func setUp() { + errorCollector.errors.removeAll() + } + /// Waits for `count` responses to be collected and then returns them. The test fails if the number /// of collected responses does not match the expected. + /// + /// - Parameters: + /// - count: expected number of responses. + /// - servicesByName: service providers keyed by their service name. + /// - callback: a callback called after the channel has been setup, intended to "fill" the channel + /// with messages. The callback is called before this function returns. + /// - Returns: The responses collected from the pipeline. func waitForGRPCChannelHandlerResponses( count: Int, - servicesByName: [String: CallHandlerProvider] = echoProvider, + servicesByName: [String: CallHandlerProvider] = defaultServiceProvider, callback: @escaping (EmbeddedChannel) throws -> Void ) throws -> [RawGRPCServerResponsePart] { let collector = CollectingChannelHandler() - try configureChannel(withHandlers: [collector, GRPCChannelHandler(servicesByName: servicesByName)]) + try configureChannel(withHandlers: [collector, GRPCChannelHandler(servicesByName: servicesByName, errorDelegate: errorCollector)]) .thenThrowing(callback) .wait() diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift index 4ecd8b96e..60e1e7afd 100644 --- a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift @@ -18,12 +18,15 @@ func gRPCMessage(channel: EmbeddedChannel, compression: Bool = false, message: D class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { func testUnimplementedMethodReturnsUnimplementedStatus() throws { let responses = try waitForGRPCChannelHandlerResponses(count: 1) { channel in - let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "unimplemented") + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "unimplementedMethodName") try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) } + let expectedError = GRPCError.unimplementedMethod("unimplementedMethodName") + XCTAssertEqual(expectedError, errorCollector.errors.first as? GRPCError) + XCTAssertNoThrow(try extractStatus(responses[0])) { status in - XCTAssertEqual(status.code, .unimplemented) + XCTAssertEqual(status, expectedError.asGRPCStatus()) } } @@ -42,7 +45,7 @@ class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractMessage(responses[1])) XCTAssertNoThrow(try extractStatus(responses[2])) { status in - XCTAssertEqual(status.code, .ok) + XCTAssertEqual(status, .ok) } } @@ -56,28 +59,30 @@ class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { try channel.writeInbound(RawGRPCServerRequestPart.message(buffer)) } + let expectedError = GRPCError.requestProtoParseFailure + XCTAssertEqual(expectedError, errorCollector.errors.first as? GRPCError) + XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractStatus(responses[1])) { status in - let expectedStatus = GRPCStatus.requestProtoParseError - XCTAssertEqual(status.code, expectedStatus.code) - XCTAssertEqual(status.message, expectedStatus.message) + XCTAssertEqual(status, expectedError.asGRPCStatus()) } } } class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCase { - func testUnimplementedStatusReturnedWhenCompressionFlagIsSet() throws { + func testInternalErrorStatusReturnedWhenCompressionFlagIsSet() throws { let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel, compression: true))) } + let expectedError = GRPCError.unexpectedCompression + XCTAssertEqual(expectedError, errorCollector.errors.first as? GRPCError) + XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractStatus(responses[1])) { status in - let expected = GRPCStatus.unsupportedCompression - XCTAssertEqual(status.code, expected.code) - XCTAssertEqual(status.message, expected.message) + XCTAssertEqual(status, expectedError.asGRPCStatus()) } } @@ -106,7 +111,7 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractMessage(responses[1])) XCTAssertNoThrow(try extractStatus(responses[2])) { status in - XCTAssertEqual(status.code, .ok) + XCTAssertEqual(status, .ok) } } @@ -119,11 +124,12 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas try channel.writeInbound(HTTPServerRequestPart.body(buffer)) } + let expectedError = GRPCError.requestProtoParseFailure + XCTAssertEqual(expectedError, errorCollector.errors.first as? GRPCError) + XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractStatus(responses[1])) { status in - let expected = GRPCStatus.requestProtoParseError - XCTAssertEqual(status.code, expected.code) - XCTAssertEqual(status.message, expected.message) + XCTAssertEqual(status, expectedError.asGRPCStatus()) } } @@ -141,9 +147,15 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas try channel.writeInbound(HTTPServerRequestPart.end(trailers)) } + if case .invalidState(let message)? = errorCollector.errors.first as? GRPCError { + XCTAssert(message.contains("trailers")) + } else { + XCTFail("\(String(describing: errorCollector.errors.first)) was not GRPCError.invalidState") + } + XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractStatus(responses[1])) { status in - XCTAssertEqual(status.code, .internalError) + XCTAssertEqual(status, .processingError) } } @@ -166,7 +178,7 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractMessage(responses[1])) XCTAssertNoThrow(try extractStatus(responses[2])) { status in - XCTAssertEqual(status.code, .ok) + XCTAssertEqual(status, .ok) } } @@ -181,16 +193,3 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas } } } - -// Assert the given expression does not throw, and validate the return value from that expression. -public func XCTAssertNoThrow( - _ expression: @autoclosure () throws -> T, - _ message: String = "", - file: StaticString = #file, - line: UInt = #line, - validate: (T) -> Void -) { - var value: T? = nil - XCTAssertNoThrow(try value = expression(), message, file: file, line: line) - value.map { validate($0) } -} diff --git a/Tests/SwiftGRPCNIOTests/NIOServerTests.swift b/Tests/SwiftGRPCNIOTests/NIOServerTests.swift index 60fbbe526..5db92c607 100644 --- a/Tests/SwiftGRPCNIOTests/NIOServerTests.swift +++ b/Tests/SwiftGRPCNIOTests/NIOServerTests.swift @@ -124,8 +124,10 @@ extension NIOServerTests { func testUnaryWithLargeData() throws { // Default max frame size is: 16,384. We'll exceed this as we also have to send the size and compression flag. - let request = Echo_EchoRequest.with { $0.text = String(repeating: "e", count: 16_384) } - XCTAssertNoThrow(try client.get(request)) + let longMessage = String(repeating: "e", count: 16_384) + XCTAssertNoThrow(try client.get(Echo_EchoRequest(text: longMessage))) { response in + XCTAssertEqual("Swift echo get: \(longMessage)", response.text) + } } func testUnaryLotsOfRequests() { diff --git a/Tests/SwiftGRPCNIOTests/TestHelpers.swift b/Tests/SwiftGRPCNIOTests/TestHelpers.swift new file mode 100644 index 000000000..73be06743 --- /dev/null +++ b/Tests/SwiftGRPCNIOTests/TestHelpers.swift @@ -0,0 +1,61 @@ +/* + * Copyright 2019, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation +import XCTest +import SwiftGRPCNIO +import NIO +import NIOHTTP1 + +// Assert the given expression does not throw, and validate the return value from that expression. +public func XCTAssertNoThrow( + _ expression: @autoclosure () throws -> T, + _ message: String = "", + file: StaticString = #file, + line: UInt = #line, + validate: (T) -> Void +) { + var value: T? = nil + XCTAssertNoThrow(try value = expression(), message, file: file, line: line) + value.map { validate($0) } +} + +struct CaseExtractError: Error { + let message: String +} + +@discardableResult +func extractHeaders(_ response: RawGRPCServerResponsePart) throws -> HTTPHeaders { + guard case .headers(let headers) = response else { + throw CaseExtractError(message: "\(response) did not match .headers") + } + return headers +} + +@discardableResult +func extractMessage(_ response: RawGRPCServerResponsePart) throws -> ByteBuffer { + guard case .message(let message) = response else { + throw CaseExtractError(message: "\(response) did not match .message") + } + return message +} + +@discardableResult +func extractStatus(_ response: RawGRPCServerResponsePart) throws -> GRPCStatus { + guard case .status(let status) = response else { + throw CaseExtractError(message: "\(response) did not match .status") + } + return status +} diff --git a/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift b/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift index 4d39364b3..ecf86a285 100644 --- a/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift +++ b/Tests/SwiftGRPCNIOTests/echo_nio.grpc.swift @@ -40,7 +40,7 @@ extension Echo_EchoProvider_NIO { /// Determines, calls and returns the appropriate request handler, depending on the request's method. /// Returns nil for methods not handled by this service. - internal func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate? = nil) -> GRPCCallHandler? { + internal func handleMethod(_ methodName: String, request: HTTPRequestHead, serverHandler: GRPCChannelHandler, channel: Channel, errorDelegate: ServerErrorDelegate?) -> GRPCCallHandler? { switch methodName { case "Get": return UnaryCallHandler(channel: channel, request: request, errorDelegate: errorDelegate) { context in From aa0acdcd3c91836c334e799f02a226258f5babb6 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Mon, 25 Feb 2019 17:01:45 +0000 Subject: [PATCH 05/10] Renaming, typo fixes --- .../CallHandlers/BaseCallHandler.swift | 4 ++-- .../ServerStreamingCallHandler.swift | 2 +- .../CallHandlers/UnaryCallHandler.swift | 2 +- Sources/SwiftGRPCNIO/GRPCChannelHandler.swift | 2 +- Sources/SwiftGRPCNIO/GRPCServerCodec.swift | 4 ++-- ...{GRPCError.swift => GRPCServerError.swift} | 4 ++-- .../HTTP1ToRawGRPCServerCodec.swift | 10 +++++----- .../LoggingServerErrorDelegate.swift | 2 +- .../SwiftGRPCNIO/ServerErrorDelegate.swift | 2 +- ...nnelHandlerResponseCapturingTestCase.swift | 4 ---- .../GRPCChannelHandlerTests.swift | 20 ++++++++++--------- 11 files changed, 27 insertions(+), 29 deletions(-) rename Sources/SwiftGRPCNIO/{GRPCError.swift => GRPCServerError.swift} (96%) diff --git a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift index 718e211b1..ae3986fd5 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift @@ -50,7 +50,7 @@ extension BaseCallHandler: ChannelInboundHandler { switch self.unwrapInboundIn(data) { case .head(let requestHead): // Head should have been handled by `GRPCChannelHandler`. - self.errorCaught(ctx: ctx, error: GRPCError.invalidState("unexpected request head received \(requestHead)")) + self.errorCaught(ctx: ctx, error: GRPCServerError.invalidState("unexpected request head received \(requestHead)")) case .message(let message): do { @@ -71,7 +71,7 @@ extension BaseCallHandler: ChannelOutboundHandler { public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { guard serverCanWrite else { - promise?.fail(error: GRPCError.serverNotWritable) + promise?.fail(error: GRPCServerError.serverNotWritable) return } diff --git a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift index e282c2964..6374cea58 100644 --- a/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift +++ b/Sources/SwiftGRPCNIO/CallHandlers/ServerStreamingCallHandler.swift @@ -29,7 +29,7 @@ public class ServerStreamingCallHandler public override func processMessage(_ message: RequestMessage) throws { guard let eventObserver = self.eventObserver, let context = self.context else { - throw GRPCError.requestCardinalityViolation + throw GRPCServerError.requestCardinalityViolation } let resultFuture = eventObserver(message) diff --git a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift index 156a44863..466499cff 100644 --- a/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift +++ b/Sources/SwiftGRPCNIO/GRPCChannelHandler.swift @@ -53,7 +53,7 @@ extension GRPCChannelHandler: ChannelInboundHandler { switch requestPart { case .head(let requestHead): guard let callHandler = getCallHandler(channel: ctx.channel, requestHead: requestHead) else { - errorCaught(ctx: ctx, error: GRPCError.unimplementedMethod(requestHead.uri)) + errorCaught(ctx: ctx, error: GRPCServerError.unimplementedMethod(requestHead.uri)) return } diff --git a/Sources/SwiftGRPCNIO/GRPCServerCodec.swift b/Sources/SwiftGRPCNIO/GRPCServerCodec.swift index ef67a75c7..9193c9ea9 100644 --- a/Sources/SwiftGRPCNIO/GRPCServerCodec.swift +++ b/Sources/SwiftGRPCNIO/GRPCServerCodec.swift @@ -35,7 +35,7 @@ extension GRPCServerCodec: ChannelInboundHandler { do { ctx.fireChannelRead(self.wrapInboundOut(.message(try RequestMessage(serializedData: messageAsData)))) } catch { - ctx.fireErrorCaught(GRPCError.requestProtoParseFailure) + ctx.fireErrorCaught(GRPCServerError.requestProtoParseFailure) } case .end: @@ -61,7 +61,7 @@ extension GRPCServerCodec: ChannelOutboundHandler { responseBuffer.write(bytes: messageData) ctx.write(self.wrapOutboundOut(.message(responseBuffer)), promise: promise) } catch { - let error = GRPCError.responseProtoSerializationFailure + let error = GRPCServerError.responseProtoSerializationFailure promise?.fail(error: error) ctx.fireErrorCaught(error) } diff --git a/Sources/SwiftGRPCNIO/GRPCError.swift b/Sources/SwiftGRPCNIO/GRPCServerError.swift similarity index 96% rename from Sources/SwiftGRPCNIO/GRPCError.swift rename to Sources/SwiftGRPCNIO/GRPCServerError.swift index 5106feb07..e42848a54 100644 --- a/Sources/SwiftGRPCNIO/GRPCError.swift +++ b/Sources/SwiftGRPCNIO/GRPCServerError.swift @@ -15,7 +15,7 @@ */ import Foundation -public enum GRPCError: Error, Equatable { +public enum GRPCServerError: Error, Equatable { /// The RPC method is not implemented on the server. case unimplementedMethod(String) @@ -41,7 +41,7 @@ public enum GRPCError: Error, Equatable { case invalidState(String) } -extension GRPCError: GRPCStatusTransformable { +extension GRPCServerError: GRPCStatusTransformable { public func asGRPCStatus() -> GRPCStatus { // These status codes are informed by: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md switch self { diff --git a/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift b/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift index 022eeea98..1bdd169f3 100644 --- a/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift +++ b/Sources/SwiftGRPCNIO/HTTP1ToRawGRPCServerCodec.swift @@ -83,7 +83,7 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { func processHead(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead) throws -> InboundState { guard case .expectingHeaders = inboundState else { - throw GRPCError.invalidState("expecteded state .expectingHeaders, got \(inboundState)") + throw GRPCServerError.invalidState("expecteded state .expectingHeaders, got \(inboundState)") } ctx.fireChannelRead(self.wrapInboundOut(.head(requestHead))) @@ -93,7 +93,7 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { func processBody(ctx: ChannelHandlerContext, body: inout ByteBuffer) throws -> InboundState { guard case .expectingBody(let bodyState) = inboundState else { - throw GRPCError.invalidState("expecteded state .expectingBody(_), got \(inboundState)") + throw GRPCServerError.invalidState("expecteded state .expectingBody(_), got \(inboundState)") } return .expectingBody(try processBodyState(ctx: ctx, initialState: bodyState, messageBuffer: &body)) @@ -113,7 +113,7 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { guard let compressedFlag: Int8 = messageBuffer.readInteger() else { return .expectingCompressedFlag } // TODO: Add support for compression. - guard compressedFlag == 0 else { throw GRPCError.unexpectedCompression } + guard compressedFlag == 0 else { throw GRPCServerError.unexpectedCompression } bodyState = .expectingMessageLength @@ -156,7 +156,7 @@ extension HTTP1ToRawGRPCServerCodec: ChannelInboundHandler { private func processEnd(ctx: ChannelHandlerContext, trailers: HTTPHeaders?) throws -> InboundState { if let trailers = trailers { - throw GRPCError.invalidState("unexpected trailers received \(trailers)") + throw GRPCServerError.invalidState("unexpected trailers received \(trailers)") } ctx.fireChannelRead(self.wrapInboundOut(.end)) @@ -192,7 +192,7 @@ extension HTTP1ToRawGRPCServerCodec: ChannelOutboundHandler { outboundState = .expectingBodyOrStatus case .status(let status): - // If we error before sending the initial headers (e.g. unimplemtned method) then we won't have sent the request head. + // If we error before sending the initial headers (e.g. unimplemented method) then we won't have sent the request head. // NIOHTTP2 doesn't support sending a single frame as a "Trailers-Only" response so we still need to loop back and // send the request head first. if case .expectingHeaders = outboundState { diff --git a/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift b/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift index 86c128d31..b0a30c178 100644 --- a/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift +++ b/Sources/SwiftGRPCNIO/LoggingServerErrorDelegate.swift @@ -19,6 +19,6 @@ public class LoggingServerErrorDelegate: ServerErrorDelegate { public init() {} public func observe(_ error: Error) { - print("[grpc-server] \(error)") + print("[grpc-server][\(Date())] \(error)") } } diff --git a/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift b/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift index 15b312959..83521a6e3 100644 --- a/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift +++ b/Sources/SwiftGRPCNIO/ServerErrorDelegate.swift @@ -17,7 +17,7 @@ import Foundation import NIO public protocol ServerErrorDelegate: class { - //: FIXME: provide more context about where the error was thrown. + //! FIXME: Provide more context about where the error was thrown. /// Called when an error is thrown in the channel pipeline. func observe(_ error: Error) diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift index 863bb11ae..0999801fd 100644 --- a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerResponseCapturingTestCase.swift @@ -34,10 +34,6 @@ class GRPCChannelHandlerResponseCapturingTestCase: XCTestCase { var errorCollector: CollectingServerErrorDelegate = CollectingServerErrorDelegate() - override func setUp() { - errorCollector.errors.removeAll() - } - /// Waits for `count` responses to be collected and then returns them. The test fails if the number /// of collected responses does not match the expected. /// diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift index 60e1e7afd..4d7e190f2 100644 --- a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift @@ -22,8 +22,8 @@ class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { try channel.writeInbound(RawGRPCServerRequestPart.head(requestHead)) } - let expectedError = GRPCError.unimplementedMethod("unimplementedMethodName") - XCTAssertEqual(expectedError, errorCollector.errors.first as? GRPCError) + let expectedError = GRPCServerError.unimplementedMethod("unimplementedMethodName") + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) XCTAssertNoThrow(try extractStatus(responses[0])) { status in XCTAssertEqual(status, expectedError.asGRPCStatus()) @@ -59,8 +59,8 @@ class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { try channel.writeInbound(RawGRPCServerRequestPart.message(buffer)) } - let expectedError = GRPCError.requestProtoParseFailure - XCTAssertEqual(expectedError, errorCollector.errors.first as? GRPCError) + let expectedError = GRPCServerError.requestProtoParseFailure + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractStatus(responses[1])) { status in @@ -77,8 +77,8 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel, compression: true))) } - let expectedError = GRPCError.unexpectedCompression - XCTAssertEqual(expectedError, errorCollector.errors.first as? GRPCError) + let expectedError = GRPCServerError.unexpectedCompression + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractStatus(responses[1])) { status in @@ -124,8 +124,8 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas try channel.writeInbound(HTTPServerRequestPart.body(buffer)) } - let expectedError = GRPCError.requestProtoParseFailure - XCTAssertEqual(expectedError, errorCollector.errors.first as? GRPCError) + let expectedError = GRPCServerError.requestProtoParseFailure + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) XCTAssertNoThrow(try extractHeaders(responses[0])) XCTAssertNoThrow(try extractStatus(responses[1])) { status in @@ -147,7 +147,9 @@ class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCas try channel.writeInbound(HTTPServerRequestPart.end(trailers)) } - if case .invalidState(let message)? = errorCollector.errors.first as? GRPCError { + XCTAssertEqual(errorCollector.errors.count, 1) + + if case .invalidState(let message)? = errorCollector.errors.first as? GRPCServerError { XCTAssert(message.contains("trailers")) } else { XCTFail("\(String(describing: errorCollector.errors.first)) was not GRPCError.invalidState") From 2b82ae83025240b48ff27a679c776fd32b157c6c Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 26 Feb 2019 10:52:51 +0000 Subject: [PATCH 06/10] Split out GRPCChannelHandlerTests and HTTPToRawGRPCServerCodecTests --- .../GRPCChannelHandlerTests.swift | 144 +---------------- .../HTTP1ToRawGRPCServerCodecTests.swift | 153 ++++++++++++++++++ 2 files changed, 160 insertions(+), 137 deletions(-) create mode 100644 Tests/SwiftGRPCNIOTests/HTTP1ToRawGRPCServerCodecTests.swift diff --git a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift index 4d7e190f2..b97c3f49f 100644 --- a/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift +++ b/Tests/SwiftGRPCNIOTests/GRPCChannelHandlerTests.swift @@ -4,18 +4,15 @@ import NIO import NIOHTTP1 @testable import SwiftGRPCNIO -func gRPCMessage(channel: EmbeddedChannel, compression: Bool = false, message: Data? = nil) -> ByteBuffer { - let messageLength = message?.count ?? 0 - var buffer = channel.allocator.buffer(capacity: 5 + messageLength) - buffer.write(integer: Int8(compression ? 1 : 0)) - buffer.write(integer: UInt32(messageLength)) - if let bytes = message { - buffer.write(bytes: bytes) +class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { + static var allTests: [(String, (GRPCChannelHandlerTests) -> () throws -> Void)] { + return [ + ("testUnimplementedMethodReturnsUnimplementedStatus", testUnimplementedMethodReturnsUnimplementedStatus), + ("testImplementedMethodReturnsHeadersMessageAndStatus", testImplementedMethodReturnsHeadersMessageAndStatus), + ("testImplementedMethodReturnsStatusForBadlyFormedProto", testImplementedMethodReturnsStatusForBadlyFormedProto), + ] } - return buffer -} -class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { func testUnimplementedMethodReturnsUnimplementedStatus() throws { let responses = try waitForGRPCChannelHandlerResponses(count: 1) { channel in let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "unimplementedMethodName") @@ -68,130 +65,3 @@ class GRPCChannelHandlerTests: GRPCChannelHandlerResponseCapturingTestCase { } } } - -class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCase { - func testInternalErrorStatusReturnedWhenCompressionFlagIsSet() throws { - let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in - let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") - try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) - try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel, compression: true))) - } - - let expectedError = GRPCServerError.unexpectedCompression - XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) - - XCTAssertNoThrow(try extractHeaders(responses[0])) - XCTAssertNoThrow(try extractStatus(responses[1])) { status in - XCTAssertEqual(status, expectedError.asGRPCStatus()) - } - } - - func testMessageCanBeSentAcrossMultipleByteBuffers() throws { - let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in - let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") - // Sending the header allocates a buffer. - try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) - - let request = Echo_EchoRequest.with { $0.text = "echo!" } - let requestAsData = try request.serializedData() - - var buffer = channel.allocator.buffer(capacity: 1) - buffer.write(integer: Int8(0)) - try channel.writeInbound(HTTPServerRequestPart.body(buffer)) - - buffer = channel.allocator.buffer(capacity: 4) - buffer.write(integer: Int32(requestAsData.count)) - try channel.writeInbound(HTTPServerRequestPart.body(buffer)) - - buffer = channel.allocator.buffer(capacity: requestAsData.count) - buffer.write(bytes: requestAsData) - try channel.writeInbound(HTTPServerRequestPart.body(buffer)) - } - - XCTAssertNoThrow(try extractHeaders(responses[0])) - XCTAssertNoThrow(try extractMessage(responses[1])) - XCTAssertNoThrow(try extractStatus(responses[2])) { status in - XCTAssertEqual(status, .ok) - } - } - - func testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized() throws { - let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in - let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") - try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) - - let buffer = gRPCMessage(channel: channel, message: Data(bytes: [42])) - try channel.writeInbound(HTTPServerRequestPart.body(buffer)) - } - - let expectedError = GRPCServerError.requestProtoParseFailure - XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) - - XCTAssertNoThrow(try extractHeaders(responses[0])) - XCTAssertNoThrow(try extractStatus(responses[1])) { status in - XCTAssertEqual(status, expectedError.asGRPCStatus()) - } - } - - func testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest() throws { - let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in - // We have to use "Collect" (client streaming) as the tests rely on `EmbeddedChannel` which runs in this thread. - // In the current server implementation, responses from unary calls send a status immediately after sending the response. - // As such, a unary "Get" would return an "ok" status before the trailers would be sent. - let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Collect") - try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) - try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) - - var trailers = HTTPHeaders() - trailers.add(name: "foo", value: "bar") - try channel.writeInbound(HTTPServerRequestPart.end(trailers)) - } - - XCTAssertEqual(errorCollector.errors.count, 1) - - if case .invalidState(let message)? = errorCollector.errors.first as? GRPCServerError { - XCTAssert(message.contains("trailers")) - } else { - XCTFail("\(String(describing: errorCollector.errors.first)) was not GRPCError.invalidState") - } - - XCTAssertNoThrow(try extractHeaders(responses[0])) - XCTAssertNoThrow(try extractStatus(responses[1])) { status in - XCTAssertEqual(status, .processingError) - } - } - - func testOnlyOneStatusIsReturned() throws { - let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in - let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") - try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) - try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) - - // Sending trailers with `.end` should trigger an error. However, writing a message to a unary call - // will trigger a response and status to be sent back. Since we're using `EmbeddedChannel` this will - // be done before the trailers are sent. If a 4th resposne were to be sent (for the error status) then - // the test would fail. - - var trailers = HTTPHeaders() - trailers.add(name: "foo", value: "bar") - try channel.writeInbound(HTTPServerRequestPart.end(trailers)) - } - - XCTAssertNoThrow(try extractHeaders(responses[0])) - XCTAssertNoThrow(try extractMessage(responses[1])) - XCTAssertNoThrow(try extractStatus(responses[2])) { status in - XCTAssertEqual(status, .ok) - } - } - - override func waitForGRPCChannelHandlerResponses( - count: Int, - servicesByName: [String: CallHandlerProvider] = GRPCChannelHandlerResponseCapturingTestCase.echoProvider, - callback: @escaping (EmbeddedChannel) throws -> Void - ) throws -> [RawGRPCServerResponsePart] { - return try super.waitForGRPCChannelHandlerResponses(count: count, servicesByName: servicesByName) { channel in - _ = channel.pipeline.addHandlers(HTTP1ToRawGRPCServerCodec(), first: true) - .thenThrowing { _ in try callback(channel) } - } - } -} diff --git a/Tests/SwiftGRPCNIOTests/HTTP1ToRawGRPCServerCodecTests.swift b/Tests/SwiftGRPCNIOTests/HTTP1ToRawGRPCServerCodecTests.swift new file mode 100644 index 000000000..bb17fef8e --- /dev/null +++ b/Tests/SwiftGRPCNIOTests/HTTP1ToRawGRPCServerCodecTests.swift @@ -0,0 +1,153 @@ +import Foundation +import XCTest +import NIO +import NIOHTTP1 +@testable import SwiftGRPCNIO + +func gRPCMessage(channel: EmbeddedChannel, compression: Bool = false, message: Data? = nil) -> ByteBuffer { + let messageLength = message?.count ?? 0 + var buffer = channel.allocator.buffer(capacity: 5 + messageLength) + buffer.write(integer: Int8(compression ? 1 : 0)) + buffer.write(integer: UInt32(messageLength)) + if let bytes = message { + buffer.write(bytes: bytes) + } + return buffer +} + +class HTTP1ToRawGRPCServerCodecTests: GRPCChannelHandlerResponseCapturingTestCase { + static var allTests: [(String, (HTTP1ToRawGRPCServerCodecTests) -> () throws -> Void)] { + return [ + ("testInternalErrorStatusReturnedWhenCompressionFlagIsSet", testInternalErrorStatusReturnedWhenCompressionFlagIsSet), + ("testMessageCanBeSentAcrossMultipleByteBuffers", testMessageCanBeSentAcrossMultipleByteBuffers), + ("testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized", testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized), + ("testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest", testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest), + ("testOnlyOneStatusIsReturned", testOnlyOneStatusIsReturned), + ] + } + + func testInternalErrorStatusReturnedWhenCompressionFlagIsSet() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel, compression: true))) + } + + let expectedError = GRPCServerError.unexpectedCompression + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + XCTAssertEqual(status, expectedError.asGRPCStatus()) + } + } + + func testMessageCanBeSentAcrossMultipleByteBuffers() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + // Sending the header allocates a buffer. + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + + let request = Echo_EchoRequest.with { $0.text = "echo!" } + let requestAsData = try request.serializedData() + + var buffer = channel.allocator.buffer(capacity: 1) + buffer.write(integer: Int8(0)) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + + buffer = channel.allocator.buffer(capacity: 4) + buffer.write(integer: Int32(requestAsData.count)) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + + buffer = channel.allocator.buffer(capacity: requestAsData.count) + buffer.write(bytes: requestAsData) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractMessage(responses[1])) + XCTAssertNoThrow(try extractStatus(responses[2])) { status in + XCTAssertEqual(status, .ok) + } + } + + func testInternalErrorStatusIsReturnedIfMessageCannotBeDeserialized() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + + let buffer = gRPCMessage(channel: channel, message: Data(bytes: [42])) + try channel.writeInbound(HTTPServerRequestPart.body(buffer)) + } + + let expectedError = GRPCServerError.requestProtoParseFailure + XCTAssertEqual([expectedError], errorCollector.errors as? [GRPCServerError]) + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + XCTAssertEqual(status, expectedError.asGRPCStatus()) + } + } + + func testInternalErrorStatusIsReturnedWhenSendingTrailersInRequest() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 2) { channel in + // We have to use "Collect" (client streaming) as the tests rely on `EmbeddedChannel` which runs in this thread. + // In the current server implementation, responses from unary calls send a status immediately after sending the response. + // As such, a unary "Get" would return an "ok" status before the trailers would be sent. + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Collect") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) + + var trailers = HTTPHeaders() + trailers.add(name: "foo", value: "bar") + try channel.writeInbound(HTTPServerRequestPart.end(trailers)) + } + + XCTAssertEqual(errorCollector.errors.count, 1) + + if case .invalidState(let message)? = errorCollector.errors.first as? GRPCServerError { + XCTAssert(message.contains("trailers")) + } else { + XCTFail("\(String(describing: errorCollector.errors.first)) was not GRPCError.invalidState") + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractStatus(responses[1])) { status in + XCTAssertEqual(status, .processingError) + } + } + + func testOnlyOneStatusIsReturned() throws { + let responses = try waitForGRPCChannelHandlerResponses(count: 3) { channel in + let requestHead = HTTPRequestHead(version: .init(major: 2, minor: 0), method: .POST, uri: "/echo.Echo/Get") + try channel.writeInbound(HTTPServerRequestPart.head(requestHead)) + try channel.writeInbound(HTTPServerRequestPart.body(gRPCMessage(channel: channel))) + + // Sending trailers with `.end` should trigger an error. However, writing a message to a unary call + // will trigger a response and status to be sent back. Since we're using `EmbeddedChannel` this will + // be done before the trailers are sent. If a 4th resposne were to be sent (for the error status) then + // the test would fail. + + var trailers = HTTPHeaders() + trailers.add(name: "foo", value: "bar") + try channel.writeInbound(HTTPServerRequestPart.end(trailers)) + } + + XCTAssertNoThrow(try extractHeaders(responses[0])) + XCTAssertNoThrow(try extractMessage(responses[1])) + XCTAssertNoThrow(try extractStatus(responses[2])) { status in + XCTAssertEqual(status, .ok) + } + } + + override func waitForGRPCChannelHandlerResponses( + count: Int, + servicesByName: [String: CallHandlerProvider] = GRPCChannelHandlerResponseCapturingTestCase.echoProvider, + callback: @escaping (EmbeddedChannel) throws -> Void + ) throws -> [RawGRPCServerResponsePart] { + return try super.waitForGRPCChannelHandlerResponses(count: count, servicesByName: servicesByName) { channel in + _ = channel.pipeline.addHandlers(HTTP1ToRawGRPCServerCodec(), first: true) + .thenThrowing { _ in try callback(channel) } + } + } +} From 344a6e3707f8c963c948235595f9e09a8aab0cab Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 26 Feb 2019 10:54:00 +0000 Subject: [PATCH 07/10] Update LinuxMain --- Tests/LinuxMain.swift | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index b0abb3e49..5c13ee0aa 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -38,4 +38,7 @@ XCTMain([ // SwiftGRPCNIO testCase(NIOServerTests.allTests) + testCase(NIOServerWebTests.allTests) + testCase(GRPCChannelHandlerTests.allTests) + testCase(HTTP1ToRawGRPCServerCodecTests.allTests) ]) From 0fe5dd247c4dd56972a5cad2668c252fcfa497e8 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 26 Feb 2019 12:51:16 +0000 Subject: [PATCH 08/10] Add missing commas to LinuxMain --- Tests/LinuxMain.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 5c13ee0aa..91d0a9807 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -37,8 +37,8 @@ XCTMain([ testCase(ServerTimeoutTests.allTests), // SwiftGRPCNIO - testCase(NIOServerTests.allTests) - testCase(NIOServerWebTests.allTests) - testCase(GRPCChannelHandlerTests.allTests) + testCase(NIOServerTests.allTests), + testCase(NIOServerWebTests.allTests), + testCase(GRPCChannelHandlerTests.allTests), testCase(HTTP1ToRawGRPCServerCodecTests.allTests) ]) From e39d0a20a5d1f36d487341b704c2f6d266901822 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 26 Feb 2019 13:35:16 +0000 Subject: [PATCH 09/10] Fix grpc-web testUnaryLotsOfRequests on Linux --- Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift b/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift index 342975796..44f094467 100644 --- a/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift +++ b/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift @@ -114,13 +114,15 @@ extension NIOServerWebTests { // Sending that many requests at once can sometimes trip things up, it seems. let clockStart = clock() let numberOfRequests = 2_000 + let completionHandlerExpectation = expectation(description: "completion handler called") -#if os(macOS) - // Linux version of Swift doesn't have this API yet. + // Linux version of Swift doesn't have the `expectedFulfillmentCount` API yet. // Implemented in https://github.com/apple/swift-corelibs-xctest/pull/228 but not yet // released. - completionHandlerExpectation.expectedFulfillmentCount = numberOfRequests -#endif + // + // Wait for the expected number of responses (i.e. `numberOfRequests`) instead. + var responses = 0 + for i in 0.. Date: Tue, 26 Feb 2019 14:38:12 +0000 Subject: [PATCH 10/10] Disable broken Linux test --- Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift b/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift index 44f094467..ce87064b5 100644 --- a/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift +++ b/Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift @@ -25,7 +25,8 @@ class NIOServerWebTests: NIOServerTestCase { static var allTests: [(String, (NIOServerWebTests) -> () throws -> Void)] { return [ ("testUnary", testUnary), - ("testUnaryLotsOfRequests", testUnaryLotsOfRequests), + //! FIXME: Broken on Linux: https://github.com/grpc/grpc-swift/issues/382 + // ("testUnaryLotsOfRequests", testUnaryLotsOfRequests), ("testServerStreaming", testServerStreaming), ] }