From 0abcab03c0a87f7f06e99f24dbef4500178478a2 Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Fri, 23 Feb 2024 09:43:26 +0000 Subject: [PATCH 1/8] Add GRPCServerStreamHandler --- .../GRPCStreamStateMachine.swift | 2 +- .../Server/GRPCServerStreamHandler.swift | 174 ++++++++++++++++++ 2 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift diff --git a/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift b/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift index 753409587..c3eda1363 100644 --- a/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift +++ b/Sources/GRPCHTTP2Core/GRPCStreamStateMachine.swift @@ -332,7 +332,7 @@ struct GRPCStreamStateMachine { case .server: if endStream { try self.invalidState( - "Can't end response stream by sending a message - send(status:metadata:trailersOnly:) must be called" + "Can't end response stream by sending a message - send(status:metadata:) must be called" ) } try self.serverSend(message: message) diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift new file mode 100644 index 000000000..7fb7f15aa --- /dev/null +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -0,0 +1,174 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import NIOCore +import NIOHTTP2 + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +final class GRPCServerStreamHandler: ChannelDuplexHandler { + typealias InboundIn = HTTP2Frame.FramePayload + typealias InboundOut = RPCRequestPart + + typealias OutboundIn = RPCResponsePart + typealias OutboundOut = HTTP2Frame.FramePayload + + private var stateMachine: GRPCStreamStateMachine + + private var isReading = false + private var flushPending = false + + init( + scheme: Scheme, + acceptedEncodings: [CompressionAlgorithm], + maximumPayloadSize: Int + ) { + self.stateMachine = .init( + configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)), + maximumPayloadSize: maximumPayloadSize + ) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.isReading = true + let frame = self.unwrapInboundIn(data) + switch frame { + case .data(let frameData): + let endStream = frameData.endStream + switch frameData.data { + case .byteBuffer(let buffer): + do { + try self.stateMachine.receive(message: buffer, endStream: endStream) + switch self.stateMachine.nextInboundMessage() { + case .awaitMoreMessages: + () + case .receiveMessage(let message): + context.fireChannelRead(self.wrapInboundOut(.message(message))) + case .noMoreMessages: + context.channel.close(mode: .input, promise: nil) + } + } catch { + context.fireErrorCaught(error) + } + case .fileRegion: + preconditionFailure("Unexpected IOData.fileRegion") + } + case .headers(let headers): + do { + let action = try self.stateMachine.receive( + metadata: headers.headers, + endStream: headers.endStream + ) + switch action { + case .receivedMetadata(let metadata): + context.fireChannelRead(self.wrapInboundOut(.metadata(metadata))) + case .rejectRPC(let trailers): + let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true)) + context.write(self.wrapOutboundOut(response), promise: nil) + self.flushPending = true + case .receivedStatusAndMetadata: + throw RPCError( + code: .internalError, + message: "Server cannot get receivedStatusAndMetadata." + ) + case .doNothing: + throw RPCError(code: .internalError, message: "Server cannot receive doNothing.") + } + } catch { + context.fireErrorCaught(error) + } + case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate, + .alternativeService, .origin: + () + } + } + + func channelReadComplete(context: ChannelHandlerContext) { + self.isReading = false + if self.flushPending { + self.flushPending = false + context.flush() + } + context.fireChannelReadComplete() + } +} + +// - MARK: ChannelOutboundHandler + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension GRPCServerStreamHandler { + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let frame = self.unwrapOutboundIn(data) + switch frame { + case .metadata(let metadata): + do { + let headers = try self.stateMachine.send(metadata: metadata) + context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil) + // TODO: move the promise handling into the state machine + promise?.succeed() + } catch { + context.fireErrorCaught(error) + // TODO: move the promise handling into the state machine + promise?.fail(error) + } + case .message(let message): + do { + try self.stateMachine.send(message: message, endStream: false) + switch try self.stateMachine.nextOutboundMessage() { + case .noMoreMessages: + // We shouldn't close the channel in this case, because we still have + // to send back a status and trailers to properly end the RPC stream. + () + case .awaitMoreMessages: + () + case .sendMessage(let byteBuffer): + context.write( + self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), + promise: nil + ) + if self.isReading { + self.flushPending = true + } else { + context.flush() + } + } + // TODO: move the promise handling into the state machine + promise?.succeed() + } catch { + context.fireErrorCaught(error) + // TODO: move the promise handling into the state machine + promise?.fail(error) + } + case .status(let status, let metadata): + do { + let headers = try self.stateMachine.send(status: status, metadata: metadata) + let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true)) + context.write(self.wrapOutboundOut(response), promise: nil) + if self.isReading { + self.flushPending = true + } else { + context.flush() + } + // TODO: move the promise handling into the state machine + promise?.succeed() + } catch { + context.fireErrorCaught(error) + // TODO: move the promise handling into the state machine + promise?.fail(error) + } + } + } +} From 8fe0f3570a5bc4c1b38ad6b899bcb4b6767aa992 Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Wed, 13 Mar 2024 17:25:59 +0000 Subject: [PATCH 2/8] Add tests --- .../Server/GRPCServerStreamHandler.swift | 6 +- .../Server/GRPCServerStreamHandlerTests.swift | 644 ++++++++++++++++++ 2 files changed, 648 insertions(+), 2 deletions(-) create mode 100644 Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift index 7fb7f15aa..f2d281c6c 100644 --- a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -34,11 +34,13 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { init( scheme: Scheme, acceptedEncodings: [CompressionAlgorithm], - maximumPayloadSize: Int + maximumPayloadSize: Int, + skipStateMachineAssertions: Bool = false ) { self.stateMachine = .init( configuration: .server(.init(scheme: scheme, acceptedEncodings: acceptedEncodings)), - maximumPayloadSize: maximumPayloadSize + maximumPayloadSize: maximumPayloadSize, + skipAssertions: skipStateMachineAssertions ) } diff --git a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift new file mode 100644 index 000000000..0cd97c0eb --- /dev/null +++ b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -0,0 +1,644 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import NIOCore +import NIOEmbedded +import NIOHPACK +import NIOHTTP2 +import XCTest + +@testable import GRPCHTTP2Core + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +final class GRPCServerStreamHandlerTests: XCTestCase { + func testClientInitialMetadataWithoutContentTypeResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without content-type + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + guard + case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual(writtenTrailersOnlyResponse.headers, [":status": "415"]) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testClientInitialMetadataWithoutMethodResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without :method + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + guard + case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: ":method header is expected to be present and have a value of \"POST\".", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testClientInitialMetadataWithoutSchemeResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without :scheme + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + guard + case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: ":scheme header must be present and one of \"http\" or \"https\".", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testClientInitialMetadataWithoutPathResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without :path + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + guard + case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.unimplemented.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: "No :path header has been set.", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testClientInitialMetadataWithoutTEResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata without TE + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc" + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we have sent a trailers-only response + guard + case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: "\"te\" header is expected to be present and have a value of \"trailers\".", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testNotAcceptedEncodingResultsInRejectedRPC() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + GRPCHTTP2Keys.encoding.rawValue: "deflate", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + // Make sure we have sent a trailers-only response + guard + case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenTrailersOnlyResponse.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.unimplemented.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: "Compression is not supported", + ] + ) + XCTAssertTrue(writtenTrailersOnlyResponse.endStream) + } + + func testOverMaximumPayloadSize() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Make sure we wrote back the initial metadata + guard + case .headers(let writtenHeaders) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + ) + + // Receive client's message + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let clientDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true) + XCTAssertThrowsError( + ofType: RPCError.self, + try channel.writeInbound(HTTP2Frame.FramePayload.data(clientDataPayload)) + ) { error in + XCTAssertEqual(error.code, .resourceExhausted) + XCTAssertEqual( + error.message, + "Message has exceeded the configured maximum payload size (max: 1, actual: 42)" + ) + } + + // Make sure we haven't sent a response back and that we didn't read the received message + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + } + + func testClientEndsStream() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100, + skipStateMachineAssertions: true + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata with end stream set + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata, endStream: true)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Make sure we wrote back the initial metadata + guard + case .headers(let writtenHeaders) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + ) + + // We should throw if the client sends another message, since it's closed the stream already. + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let clientDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true) + XCTAssertThrowsError( + ofType: RPCError.self, + try channel.writeInbound(HTTP2Frame.FramePayload.data(clientDataPayload)) + ) { error in + XCTAssertEqual(error.code, .internalError) + XCTAssertEqual(error.message, "Client can't send a message if closed.") + } + } + + func testNormalFlow() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Make sure we wrote back the initial metadata + guard + case .headers(let writtenHeaders) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + ) + + // Receive client's message + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + buffer.writeInteger(UInt32(42)) // message length + buffer.writeRepeatingByte(0, count: 42) // message + let clientDataPayload = HTTP2Frame.FramePayload.Data(data: .byteBuffer(buffer), endStream: true) + XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(clientDataPayload))) + + // Make sure we haven't sent back an error response, and that we read the message properly + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.message([UInt8](repeating: 0, count: 42)) + ) + + // Write back response + let serverDataPayload = RPCResponsePart.message([UInt8](repeating: 1, count: 42)) + XCTAssertNoThrow(try channel.writeOutbound(serverDataPayload)) + + // Make sure we wrote back the right message + guard + case .data(let writtenMessage) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write data") + return + } + + var expectedBuffer = ByteBuffer() + expectedBuffer.writeInteger(UInt8(0)) // not compressed + expectedBuffer.writeInteger(UInt32(42)) // message length + expectedBuffer.writeRepeatingByte(1, count: 42) // message + XCTAssertEqual(writtenMessage.data, .byteBuffer(expectedBuffer)) + + // Send back status to end RPC + let trailers = RPCResponsePart.status( + .init(code: .dataLoss, message: "Test data loss"), + ["custom-header": "custom-value"] + ) + XCTAssertNoThrow(try channel.writeOutbound(trailers)) + + // Make sure we wrote back the status and trailers + guard + case .headers(let writtenStatus) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write data") + return + } + + XCTAssertTrue(writtenStatus.endStream) + XCTAssertEqual( + writtenStatus.headers, + [ + GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.dataLoss.rawValue), + GRPCHTTP2Keys.grpcStatusMessage.rawValue: "Test data loss", + "custom-header": "custom-value", + ] + ) + } + + func testReceiveMessageSplitAcrossMultipleBuffers() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Make sure we wrote back the initial metadata + guard + case .headers(let writtenHeaders) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write headers") + return + } + + XCTAssertEqual( + writtenHeaders.headers, + [ + GRPCHTTP2Keys.status.rawValue: "200", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + "some-custom-header": "some-custom-value", + ] + ) + + // Receive client's first message + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(0)) // not compressed + XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + + buffer.clear() + buffer.writeInteger(UInt32(30)) // message length + XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + + buffer.clear() + buffer.writeRepeatingByte(0, count: 10) // first part of the message + XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + + buffer.clear() + buffer.writeRepeatingByte(1, count: 10) // second part of the message + XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) + + buffer.clear() + buffer.writeRepeatingByte(2, count: 10) // third part of the message + XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + + // Make sure we haven't sent back an error response, and that we read the message properly + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.message([UInt8](repeating: 0, count: 10) + [UInt8](repeating: 1, count: 10) + [UInt8](repeating: 2, count: 10)) + ) + } +} From 7a034461b8014ad704549bf53adf35a8536673c0 Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Mon, 18 Mar 2024 18:28:54 +0000 Subject: [PATCH 3/8] PR changes --- .../Server/GRPCServerStreamHandler.swift | 31 ++- .../Server/GRPCServerStreamHandlerTests.swift | 228 +++++++++--------- 2 files changed, 128 insertions(+), 131 deletions(-) diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift index f2d281c6c..666f75c4d 100644 --- a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -60,7 +60,7 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { case .receiveMessage(let message): context.fireChannelRead(self.wrapInboundOut(.message(message))) case .noMoreMessages: - context.channel.close(mode: .input, promise: nil) + context.close(mode: .input, promise: nil) } } catch { context.fireErrorCaught(error) @@ -68,6 +68,7 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { case .fileRegion: preconditionFailure("Unexpected IOData.fileRegion") } + case .headers(let headers): do { let action = try self.stateMachine.receive( @@ -92,6 +93,7 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { } catch { context.fireErrorCaught(error) } + case .ping, .goAway, .priority, .rstStream, .settings, .pushPromise, .windowUpdate, .alternativeService, .origin: () @@ -106,12 +108,24 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { } context.fireChannelReadComplete() } + + func handlerRemoved(context: ChannelHandlerContext) { + self.stateMachine.tearDown() + } } // - MARK: ChannelOutboundHandler @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) extension GRPCServerStreamHandler { + private func flushIfNeeded(_ context: ChannelHandlerContext) { + if self.isReading { + self.flushPending = true + } else { + context.flush() + } + } + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { let frame = self.unwrapOutboundIn(data) switch frame { @@ -119,6 +133,7 @@ extension GRPCServerStreamHandler { do { let headers = try self.stateMachine.send(metadata: metadata) context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil) + self.flushIfNeeded(context) // TODO: move the promise handling into the state machine promise?.succeed() } catch { @@ -126,6 +141,7 @@ extension GRPCServerStreamHandler { // TODO: move the promise handling into the state machine promise?.fail(error) } + case .message(let message): do { try self.stateMachine.send(message: message, endStream: false) @@ -141,11 +157,7 @@ extension GRPCServerStreamHandler { self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), promise: nil ) - if self.isReading { - self.flushPending = true - } else { - context.flush() - } + self.flushIfNeeded(context) } // TODO: move the promise handling into the state machine promise?.succeed() @@ -154,16 +166,13 @@ extension GRPCServerStreamHandler { // TODO: move the promise handling into the state machine promise?.fail(error) } + case .status(let status, let metadata): do { let headers = try self.stateMachine.send(status: status, metadata: metadata) let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true)) context.write(self.wrapOutboundOut(response), promise: nil) - if self.isReading { - self.flushPending = true - } else { - context.flush() - } + self.flushIfNeeded(context) // TODO: move the promise handling into the state machine promise?.succeed() } catch { diff --git a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift index 0cd97c0eb..8c5d509fa 100644 --- a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift +++ b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -25,6 +25,34 @@ import XCTest @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) final class GRPCServerStreamHandlerTests: XCTestCase { + func testH2FramesAreIgnored() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 1 + ) + + let channel = EmbeddedChannel(handler: handler) + + let framesToBeIgnored: [HTTP2Frame.FramePayload] = [ + .ping(.init(), ack: false), + .goAway(lastStreamID: .rootStream, errorCode: .cancel, opaqueData: nil), + // TODO: add .priority(StreamPriorityData) - right now, StreamPriorityData's + // initialiser is internal, so I can't create one of these frames. + .rstStream(.cancel), + .settings(.ack), + .pushPromise(.init(pushedStreamID: .maxID, headers: [:])), + .windowUpdate(windowSizeIncrement: 4), + .alternativeService(origin: nil, field: nil), + .origin([]), + ] + + for toBeIgnored in framesToBeIgnored { + XCTAssertNoThrow(try channel.writeInbound(toBeIgnored)) + XCTAssertNil(try channel.readInbound(as: HTTP2Frame.FramePayload.self)) + } + } + func testClientInitialMetadataWithoutContentTypeResultsInRejectedRPC() throws { let handler = GRPCServerStreamHandler( scheme: .http, @@ -46,21 +74,14 @@ final class GRPCServerStreamHandlerTests: XCTestCase { HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) ) ) - + // Make sure we have sent a trailers-only response - guard - case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() XCTAssertEqual(writtenTrailersOnlyResponse.headers, [":status": "415"]) XCTAssertTrue(writtenTrailersOnlyResponse.endStream) } - + func testClientInitialMetadataWithoutMethodResultsInRejectedRPC() throws { let handler = GRPCServerStreamHandler( scheme: .http, @@ -82,16 +103,9 @@ final class GRPCServerStreamHandlerTests: XCTestCase { HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) ) ) - + // Make sure we have sent a trailers-only response - guard - case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenTrailersOnlyResponse.headers, @@ -99,12 +113,13 @@ final class GRPCServerStreamHandlerTests: XCTestCase { GRPCHTTP2Keys.status.rawValue: "200", GRPCHTTP2Keys.contentType.rawValue: "application/grpc", GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), - GRPCHTTP2Keys.grpcStatusMessage.rawValue: ":method header is expected to be present and have a value of \"POST\".", + GRPCHTTP2Keys.grpcStatusMessage.rawValue: + ":method header is expected to be present and have a value of \"POST\".", ] ) XCTAssertTrue(writtenTrailersOnlyResponse.endStream) } - + func testClientInitialMetadataWithoutSchemeResultsInRejectedRPC() throws { let handler = GRPCServerStreamHandler( scheme: .http, @@ -126,16 +141,9 @@ final class GRPCServerStreamHandlerTests: XCTestCase { HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) ) ) - + // Make sure we have sent a trailers-only response - guard - case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenTrailersOnlyResponse.headers, @@ -143,12 +151,13 @@ final class GRPCServerStreamHandlerTests: XCTestCase { GRPCHTTP2Keys.status.rawValue: "200", GRPCHTTP2Keys.contentType.rawValue: "application/grpc", GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), - GRPCHTTP2Keys.grpcStatusMessage.rawValue: ":scheme header must be present and one of \"http\" or \"https\".", + GRPCHTTP2Keys.grpcStatusMessage.rawValue: + ":scheme header must be present and one of \"http\" or \"https\".", ] ) XCTAssertTrue(writtenTrailersOnlyResponse.endStream) } - + func testClientInitialMetadataWithoutPathResultsInRejectedRPC() throws { let handler = GRPCServerStreamHandler( scheme: .http, @@ -170,16 +179,9 @@ final class GRPCServerStreamHandlerTests: XCTestCase { HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) ) ) - + // Make sure we have sent a trailers-only response - guard - case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenTrailersOnlyResponse.headers, @@ -192,7 +194,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { ) XCTAssertTrue(writtenTrailersOnlyResponse.endStream) } - + func testClientInitialMetadataWithoutTEResultsInRejectedRPC() throws { let handler = GRPCServerStreamHandler( scheme: .http, @@ -207,23 +209,16 @@ final class GRPCServerStreamHandlerTests: XCTestCase { GRPCHTTP2Keys.path.rawValue: "test/test", GRPCHTTP2Keys.scheme.rawValue: "http", GRPCHTTP2Keys.method.rawValue: "POST", - GRPCHTTP2Keys.contentType.rawValue: "application/grpc" + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", ] XCTAssertNoThrow( try channel.writeInbound( HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) ) ) - + // Make sure we have sent a trailers-only response - guard - case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenTrailersOnlyResponse.headers, @@ -231,12 +226,13 @@ final class GRPCServerStreamHandlerTests: XCTestCase { GRPCHTTP2Keys.status.rawValue: "200", GRPCHTTP2Keys.contentType.rawValue: "application/grpc", GRPCHTTP2Keys.grpcStatus.rawValue: String(Status.Code.invalidArgument.rawValue), - GRPCHTTP2Keys.grpcStatusMessage.rawValue: "\"te\" header is expected to be present and have a value of \"trailers\".", + GRPCHTTP2Keys.grpcStatusMessage.rawValue: + "\"te\" header is expected to be present and have a value of \"trailers\".", ] ) XCTAssertTrue(writtenTrailersOnlyResponse.endStream) } - + func testNotAcceptedEncodingResultsInRejectedRPC() throws { let handler = GRPCServerStreamHandler( scheme: .http, @@ -263,14 +259,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { // Make sure we haven't sent back an error response, and that we read the initial metadata // Make sure we have sent a trailers-only response - guard - case .headers(let writtenTrailersOnlyResponse) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenTrailersOnlyResponse = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenTrailersOnlyResponse.headers, @@ -322,14 +311,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) // Make sure we wrote back the initial metadata - guard - case .headers(let writtenHeaders) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenHeaders = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenHeaders.headers, @@ -401,14 +383,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) // Make sure we wrote back the initial metadata - guard - case .headers(let writtenHeaders) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenHeaders = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenHeaders.headers, @@ -472,14 +447,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) // Make sure we wrote back the initial metadata - guard - case .headers(let writtenHeaders) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenHeaders = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenHeaders.headers, @@ -510,14 +478,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { XCTAssertNoThrow(try channel.writeOutbound(serverDataPayload)) // Make sure we wrote back the right message - guard - case .data(let writtenMessage) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write data") - return - } + let writtenMessage = try channel.assertReadDataOutbound() var expectedBuffer = ByteBuffer() expectedBuffer.writeInteger(UInt8(0)) // not compressed @@ -533,14 +494,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { XCTAssertNoThrow(try channel.writeOutbound(trailers)) // Make sure we wrote back the status and trailers - guard - case .headers(let writtenStatus) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write data") - return - } + let writtenStatus = try channel.assertReadHeadersOutbound() XCTAssertTrue(writtenStatus.endStream) XCTAssertEqual( @@ -591,14 +545,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) // Make sure we wrote back the initial metadata - guard - case .headers(let writtenHeaders) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write headers") - return - } + let writtenHeaders = try channel.assertReadHeadersOutbound() XCTAssertEqual( writtenHeaders.headers, @@ -612,33 +559,74 @@ final class GRPCServerStreamHandlerTests: XCTestCase { // Receive client's first message var buffer = ByteBuffer() buffer.writeInteger(UInt8(0)) // not compressed - XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) - + buffer.clear() buffer.writeInteger(UInt32(30)) // message length - XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) - + buffer.clear() buffer.writeRepeatingByte(0, count: 10) // first part of the message - XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) - + buffer.clear() buffer.writeRepeatingByte(1, count: 10) // second part of the message - XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) XCTAssertNil(try channel.readInbound(as: RPCRequestPart.self)) - + buffer.clear() buffer.writeRepeatingByte(2, count: 10) // third part of the message - XCTAssertNoThrow(try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer))))) + XCTAssertNoThrow( + try channel.writeInbound(HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(buffer)))) + ) // Make sure we haven't sent back an error response, and that we read the message properly XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) XCTAssertEqual( try channel.readInbound(as: RPCRequestPart.self), - RPCRequestPart.message([UInt8](repeating: 0, count: 10) + [UInt8](repeating: 1, count: 10) + [UInt8](repeating: 2, count: 10)) + RPCRequestPart.message( + [UInt8](repeating: 0, count: 10) + [UInt8](repeating: 1, count: 10) + + [UInt8](repeating: 2, count: 10) + ) ) } } + +extension EmbeddedChannel { + fileprivate func assertReadHeadersOutbound() throws -> HTTP2Frame.FramePayload.Headers { + guard + case .headers(let writtenHeaders) = try XCTUnwrap( + try self.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + throw TestError.assertionFailure("Expected to write headers") + } + return writtenHeaders + } + + fileprivate func assertReadDataOutbound() throws -> HTTP2Frame.FramePayload.Data { + guard + case .data(let writtenMessage) = try XCTUnwrap( + try self.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + throw TestError.assertionFailure("Expected to write data") + } + return writtenMessage + } +} + +private enum TestError: Error { + case assertionFailure(String) +} From 9fd0e04174ba20e02d49e8b22e30d2d43d9cd422 Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Tue, 19 Mar 2024 14:16:39 +0000 Subject: [PATCH 4/8] Fix flushing logic --- .../Server/GRPCServerStreamHandler.swift | 35 ++++---- .../Server/GRPCServerStreamHandlerTests.swift | 81 ++++++++++++++++++- 2 files changed, 101 insertions(+), 15 deletions(-) diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift index 666f75c4d..fb647d8f8 100644 --- a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -145,20 +145,7 @@ extension GRPCServerStreamHandler { case .message(let message): do { try self.stateMachine.send(message: message, endStream: false) - switch try self.stateMachine.nextOutboundMessage() { - case .noMoreMessages: - // We shouldn't close the channel in this case, because we still have - // to send back a status and trailers to properly end the RPC stream. - () - case .awaitMoreMessages: - () - case .sendMessage(let byteBuffer): - context.write( - self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), - promise: nil - ) - self.flushIfNeeded(context) - } + self.flushIfNeeded(context) // TODO: move the promise handling into the state machine promise?.succeed() } catch { @@ -182,4 +169,24 @@ extension GRPCServerStreamHandler { } } } + + func flush(context: ChannelHandlerContext) { + do { + switch try self.stateMachine.nextOutboundMessage() { + case .noMoreMessages: + // We shouldn't close the channel in this case, because we still have + // to send back a status and trailers to properly end the RPC stream. + () + case .awaitMoreMessages: + () + case .sendMessage(let byteBuffer): + context.writeAndFlush( + self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), + promise: nil + ) + } + } catch { + context.fireErrorCaught(error) + } + } } diff --git a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift index 8c5d509fa..5ed77295b 100644 --- a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift +++ b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -601,6 +601,85 @@ final class GRPCServerStreamHandlerTests: XCTestCase { ) ) } + + func testSendMultipleMessagesInSingleBuffer() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let headers: HPACKHeaders = [ + "some-custom-header": "some-custom-value" + ] + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: headers)) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Read out the metadata + _ = try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + + // This is where this test actually begins. We want to write two messages + // without flushing, and make sure that no messages are sent down the pipeline + // until we flush. Once we flush, both messages should be sent in the same ByteBuffer. + + // Write back first message and make sure nothing's written in the channel. + XCTAssertNoThrow(channel.write(RPCResponsePart.message([UInt8](repeating: 1, count: 4)))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Write back second message and make sure nothing's written in the channel. + XCTAssertNoThrow(channel.write(RPCResponsePart.message([UInt8](repeating: 2, count: 4)))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Now flush and check we *do* write the data. + channel.flush() + + guard + case .data(let writtenMessage) = try XCTUnwrap( + try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + ) + else { + XCTFail("Expected to write data") + return + } + + // Make sure both messages have been framed together in the ByteBuffer. + XCTAssertEqual(writtenMessage.data, .byteBuffer(.init(bytes: [ + // First message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 1, 1, 1, 1, // First message data + + // Second message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 2, 2, 2, 2 // Second message data + ]))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + } } extension EmbeddedChannel { @@ -614,7 +693,7 @@ extension EmbeddedChannel { } return writtenHeaders } - + fileprivate func assertReadDataOutbound() throws -> HTTP2Frame.FramePayload.Data { guard case .data(let writtenMessage) = try XCTUnwrap( From 81fd8134cab736fe82c655b6432c3181d12d336f Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Tue, 19 Mar 2024 14:17:32 +0000 Subject: [PATCH 5/8] Formatting --- .../Server/GRPCServerStreamHandler.swift | 2 +- .../Server/GRPCServerStreamHandlerTests.swift | 47 ++++++++++--------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift index fb647d8f8..6dff7aa68 100644 --- a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -169,7 +169,7 @@ extension GRPCServerStreamHandler { } } } - + func flush(context: ChannelHandlerContext) { do { switch try self.stateMachine.nextOutboundMessage() { diff --git a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift index 5ed77295b..eea8564a3 100644 --- a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift +++ b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -601,16 +601,16 @@ final class GRPCServerStreamHandlerTests: XCTestCase { ) ) } - + func testSendMultipleMessagesInSingleBuffer() throws { let handler = GRPCServerStreamHandler( scheme: .http, acceptedEncodings: [], maximumPayloadSize: 100 ) - + let channel = EmbeddedChannel(handler: handler) - + // Receive client's initial metadata let clientInitialMetadata: HPACKHeaders = [ GRPCHTTP2Keys.path.rawValue: "test/test", @@ -624,14 +624,14 @@ final class GRPCServerStreamHandlerTests: XCTestCase { HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) ) ) - + // Make sure we haven't sent back an error response, and that we read the initial metadata XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) XCTAssertEqual( try channel.readInbound(as: RPCRequestPart.self), RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) ) - + // Write back server's initial metadata let headers: HPACKHeaders = [ "some-custom-header": "some-custom-value" @@ -641,19 +641,19 @@ final class GRPCServerStreamHandlerTests: XCTestCase { // Read out the metadata _ = try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - + // This is where this test actually begins. We want to write two messages // without flushing, and make sure that no messages are sent down the pipeline // until we flush. Once we flush, both messages should be sent in the same ByteBuffer. - + // Write back first message and make sure nothing's written in the channel. XCTAssertNoThrow(channel.write(RPCResponsePart.message([UInt8](repeating: 1, count: 4)))) XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) - + // Write back second message and make sure nothing's written in the channel. XCTAssertNoThrow(channel.write(RPCResponsePart.message([UInt8](repeating: 2, count: 4)))) XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) - + // Now flush and check we *do* write the data. channel.flush() @@ -667,17 +667,22 @@ final class GRPCServerStreamHandlerTests: XCTestCase { } // Make sure both messages have been framed together in the ByteBuffer. - XCTAssertEqual(writtenMessage.data, .byteBuffer(.init(bytes: [ - // First message - 0, // Compression disabled - 0, 0, 0, 4, // Message length - 1, 1, 1, 1, // First message data - - // Second message - 0, // Compression disabled - 0, 0, 0, 4, // Message length - 2, 2, 2, 2 // Second message data - ]))) + XCTAssertEqual( + writtenMessage.data, + .byteBuffer( + .init(bytes: [ + // First message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 1, 1, 1, 1, // First message data + + // Second message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 2, 2, 2, 2, // Second message data + ]) + ) + ) XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) } } @@ -693,7 +698,7 @@ extension EmbeddedChannel { } return writtenHeaders } - + fileprivate func assertReadDataOutbound() throws -> HTTP2Frame.FramePayload.Data { guard case .data(let writtenMessage) = try XCTUnwrap( From 18784ee659d2349d1515e7961f3716556ed5947d Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Tue, 19 Mar 2024 16:33:21 +0000 Subject: [PATCH 6/8] Replace channel close with user inbound inputClosed event fired --- .../GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift | 2 +- .../Server/GRPCServerStreamHandlerTests.swift | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift index 6dff7aa68..ceceb79b3 100644 --- a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -60,7 +60,7 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { case .receiveMessage(let message): context.fireChannelRead(self.wrapInboundOut(.message(message))) case .noMoreMessages: - context.close(mode: .input, promise: nil) + context.fireUserInboundEventTriggered(ChannelEvent.inputClosed) } } catch { context.fireErrorCaught(error) diff --git a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift index eea8564a3..e8b1bbb3d 100644 --- a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift +++ b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -657,14 +657,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { // Now flush and check we *do* write the data. channel.flush() - guard - case .data(let writtenMessage) = try XCTUnwrap( - try channel.readOutbound(as: HTTP2Frame.FramePayload.self) - ) - else { - XCTFail("Expected to write data") - return - } + let writtenMessage = try channel.assertReadDataOutbound() // Make sure both messages have been framed together in the ByteBuffer. XCTAssertEqual( From 0ec7e3ce57416540dafbd0e389421a0cf3636001 Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Wed, 20 Mar 2024 16:39:22 +0000 Subject: [PATCH 7/8] PR changes --- .../Server/GRPCServerStreamHandler.swift | 74 +++++++++++------ .../Server/GRPCServerStreamHandlerTests.swift | 80 +++++++++++++++++++ 2 files changed, 130 insertions(+), 24 deletions(-) diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift index ceceb79b3..588f5df51 100644 --- a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -31,6 +31,11 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { private var isReading = false private var flushPending = false + // We buffer the final status + trailers to avoid reordering issues (i.e., + // if there are messages still not written into the channel because flush has + // not been called, but the server sends back trailers). + private var pendingTrailers: HTTP2Frame.FramePayload? + init( scheme: Scheme, acceptedEncodings: [CompressionAlgorithm], @@ -43,7 +48,12 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { skipAssertions: skipStateMachineAssertions ) } +} +// - MARK: ChannelInboundHandler + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +extension GRPCServerStreamHandler { func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.isReading = true let frame = self.unwrapInboundIn(data) @@ -54,11 +64,18 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { case .byteBuffer(let buffer): do { try self.stateMachine.receive(message: buffer, endStream: endStream) - switch self.stateMachine.nextInboundMessage() { + + var nextInboundMessage = self.stateMachine.nextInboundMessage() + while case .receiveMessage(let message) = nextInboundMessage { + context.fireChannelRead(self.wrapInboundOut(.message(message))) + nextInboundMessage = self.stateMachine.nextInboundMessage() + } + + switch nextInboundMessage { case .awaitMoreMessages: () - case .receiveMessage(let message): - context.fireChannelRead(self.wrapInboundOut(.message(message))) + case .receiveMessage: + preconditionFailure("This isn't possible: we'd still be inside the while loop.") case .noMoreMessages: context.fireUserInboundEventTriggered(ChannelEvent.inputClosed) } @@ -118,14 +135,6 @@ final class GRPCServerStreamHandler: ChannelDuplexHandler { @available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) extension GRPCServerStreamHandler { - private func flushIfNeeded(_ context: ChannelHandlerContext) { - if self.isReading { - self.flushPending = true - } else { - context.flush() - } - } - func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { let frame = self.unwrapOutboundIn(data) switch frame { @@ -133,7 +142,7 @@ extension GRPCServerStreamHandler { do { let headers = try self.stateMachine.send(metadata: metadata) context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil) - self.flushIfNeeded(context) + self.flushPending = true // TODO: move the promise handling into the state machine promise?.succeed() } catch { @@ -145,7 +154,6 @@ extension GRPCServerStreamHandler { case .message(let message): do { try self.stateMachine.send(message: message, endStream: false) - self.flushIfNeeded(context) // TODO: move the promise handling into the state machine promise?.succeed() } catch { @@ -158,8 +166,7 @@ extension GRPCServerStreamHandler { do { let headers = try self.stateMachine.send(status: status, metadata: metadata) let response = HTTP2Frame.FramePayload.headers(.init(headers: headers, endStream: true)) - context.write(self.wrapOutboundOut(response), promise: nil) - self.flushIfNeeded(context) + self.pendingTrailers = response // TODO: move the promise handling into the state machine promise?.succeed() } catch { @@ -172,18 +179,37 @@ extension GRPCServerStreamHandler { func flush(context: ChannelHandlerContext) { do { - switch try self.stateMachine.nextOutboundMessage() { - case .noMoreMessages: - // We shouldn't close the channel in this case, because we still have - // to send back a status and trailers to properly end the RPC stream. - () - case .awaitMoreMessages: - () - case .sendMessage(let byteBuffer): - context.writeAndFlush( + if self.isReading { + // We don't want to flush yet if we're still in a read loop. + return + } + + var nextOutboundMessage = try self.stateMachine.nextOutboundMessage() + while case .sendMessage(let byteBuffer) = nextOutboundMessage { + context.write( self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), promise: nil ) + self.flushPending = true + nextOutboundMessage = try self.stateMachine.nextOutboundMessage() + } + + switch nextOutboundMessage { + case .noMoreMessages: + if let pendingTrailers = self.pendingTrailers { + context.write(self.wrapOutboundOut(pendingTrailers), promise: nil) + self.flushPending = true + self.pendingTrailers = nil + } + case .awaitMoreMessages: + () + case .sendMessage: + preconditionFailure("This isn't possible: we'd still be inside the while loop.") + } + + if self.flushPending { + context.flush() + self.flushPending = false } } catch { context.fireErrorCaught(error) diff --git a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift index e8b1bbb3d..5839a7aa3 100644 --- a/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift +++ b/Tests/GRPCHTTP2CoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -678,6 +678,86 @@ final class GRPCServerStreamHandlerTests: XCTestCase { ) XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) } + + func testMessageAndStatusAreNotReordered() throws { + let handler = GRPCServerStreamHandler( + scheme: .http, + acceptedEncodings: [], + maximumPayloadSize: 100 + ) + + let channel = EmbeddedChannel(handler: handler) + + // Receive client's initial metadata + let clientInitialMetadata: HPACKHeaders = [ + GRPCHTTP2Keys.path.rawValue: "test/test", + GRPCHTTP2Keys.scheme.rawValue: "http", + GRPCHTTP2Keys.method.rawValue: "POST", + GRPCHTTP2Keys.contentType.rawValue: "application/grpc", + GRPCHTTP2Keys.te.rawValue: "trailers", + ] + XCTAssertNoThrow( + try channel.writeInbound( + HTTP2Frame.FramePayload.headers(.init(headers: clientInitialMetadata)) + ) + ) + + // Make sure we haven't sent back an error response, and that we read the initial metadata + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + XCTAssertEqual( + try channel.readInbound(as: RPCRequestPart.self), + RPCRequestPart.metadata(Metadata(headers: clientInitialMetadata)) + ) + + // Write back server's initial metadata + let serverInitialMetadata = RPCResponsePart.metadata(Metadata(headers: [:])) + XCTAssertNoThrow(try channel.writeOutbound(serverInitialMetadata)) + + // Read out the metadata + _ = try channel.readOutbound(as: HTTP2Frame.FramePayload.self) + + // This is where this test actually begins. We want to write a message followed + // by status and trailers, and only flush after both writes. + // Because messages are buffered and potentially bundled together in a single + // ByteBuffer by the GPRCMessageFramer, we want to make sure that the status + // and trailers won't be written before the messages. + + // Write back message and make sure nothing's written in the channel. + XCTAssertNoThrow(channel.write(RPCResponsePart.message([UInt8](repeating: 1, count: 4)))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Write status + metadata and make sure nothing's written. + XCTAssertNoThrow(channel.write(RPCResponsePart.status(.init(code: .ok, message: ""), [:]))) + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + + // Now flush and check we *do* write the data in the right order: message first, + // trailers second. + channel.flush() + + let writtenMessage = try channel.assertReadDataOutbound() + + // Make sure we first get message. + XCTAssertEqual( + writtenMessage.data, + .byteBuffer( + .init(bytes: [ + // First message + 0, // Compression disabled + 0, 0, 0, 4, // Message length + 1, 1, 1, 1, // First message data + ]) + ) + ) + XCTAssertFalse(writtenMessage.endStream) + + // Make sure we get trailers. + let writtenTrailers = try channel.assertReadHeadersOutbound() + XCTAssertEqual(writtenTrailers.headers, ["grpc-status": "0"]) + XCTAssertTrue(writtenTrailers.endStream) + + // Make sure we get nothing else. + XCTAssertNil(try channel.readOutbound(as: HTTP2Frame.FramePayload.self)) + } } extension EmbeddedChannel { From d8d5633b8ef51a9fd6075236097ac7df5cd1db59 Mon Sep 17 00:00:00 2001 From: Gus Cairo Date: Thu, 21 Mar 2024 11:08:03 +0000 Subject: [PATCH 8/8] PR nits --- .../Server/GRPCServerStreamHandler.swift | 79 +++++++++---------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift index 588f5df51..c4da0d4ed 100644 --- a/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCHTTP2Core/Server/GRPCServerStreamHandler.swift @@ -64,24 +64,21 @@ extension GRPCServerStreamHandler { case .byteBuffer(let buffer): do { try self.stateMachine.receive(message: buffer, endStream: endStream) - - var nextInboundMessage = self.stateMachine.nextInboundMessage() - while case .receiveMessage(let message) = nextInboundMessage { - context.fireChannelRead(self.wrapInboundOut(.message(message))) - nextInboundMessage = self.stateMachine.nextInboundMessage() - } - - switch nextInboundMessage { - case .awaitMoreMessages: - () - case .receiveMessage: - preconditionFailure("This isn't possible: we'd still be inside the while loop.") - case .noMoreMessages: - context.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + loop: while true { + switch self.stateMachine.nextInboundMessage() { + case .receiveMessage(let message): + context.fireChannelRead(self.wrapInboundOut(.message(message))) + case .awaitMoreMessages: + break loop + case .noMoreMessages: + context.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + break loop + } } } catch { context.fireErrorCaught(error) } + case .fileRegion: preconditionFailure("Unexpected IOData.fileRegion") } @@ -95,15 +92,18 @@ extension GRPCServerStreamHandler { switch action { case .receivedMetadata(let metadata): context.fireChannelRead(self.wrapInboundOut(.metadata(metadata))) + case .rejectRPC(let trailers): + self.flushPending = true let response = HTTP2Frame.FramePayload.headers(.init(headers: trailers, endStream: true)) context.write(self.wrapOutboundOut(response), promise: nil) - self.flushPending = true + case .receivedStatusAndMetadata: throw RPCError( code: .internalError, message: "Server cannot get receivedStatusAndMetadata." ) + case .doNothing: throw RPCError(code: .internalError, message: "Server cannot receive doNothing.") } @@ -140,9 +140,9 @@ extension GRPCServerStreamHandler { switch frame { case .metadata(let metadata): do { + self.flushPending = true let headers = try self.stateMachine.send(metadata: metadata) context.write(self.wrapOutboundOut(.headers(.init(headers: headers))), promise: nil) - self.flushPending = true // TODO: move the promise handling into the state machine promise?.succeed() } catch { @@ -178,38 +178,37 @@ extension GRPCServerStreamHandler { } func flush(context: ChannelHandlerContext) { + if self.isReading { + // We don't want to flush yet if we're still in a read loop. + return + } + do { - if self.isReading { - // We don't want to flush yet if we're still in a read loop. - return - } + loop: while true { + switch try self.stateMachine.nextOutboundMessage() { + case .sendMessage(let byteBuffer): + self.flushPending = true + context.write( + self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), + promise: nil + ) - var nextOutboundMessage = try self.stateMachine.nextOutboundMessage() - while case .sendMessage(let byteBuffer) = nextOutboundMessage { - context.write( - self.wrapOutboundOut(.data(.init(data: .byteBuffer(byteBuffer)))), - promise: nil - ) - self.flushPending = true - nextOutboundMessage = try self.stateMachine.nextOutboundMessage() - } + case .noMoreMessages: + if let pendingTrailers = self.pendingTrailers { + self.flushPending = true + self.pendingTrailers = nil + context.write(self.wrapOutboundOut(pendingTrailers), promise: nil) + } + break loop - switch nextOutboundMessage { - case .noMoreMessages: - if let pendingTrailers = self.pendingTrailers { - context.write(self.wrapOutboundOut(pendingTrailers), promise: nil) - self.flushPending = true - self.pendingTrailers = nil + case .awaitMoreMessages: + break loop } - case .awaitMoreMessages: - () - case .sendMessage: - preconditionFailure("This isn't possible: we'd still be inside the while loop.") } if self.flushPending { - context.flush() self.flushPending = false + context.flush() } } catch { context.fireErrorCaught(error)