diff --git a/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift b/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift index 753409587..c3eda1363 100644 --- a/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift +++ b/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift @@ -332,7 +332,7 @@ struct GRPCStreamStateMachine { case .server: if endStream { try self.invalidState( - "Can't end response stream by sending a message - send(status:metadata:trailersOnly:) must be called" + "Can't end response stream by sending a message - send(status:metadata:) must be called" ) } try self.serverSend(message: message) diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift new file mode 100644 index 000000000..c4da0d4ed --- /dev/null +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -0,0 +1,217 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import NIOCore +import NIOHTTP2 + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +final class GRPCServerStreamHandler: ChannelDuplexHandler { + typealias InboundIn = HTTP2Frame.FramePayload + typealias InboundOut = RPCRequestPart + + typealias OutboundIn = RPCResponsePart + typealias OutboundOut = HTTP2Frame.FramePayload + + private var stateMachine: GRPCStreamStateMachine + + private var isReading = false + private var flushPending = false + + // We buffer the final status + trailers to avoid reordering issues (i.e., + // if there are messages still not written into the channel because flush has + // not been called, but the server sends back trailers). + private var pendingTrailers: HTTP2Frame.FramePayload? + + init( + scheme: Scheme, + acceptedEncodings: [CompressionAlgorithm], + maximumPayloadSize: Int, + skipStateMachineAssertions: Bool = false + ) { + self.stateMachine = .init( + configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)), + maximumPayloadSize: maximumPayloadSize, + skipAssertions: skipStateMachineAssertions + ) + } +} + +// - MARK: ChannelInboundHandler + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension GRPCServerStreamHandler { + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.isReading = true + let frame = self.unwrapInboundIn(data) + switch frame { + case .data(let frameData): + let endStream = frameData.endStream + switch frameData.data { + case .byteBuffer(let buffer): + do { + try self.stateMachine.receive(message: buffer, endStream: endStream) + loop: while true { + switch self.stateMachine.nextInboundMessage() { + case .receiveMessage(let message): + context.fireChannelRead(self.wrapInboundOut(.message(message))) + case .awaitMoreMessages: + break loop + case .noMoreMessages: + context.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + break loop + } + } + } catch { + context.fireErrorCaught(error) + } + + case .fileRegion: + preconditionFailure("Unexpected IOData.fileRegion") + } + + case .headers(let headers): + do { + let action = try self.stateMachine.receive( + metadata: headers.headers, + endStream: headers.endStream + ) + switch action { + case .receivedMetadata(let metadata): + context.fireChannelRead(self.wrapInboundOut(.metadata(metadata))) + + case .rejectRPC(let trailers): + self.flushPending = true + let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true)) + context.write(self.wrapOutboundOut(response), promise: nil) + + case .receivedStatusAndMetadata: + throw RPCError( + code: .internalError, + message: "Server cannot get receivedStatusAndMetadata." + ) + + case .doNothing: + throw RPCError(code: .internalError, message: "Server cannot receive doNothing.") + } + } catch { + context.fireErrorCaught(error) + } + + case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate, + .alternativeService, .origin: + () + } + } + + func channelReadComplete(context: ChannelHandlerContext) { + self.isReading = false + if self.flushPending { + self.flushPending = false + context.flush() + } + context.fireChannelReadComplete() + } + + func handlerRemoved(context: ChannelHandlerContext) { + self.stateMachine.tearDown() + } +} + +// - MARK: ChannelOutboundHandler + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension GRPCServerStreamHandler { + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let frame = self.unwrapOutboundIn(data) + switch frame { + case .metadata(let metadata): + do { + self.flushPending = true + let headers = try self.stateMachine.send(metadata: metadata) + context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil) + // TODO: move the promise handling into the state machine + promise?.succeed() + } catch { + context.fireErrorCaught(error) + // TODO: move the promise handling into the state machine + promise?.fail(error) + } + + case .message(let message): + do { + try self.stateMachine.send(message: message, endStream: false) + // TODO: move the promise handling into the state machine + promise?.succeed() + } catch { + context.fireErrorCaught(error) + // TODO: move the promise handling into the state machine + promise?.fail(error) + } + + case .status(let status, let metadata): + do { + let headers = try self.stateMachine.send(status: status, metadata: metadata) + let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true)) + self.pendingTrailers = response + // TODO: move the promise handling into the state machine + promise?.succeed() + } catch { + context.fireErrorCaught(error) + // TODO: move the promise handling into the state machine + promise?.fail(error) + } + } + } + + func flush(context: ChannelHandlerContext) { + if self.isReading { + // We don't want to flush yet if we're still in a read loop. + return + } + + do { + loop: while true { + switch try self.stateMachine.nextOutboundMessage() { + case .sendMessage(let byteBuffer): + self.flushPending = true + context.write( + self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), + promise: nil + ) + + case .noMoreMessages: + if let pendingTrailers = self.pendingTrailers { + self.flushPending = true + self.pendingTrailers = nil + context.write(self.wrapOutboundOut(pendingTrailers), promise: nil) + } + break loop + + case .awaitMoreMessages: + break loop + } + } + + if self.flushPending { + self.flushPending = false + context.flush() + } + } catch { + context.fireErrorCaught(error) + } + } +} diff --git a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift new file mode 100644 index 000000000..5839a7aa3 --- /dev/null +++ b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -0,0 +1,789 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import NIOCore +import NIOEmbedded +import NIOHPACK +import NIOHTTP2 +import XCTest + +@testable import GRPCHTTP2Core + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +final class GRPCServerStreamHandlerTests: XCTestCase { + func testH2FramesAreIgnored() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + let framesToBeIgnored: [HTTP2Frame.FramePayload] = [ + .ping(.init(), ack: false), + .goAway(lastStreamID: .rootStream, errorCode: .cancel, opaqueData: nil), + // TODO: add .priority(StreamPriorityData) - right now, StreamPriorityData's + // initialiser is internal, so I can't create one of these frames. + .rstStream(.cancel), + .settings(.ack), + .pushPromise(.init(pushedStreamID: .maxID, headers: [:])), + .windowUpdate(windowSizeIncrement: 4), + .alternativeService(origin: nil, field: nil), + .origin([]), + ] + + for toBeIgnored in framesToBeIgnored { + XCTAssertNoThrow(try channel.writeInbound(toBeIgnored)) + XCTAssertNil(try channel.readInbound(as: HTTP2Frame.FramePayload.self)) + } + } + + func testClientInitialMetadataWithoutContentTypeResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without content-type + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() + + XCTAssertEqual(writtenTrailersOnlyResponse.headers, [":status": "415"]) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testClientInitialMetadataWithoutMethodResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without :method + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: + ":method header is expected to be present and have a value of \"POST\".", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testClientInitialMetadataWithoutSchemeResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without :scheme + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: + ":scheme header must be present and one of \"http\" or \"https\".", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testClientInitialMetadataWithoutPathResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without :path + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.unimplemented.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: "No :path header has been set.", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testClientInitialMetadataWithoutTEResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without TE + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: + "\"te\" header is expected to be present and have a value of \"trailers\".", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testNotAcceptedEncodingResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + GRPCHTTP2Keys.encoding.rawValue: "deflate", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + // Make sure we have sent a trailers-only response + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.unimplemented.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: "Compression is not supported", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testOverMaximumPayloadSize() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Make sure we wrote back the initial metadata + let writtenHeaders = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + ) + + // Receive client's message + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let clientDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true) + XCTAssertThrowsError( + ofType: RPCError.self, + try channel.writeInbound(HTTP2Frame.FramePayload.data(clientDataPayload)) + ) { error in + XCTAssertEqual(error.code, .resourceExhausted) + XCTAssertEqual( + error.message, + "Message has exceeded the configured maximum payload size (max: 1, actual: 42)" + ) + } + + // Make sure we haven't sent a response back and that we didn't read the received message + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + } + + func testClientEndsStream() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata with end stream set + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata, endStream: true)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Make sure we wrote back the initial metadata + let writtenHeaders = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + ) + + // We should throw if the client sends another message, since it's closed the stream already. + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let clientDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true) + XCTAssertThrowsError( + ofType: RPCError.self, + try channel.writeInbound(HTTP2Frame.FramePayload.data(clientDataPayload)) + ) { error in + XCTAssertEqual(error.code, .internalError) + XCTAssertEqual(error.message, "Client can't send a message if closed.") + } + } + + func testNormalFlow() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Make sure we wrote back the initial metadata + let writtenHeaders = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + ) + + // Receive client's message + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let clientDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true) + XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(clientDataPayload))) + + // Make sure we haven't sent back an error response, and that we read the message properly + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.message([UInt8](repeating: 0, count: 42)) + ) + + // Write back response + let serverDataPayload = RPCResponsePart.message([UInt8](repeating: 1, count: 42)) + XCTAssertNoThrow(try channel.writeOutbound(serverDataPayload)) + + // Make sure we wrote back the right message + let writtenMessage = try channel.assertReadDataOutbound() + + var expectedBuffer = ByteBuffer() + expectedBuffer.writeInteger(UInt8(0)) // not compressed + expectedBuffer.writeInteger(UInt32(42)) // message length + expectedBuffer.writeRepeatingByte(1, count: 42) // message + XCTAssertEqual(writtenMessage.data, .byteBuffer(expectedBuffer)) + + // Send back status to end RPC + let trailers = RPCResponsePart.status( + .init(code: .dataLoss, message: "Test data loss"), + ["custom-header": "custom-value"] + ) + XCTAssertNoThrow(try channel.writeOutbound(trailers)) + + // Make sure we wrote back the status and trailers + let writtenStatus = try channel.assertReadHeadersOutbound() + + XCTAssertTrue(writtenStatus.endStream) + XCTAssertEqual( + writtenStatus.headers, + [ + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.dataLoss.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: "Test data loss", + "custom-header": "custom-value", + ] + ) + } + + func testReceiveMessageSplitAcrossMultipleBuffers() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Make sure we wrote back the initial metadata + let writtenHeaders = try channel.assertReadHeadersOutbound() + + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + ) + + // Receive client's first message + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + + buffer.clear() + buffer.writeInteger(UInt32(30)) // message length + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + + buffer.clear() + buffer.writeRepeatingByte(0, count: 10) // first part of the message + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + + buffer.clear() + buffer.writeRepeatingByte(1, count: 10) // second part of the message + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + + buffer.clear() + buffer.writeRepeatingByte(2, count: 10) // third part of the message + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) + + // Make sure we haven't sent back an error response, and that we read the message properly + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.message( + [UInt8](repeating: 0, count: 10) + [UInt8](repeating: 1, count: 10) + + [UInt8](repeating: 2, count: 10) + ) + ) + } + + func testSendMultipleMessagesInSingleBuffer() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Read out the metadata + _ = try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + + // This is where this test actually begins. We want to write two messages + // without flushing, and make sure that no messages are sent down the pipeline + // until we flush. Once we flush, both messages should be sent in the same ByteBuffer. + + // Write back first message and make sure nothing's written in the channel. + XCTAssertNoThrow(channel.write(RPCResponsePart.message([UInt8](repeating: 1, count: 4)))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Write back second message and make sure nothing's written in the channel. + XCTAssertNoThrow(channel.write(RPCResponsePart.message([UInt8](repeating: 2, count: 4)))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Now flush and check we *do* write the data. + channel.flush() + + let writtenMessage = try channel.assertReadDataOutbound() + + // Make sure both messages have been framed together in the ByteBuffer. + XCTAssertEqual( + writtenMessage.data, + .byteBuffer( + .init(bytes: [ + // First message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 1, 1, 1, 1, // First message data + + // Second message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 2, 2, 2, 2, // Second message data + ]) + ) + ) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + } + + func testMessageAndStatusAreNotReordered() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: [:])) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Read out the metadata + _ = try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + + // This is where this test actually begins. We want to write a message followed + // by status and trailers, and only flush after both writes. + // Because messages are buffered and potentially bundled together in a single + // ByteBuffer by the GPRCMessageFramer, we want to make sure that the status + // and trailers won't be written before the messages. + + // Write back message and make sure nothing's written in the channel. + XCTAssertNoThrow(channel.write(RPCResponsePart.message([UInt8](repeating: 1, count: 4)))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Write status + metadata and make sure nothing's written. + XCTAssertNoThrow(channel.write(RPCResponsePart.status(.init(code: .ok, message: ""), [:]))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Now flush and check we *do* write the data in the right order: message first, + // trailers second. + channel.flush() + + let writtenMessage = try channel.assertReadDataOutbound() + + // Make sure we first get message. + XCTAssertEqual( + writtenMessage.data, + .byteBuffer( + .init(bytes: [ + // First message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 1, 1, 1, 1, // First message data + ]) + ) + ) + XCTAssertFalse(writtenMessage.endStream) + + // Make sure we get trailers. + let writtenTrailers = try channel.assertReadHeadersOutbound() + XCTAssertEqual(writtenTrailers.headers, ["grpc-status": "0"]) + XCTAssertTrue(writtenTrailers.endStream) + + // Make sure we get nothing else. + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + } +} + +extension EmbeddedChannel { + fileprivate func assertReadHeadersOutbound() throws -> HTTP2Frame.FramePayload.Headers { + guard + case .headers(let writtenHeaders) = try XCTUnwrap( + try self.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + throw TestError.assertionFailure("Expected to write headers") + } + return writtenHeaders + } + + fileprivate func assertReadDataOutbound() throws -> HTTP2Frame.FramePayload.Data { + guard + case .data(let writtenMessage) = try XCTUnwrap( + try self.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + throw TestError.assertionFailure("Expected to write data") + } + return writtenMessage + } +} + +private enum TestError: Error { + case assertionFailure(String) +}