diff --git a/Sources/GRPCHTTP2Core/Client/GRPCClientStreamHandler.swift b/Sources/GRPCHTTP2Core/Client/GRPCClientStreamHandler.swift new file mode 100644 index 000000000..6db75f843 --- /dev/null +++ b/Sources/GRPCHTTP2Core/Client/GRPCClientStreamHandler.swift @@ -0,0 +1,238 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import NIOCore +import NIOHTTP2 + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +final class GRPCClientStreamHandler: ChannelDuplexHandler { + typealias InboundIn = HTTP2Frame.FramePayload + typealias InboundOut = RPCResponsePart + + typealias OutboundIn = RPCRequestPart + typealias OutboundOut = HTTP2Frame.FramePayload + + private var stateMachine: GRPCStreamStateMachine + + private var isReading = false + private var flushPending = false + + init( + methodDescriptor: MethodDescriptor, + scheme: Scheme, + outboundEncoding: CompressionAlgorithm, + acceptedEncodings: [CompressionAlgorithm], + maximumPayloadSize: Int, + skipStateMachineAssertions: Bool = false + ) { + self.stateMachine = .init( + configuration: .client( + .init( + methodDescriptor: methodDescriptor, + scheme: scheme, + outboundEncoding: outboundEncoding, + acceptedEncodings: acceptedEncodings + ) + ), + maximumPayloadSize: maximumPayloadSize, + skipAssertions: skipStateMachineAssertions + ) + } +} + +// - MARK: ChannelInboundHandler + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension GRPCClientStreamHandler { + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.isReading = true + let frame = self.unwrapInboundIn(data) + switch frame { + case .data(let frameData): + let endStream = frameData.endStream + switch frameData.data { + case .byteBuffer(let buffer): + do { + try self.stateMachine.receive(buffer: buffer, endStream: endStream) + loop: while true { + switch self.stateMachine.nextInboundMessage() { + case .receiveMessage(let message): + context.fireChannelRead(self.wrapInboundOut(.message(message))) + case .awaitMoreMessages: + break loop + case .noMoreMessages: + context.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + break loop + } + } + } catch { + context.fireErrorCaught(error) + } + + case .fileRegion: + preconditionFailure("Unexpected IOData.fileRegion") + } + + case .headers(let headers): + do { + let action = try self.stateMachine.receive( + headers: headers.headers, + endStream: headers.endStream + ) + switch action { + case .receivedMetadata(let metadata): + context.fireChannelRead(self.wrapInboundOut(.metadata(metadata))) + + case .rejectRPC: + throw RPCError( + code: .internalError, + message: "Client cannot get rejectRPC." + ) + + case .receivedStatusAndMetadata(let status, let metadata): + context.fireChannelRead(self.wrapInboundOut(.status(status, metadata))) + + case .doNothing: + () + } + } catch { + context.fireErrorCaught(error) + } + + case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate, + .alternativeService, .origin: + () + } + } + + func channelReadComplete(context: ChannelHandlerContext) { + self.isReading = false + if self.flushPending { + self.flushPending = false + self.flush(context: context) + } + context.fireChannelReadComplete() + } + + func handlerRemoved(context: ChannelHandlerContext) { + self.stateMachine.tearDown() + } +} + +// - MARK: ChannelOutboundHandler + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension GRPCClientStreamHandler { + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + switch self.unwrapOutboundIn(data) { + case .metadata(let metadata): + do { + self.flushPending = true + let headers = try self.stateMachine.send(metadata: metadata) + context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil) + // TODO: move the promise handling into the state machine + promise?.succeed() + } catch { + context.fireErrorCaught(error) + // TODO: move the promise handling into the state machine + promise?.fail(error) + } + + case .message(let message): + do { + try self.stateMachine.send(message: message) + // TODO: move the promise handling into the state machine + promise?.succeed() + } catch { + context.fireErrorCaught(error) + // TODO: move the promise handling into the state machine + promise?.fail(error) + } + } + } + + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + switch mode { + case .output, .all: + do { + try self.stateMachine.closeOutbound() + // Force a flush by calling _flush + // (otherwise, we'd skip flushing if we're in a read loop) + self._flush(context: context) + context.close(mode: mode, promise: promise) + } catch { + promise?.fail(error) + context.fireErrorCaught(error) + } + + case .input: + context.close(mode: .input, promise: promise) + } + } + + func flush(context: ChannelHandlerContext) { + if self.isReading { + // We don't want to flush yet if we're still in a read loop. + self.flushPending = true + return + } + + self._flush(context: context) + } + + private func _flush(context: ChannelHandlerContext) { + do { + loop: while true { + switch try self.stateMachine.nextOutboundMessage() { + case .sendMessage(let byteBuffer): + self.flushPending = true + context.write( + self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), + promise: nil + ) + + case .noMoreMessages: + // Write an empty data frame with the EOS flag set, to signal the RPC + // request is now finished. + context.write( + self.wrapOutboundOut( + HTTP2Frame.FramePayload.data( + .init( + data: .byteBuffer(.init()), + endStream: true + ) + ) + ), + promise: nil + ) + + context.flush() + break loop + + case .awaitMoreMessages: + if self.flushPending { + self.flushPending = false + context.flush() + } + break loop + } + } + } catch { + context.fireErrorCaught(error) + } + } +} diff --git a/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift b/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift index eb37b76ff..41961b568 100644 --- a/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift +++ b/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift @@ -373,8 +373,12 @@ struct GRPCStreamStateMachine { mutating func receive(headers: HPACKHeaders, endStream: Bool) throws -> OnMetadataReceived { switch self.configuration { - case .client: - return try self.clientReceive(headers: headers, endStream: endStream) + case .client(let clientConfiguration): + return try self.clientReceive( + headers: headers, + endStream: endStream, + configuration: clientConfiguration + ) case .server(let serverConfiguration): return try self.serverReceive( headers: headers, @@ -567,9 +571,7 @@ extension GRPCStreamStateMachine { case .clientOpenServerClosed(let state): self.state = .clientClosedServerClosed(.init(previousState: state)) case .clientClosedServerIdle, .clientClosedServerOpen, .clientClosedServerClosed: - try self.invalidState( - "Client is closed, cannot send a message." - ) + try self.invalidState("Client is already closed.") } } @@ -665,7 +667,7 @@ extension GRPCStreamStateMachine { .receivedStatusAndMetadata( status: .init( code: .internalError, - message: "Missing \(GRPCHTTP2Keys.contentType) header" + message: "Missing \(GRPCHTTP2Keys.contentType.rawValue) header" ), metadata: Metadata(headers: metadata) ) @@ -680,10 +682,15 @@ extension GRPCStreamStateMachine { case success(CompressionAlgorithm) } - private func processInboundEncoding(_ metadata: HPACKHeaders) -> ProcessInboundEncodingResult { + private func processInboundEncoding( + headers: HPACKHeaders, + configuration: GRPCStreamStateMachineConfiguration.ClientConfiguration + ) -> ProcessInboundEncodingResult { let inboundEncoding: CompressionAlgorithm - if let serverEncoding = metadata.first(name: GRPCHTTP2Keys.encoding.rawValue) { - guard let parsedEncoding = CompressionAlgorithm(rawValue: serverEncoding) else { + if let serverEncoding = headers.first(name: GRPCHTTP2Keys.encoding.rawValue) { + guard let parsedEncoding = CompressionAlgorithm(rawValue: serverEncoding), + configuration.acceptedEncodings.contains(parsedEncoding) + else { return .error( .receivedStatusAndMetadata( status: .init( @@ -691,7 +698,7 @@ extension GRPCStreamStateMachine { message: "The server picked a compression algorithm ('\(serverEncoding)') the client does not know about." ), - metadata: Metadata(headers: metadata) + metadata: Metadata(headers: headers) ) ) } @@ -732,7 +739,8 @@ extension GRPCStreamStateMachine { private mutating func clientReceive( headers: HPACKHeaders, - endStream: Bool + endStream: Bool, + configuration: GRPCStreamStateMachineConfiguration.ClientConfiguration ) throws -> OnMetadataReceived { switch self.state { case .clientOpenServerIdle(let state): @@ -750,7 +758,7 @@ extension GRPCStreamStateMachine { self.state = .clientOpenServerClosed(.init(previousState: state)) return try self.validateAndReturnStatusAndMetadata(headers) case (.valid, false): - switch self.processInboundEncoding(headers) { + switch self.processInboundEncoding(headers: headers, configuration: configuration) { case .error(let failure): return failure case .success(let inboundEncoding): @@ -798,7 +806,7 @@ extension GRPCStreamStateMachine { self.state = .clientClosedServerClosed(.init(previousState: state)) return try self.validateAndReturnStatusAndMetadata(headers) case (.valid, false): - switch self.processInboundEncoding(headers) { + switch self.processInboundEncoding(headers: headers, configuration: configuration) { case .error(let failure): return failure case .success(let inboundEncoding): diff --git a/Tests/GRPCHTTP2CoreTests/Client/GRPCClientStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Client/GRPCClientStreamHandlerTests.swift new file mode 100644 index 000000000..3b0d461b1 --- /dev/null +++ b/Tests/GRPCHTTP2CoreTests/Client/GRPCClientStreamHandlerTests.swift @@ -0,0 +1,724 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import NIOCore +import NIOEmbedded +import NIOHPACK +import NIOHTTP1 +import NIOHTTP2 +import XCTest + +@testable import GRPCHTTP2Core + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +final class GRPCClientStreamHandlerTests: XCTestCase { + func testH2FramesAreIgnored() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + let framesToBeIgnored: [HTTP2Frame.FramePayload] = [ + .ping(.init(), ack: false), + .goAway(lastStreamID: .rootStream, errorCode: .cancel, opaqueData: nil), + // TODO: add .priority(StreamPriorityData) - right now, StreamPriorityData's + // initialiser is internal, so I can't create one of these frames. + .rstStream(.cancel), + .settings(.ack), + .pushPromise(.init(pushedStreamID: .maxID, headers: [:])), + .windowUpdate(windowSizeIncrement: 4), + .alternativeService(origin: nil, field: nil), + .origin([]), + ] + + for toBeIgnored in framesToBeIgnored { + XCTAssertNoThrow(try channel.writeInbound(toBeIgnored)) + XCTAssertNil(try channel.readInbound(as: HTTP2Frame.FramePayload.self)) + } + } + + func testServerInitialMetadataMissingHTTPStatusCodeResultsInFinishedRPC() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 1, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + let request = RPCRequestPart.metadata([:]) + XCTAssertNoThrow(try channel.writeOutbound(request)) + + // Receive server's initial metadata without :status + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue + ] + + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + .status( + .init(code: .unknown, message: "HTTP Status Code is missing."), + Metadata(headers: serverInitialMetadata) + ) + ) + } + + func testServerInitialMetadata1xxHTTPStatusCodeResultsInNothingRead() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 1, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + let request = RPCRequestPart.metadata([:]) + XCTAssertNoThrow(try channel.writeOutbound(request)) + + // Receive server's initial metadata with 1xx status + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: "104", + GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue, + ] + + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + + XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self)) + } + + func testServerInitialMetadataOtherNon200HTTPStatusCodeResultsInFinishedRPC() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 1, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + let request = RPCRequestPart.metadata([:]) + XCTAssertNoThrow(try channel.writeOutbound(request)) + + // Receive server's initial metadata with non-200 and non-1xx :status + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: String(HTTPResponseStatus.tooManyRequests.code), + GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue, + ] + + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + .status( + .init(code: .unavailable, message: "Unexpected non-200 HTTP Status Code."), + Metadata(headers: serverInitialMetadata) + ) + ) + } + + func testServerInitialMetadataMissingContentTypeResultsInFinishedRPC() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 1, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + let request = RPCRequestPart.metadata([:]) + XCTAssertNoThrow(try channel.writeOutbound(request)) + + // Receive server's initial metadata without content-type + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: "200" + ] + + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + .status( + .init(code: .internalError, message: "Missing content-type header"), + Metadata(headers: serverInitialMetadata) + ) + ) + } + + func testNotAcceptedEncodingResultsInFinishedRPC() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .deflate, + acceptedEncodings: [.deflate], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + XCTAssertNoThrow( + try channel.writeOutbound(RPCRequestPart.metadata(Metadata())) + ) + + // Make sure we have sent right metadata. + let writtenMetadata = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenMetadata.headers, + [ + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + GRPCHTTP2Keys.encoding.rawValue: "deflate", + GRPCHTTP2Keys.acceptEncoding.rawValue: "deflate", + ] + ) + + // Server sends initial metadata with unsupported encoding + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue, + GRPCHTTP2Keys.encoding.rawValue: "gzip", + ] + + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + .status( + .init( + code: .internalError, + message: + "The server picked a compression algorithm ('gzip') the client does not know about." + ), + Metadata(headers: serverInitialMetadata) + ) + ) + } + + func testOverMaximumPayloadSize() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 1, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + XCTAssertNoThrow( + try channel.writeOutbound(RPCRequestPart.metadata(Metadata())) + ) + + // Make sure we have sent right metadata. + let writtenMetadata = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenMetadata.headers, + [ + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + ) + + // Server sends initial metadata + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: ContentType.grpc.canonicalValue, + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + .metadata(Metadata(headers: serverInitialMetadata)) + ) + + // Server sends message over payload limit + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let clientDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true) + XCTAssertThrowsError( + ofType: RPCError.self, + try channel.writeInbound(HTTP2Frame.FramePayload.data(clientDataPayload)) + ) { error in + XCTAssertEqual(error.code, .resourceExhausted) + XCTAssertEqual( + error.message, + "Message has exceeded the configured maximum payload size (max: 1, actual: 42)" + ) + } + + // Make sure we didn't read the received message + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + } + + func testServerEndsStream() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 1, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Write client's initial metadata + XCTAssertNoThrow(try channel.writeOutbound(RPCRequestPart.metadata(Metadata()))) + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + let writtenInitialMetadata = try channel.assertReadHeadersOutbound() + XCTAssertEqual(writtenInitialMetadata.headers, clientInitialMetadata) + + // Receive server's initial metadata with end stream set + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.grpcStatus.rawValue: "0", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers( + .init( + headers: serverInitialMetadata, + endStream: true + ) + ) + ) + ) + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + .status( + .init(code: .ok, message: ""), + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + ] + ) + ) + + // We should throw if the server sends another message, since it's closed the stream already. + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let serverDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true) + XCTAssertThrowsError( + ofType: RPCError.self, + try channel.writeInbound(HTTP2Frame.FramePayload.data(serverDataPayload)) + ) { error in + XCTAssertEqual(error.code, .internalError) + XCTAssertEqual(error.message, "Cannot have received anything from a closed server.") + } + } + + func testNormalFlow() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 100, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + let request = RPCRequestPart.metadata([:]) + XCTAssertNoThrow(try channel.writeOutbound(request)) + + // Make sure we have sent the corresponding frame, and that nothing has been written back. + let writtenHeaders = try channel.assertReadHeadersOutbound() + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + + ] + ) + XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self)) + + // Receive server's initial metadata + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + RPCResponsePart.metadata(Metadata(headers: serverInitialMetadata)) + ) + + // Send a message + XCTAssertNoThrow( + try channel.writeOutbound(RPCRequestPart.message(.init(repeating: 1, count: 42))) + ) + + // Assert we wrote it successfully into the channel + let writtenMessage = try channel.assertReadDataOutbound() + var expectedBuffer = ByteBuffer() + expectedBuffer.writeInteger(UInt8(0)) // not compressed + expectedBuffer.writeInteger(UInt32(42)) // message length + expectedBuffer.writeRepeatingByte(1, count: 42) // message + XCTAssertEqual(writtenMessage.data, .byteBuffer(expectedBuffer)) + + // Half-close the outbound end: this would be triggered by finishing the client's writer. + XCTAssertNoThrow(channel.close(mode: .output, promise: nil)) + + // Flush to make sure the EOS is written. + channel.flush() + + // Make sure the EOS frame was sent + let emptyEOSFrame = try channel.assertReadDataOutbound() + XCTAssertEqual(emptyEOSFrame.data, .byteBuffer(.init())) + XCTAssertTrue(emptyEOSFrame.endStream) + + // Make sure we cannot write anymore because client's closed. + XCTAssertThrowsError( + ofType: RPCError.self, + try channel.writeOutbound(RPCRequestPart.message(.init(repeating: 1, count: 42))) + ) { error in + XCTAssertEqual(error.code, .internalError) + XCTAssertEqual(error.message, "Client is closed, cannot send a message.") + } + + // This is needed to clear the EmbeddedChannel's stored error, otherwise + // it will be thrown when writing inbound. + try? channel.throwIfErrorCaught() + + // Server sends back response message + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let serverDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer)) + XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(serverDataPayload))) + + // Make sure we read the message properly + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + RPCResponsePart.message([UInt8](repeating: 0, count: 42)) + ) + + // Server sends status to end RPC + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers( + .init(headers: [ + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.dataLoss.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: "Test data loss", + "custom-header": "custom-value", + ]) + ) + ) + ) + + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + .status(.init(code: .dataLoss, message: "Test data loss"), ["custom-header": "custom-value"]) + ) + } + + func testReceiveMessageSplitAcrossMultipleBuffers() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 100, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + let request = RPCRequestPart.metadata([:]) + XCTAssertNoThrow(try channel.writeOutbound(request)) + + // Make sure we have sent the corresponding frame, and that nothing has been written back. + let writtenHeaders = try channel.assertReadHeadersOutbound() + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + + ] + ) + XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self)) + + // Receive server's initial metadata + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + RPCResponsePart.metadata(Metadata(headers: serverInitialMetadata)) + ) + + // Send a message + XCTAssertNoThrow( + try channel.writeOutbound(RPCRequestPart.message(.init(repeating: 1, count: 42))) + ) + + // Assert we wrote it successfully into the channel + let writtenMessage = try channel.assertReadDataOutbound() + var expectedBuffer = ByteBuffer() + expectedBuffer.writeInteger(UInt8(0)) // not compressed + expectedBuffer.writeInteger(UInt32(42)) // message length + expectedBuffer.writeRepeatingByte(1, count: 42) // message + XCTAssertEqual(writtenMessage.data, .byteBuffer(expectedBuffer)) + + // Receive server's first message + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self)) + + buffer.clear() + buffer.writeInteger(UInt32(30)) // message length + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self)) + + buffer.clear() + buffer.writeRepeatingByte(0, count: 10) // first part of the message + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self)) + + buffer.clear() + buffer.writeRepeatingByte(1, count: 10) // second part of the message + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self)) + + buffer.clear() + buffer.writeRepeatingByte(2, count: 10) // third part of the message + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + + // Make sure we read the message properly + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + RPCResponsePart.message( + [UInt8](repeating: 0, count: 10) + [UInt8](repeating: 1, count: 10) + + [UInt8](repeating: 2, count: 10) + ) + ) + } + + func testSendMultipleMessagesInSingleBuffer() throws { + let handler = GRPCClientStreamHandler( + methodDescriptor: .init(service: "test", method: "test"), + scheme: .http, + outboundEncoding: .identity, + acceptedEncodings: [], + maximumPayloadSize: 100, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Send client's initial metadata + let request = RPCRequestPart.metadata([:]) + XCTAssertNoThrow(try channel.writeOutbound(request)) + + // Make sure we have sent the corresponding frame, and that nothing has been written back. + let writtenHeaders = try channel.assertReadHeadersOutbound() + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + + ] + ) + XCTAssertNil(try channel.readInbound(as: RPCResponsePart.self)) + + // Receive server's initial metadata + let serverInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: serverInitialMetadata)) + ) + ) + XCTAssertEqual( + try channel.readInbound(as: RPCResponsePart.self), + RPCResponsePart.metadata(Metadata(headers: serverInitialMetadata)) + ) + + // This is where this test actually begins. We want to write two messages + // without flushing, and make sure that no messages are sent down the pipeline + // until we flush. Once we flush, both messages should be sent in the same ByteBuffer. + + // Write back first message and make sure nothing's written in the channel. + XCTAssertNoThrow(channel.write(RPCRequestPart.message([UInt8](repeating: 1, count: 4)))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Write back second message and make sure nothing's written in the channel. + XCTAssertNoThrow(channel.write(RPCRequestPart.message([UInt8](repeating: 2, count: 4)))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Now flush and check we *do* write the data. + channel.flush() + + let writtenMessage = try channel.assertReadDataOutbound() + + // Make sure both messages have been framed together in the ByteBuffer. + XCTAssertEqual( + writtenMessage.data, + .byteBuffer( + .init(bytes: [ + // First message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 1, 1, 1, 1, // First message data + + // Second message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 2, 2, 2, 2, // Second message data + ]) + ) + ) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + } +} + +extension EmbeddedChannel { + fileprivate func assertReadHeadersOutbound() throws -> HTTP2Frame.FramePayload.Headers { + guard + case .headers(let writtenHeaders) = try XCTUnwrap( + try self.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + throw TestError.assertionFailure("Expected to write headers") + } + return writtenHeaders + } + + fileprivate func assertReadDataOutbound() throws -> HTTP2Frame.FramePayload.Data { + guard + case .data(let writtenMessage) = try XCTUnwrap( + try self.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + throw TestError.assertionFailure("Expected to write data") + } + return writtenMessage + } +} + +private enum TestError: Error { + case assertionFailure(String) +} diff --git a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift index 5839a7aa3..9d56824be 100644 --- a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift +++ b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -257,7 +257,6 @@ final class GRPCServerStreamHandlerTests: XCTestCase { ) ) - // Make sure we haven't sent back an error response, and that we read the initial metadata // Make sure we have sent a trailers-only response let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() @@ -413,7 +412,8 @@ final class GRPCServerStreamHandlerTests: XCTestCase { let handler = GRPCServerStreamHandler( scheme: .http, acceptedEncodings: [], - maximumPayloadSize: 100 + maximumPayloadSize: 100, + skipStateMachineAssertions: true ) let channel = EmbeddedChannel(handler: handler) @@ -505,6 +505,16 @@ final class GRPCServerStreamHandlerTests: XCTestCase { "custom-header": "custom-value", ] ) + + // Try writing and assert it throws to make sure we don't allow writes + // after closing. + XCTAssertThrowsError( + ofType: RPCError.self, + try channel.writeOutbound(trailers) + ) { error in + XCTAssertEqual(error.code, .internalError) + XCTAssertEqual(error.message, "Server can't send anything if closed.") + } } func testReceiveMessageSplitAcrossMultipleBuffers() throws {