diff --git a/Package.swift b/Package.swift index 870d0d16d..8d7d65648 100644 --- a/Package.swift +++ b/Package.swift @@ -306,7 +306,10 @@ extension Target { static let grpcHTTP2CoreTests: Target = .testTarget( name: "GRPCHTTP2CoreTests", dependencies: [ - .grpcHTTP2Core + .grpcHTTP2Core, + .nioCore, + .nioHTTP2, + .nioEmbedded, ] ) diff --git a/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionHandler.swift b/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionHandler.swift deleted file mode 100644 index 52dada671..000000000 --- a/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionHandler.swift +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2024, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Temporary namespace. Will be replaced with a channel handler. -enum ServerConnectionHandler { -} diff --git a/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionHandler+StateMachine.swift b/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler+StateMachine.swift similarity index 97% rename from Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionHandler+StateMachine.swift rename to Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler+StateMachine.swift index 5a10e41e0..156063adb 100644 --- a/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionHandler+StateMachine.swift +++ b/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler+StateMachine.swift @@ -17,7 +17,7 @@ import NIOCore import NIOHTTP2 -extension ServerConnectionHandler { +extension ServerConnectionManagementHandler { /// Tracks the state of TCP connections at the server. /// /// The state machine manages the state for the graceful shutdown procedure as well as policing @@ -248,7 +248,7 @@ extension ServerConnectionHandler { } } -extension ServerConnectionHandler.StateMachine { +extension ServerConnectionManagementHandler.StateMachine { fileprivate struct KeepAlive { /// Allow the client to send keep alive pings when there are no active calls. private let allowWithoutCalls: Bool @@ -267,8 +267,7 @@ extension ServerConnectionHandler.StateMachine { /// alive (a low number of strikes is therefore expected and okay). private var pingStrikes: Int - /// The last time a valid ping happened. This may be in the distant past if there is no such - /// time (for example the connection is new and there are no active calls). + /// The last time a valid ping happened. /// /// Note: `distantPast` isn't used to indicate no previous valid ping as `NIODeadline` uses /// the monotonic clock on Linux which uses an undefined starting point and in some cases isn't @@ -320,7 +319,7 @@ extension ServerConnectionHandler.StateMachine { } } -extension ServerConnectionHandler.StateMachine { +extension ServerConnectionManagementHandler.StateMachine { fileprivate enum State { /// The connection is active. struct Active { diff --git a/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift b/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift new file mode 100644 index 000000000..ffdbb8946 --- /dev/null +++ b/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift @@ -0,0 +1,473 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import NIOCore +import NIOHTTP2 + +/// A `ChannelHandler` which manages the lifecycle of a gRPC connection over HTTP/2. +/// +/// This handler is responsible for managing several aspects of the connection. These include: +/// 1. Handling the graceful close of connections. When gracefully closing a connection the server +/// sends a GOAWAY frame with the last stream ID set to the maximum stream ID allowed followed by +/// a PING frame. On receipt of the PING frame the server sends another GOAWAY frame with the +/// highest ID of all streams which have been opened. After this, the handler closes the +/// connection once all streams are closed. +/// 2. Enforcing that graceful shutdown doesn't exceed a configured limit (if configured). +/// 3. Gracefully closing the connection once it reaches the maximum configured age (if configured). +/// 4. Gracefully closing the connection once it has been idle for a given period of time (if +/// configured). +/// 5. Periodically sending keep alive pings to the client (if configured) and closing the +/// connection if necessary. +/// 6. Policing pings sent by the client to ensure that the client isn't misconfigured to send +/// too many pings. +/// +/// Some of the behaviours are described in: +/// - [gRFC A8](https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md), and +/// - [gRFC A9](https://github.com/grpc/proposal/blob/master/A9-server-side-conn-mgt.md). +final class ServerConnectionManagementHandler: ChannelDuplexHandler { + typealias InboundIn = HTTP2Frame + typealias InboundOut = HTTP2Frame + typealias OutboundIn = HTTP2Frame + typealias OutboundOut = HTTP2Frame + + /// The `EventLoop` of the `Channel` this handler exists in. + private let eventLoop: EventLoop + + /// The maximum amount of time a connection may be idle for. If the connection remains idle + /// (i.e. has no open streams) for this period of time then the connection will be gracefully + /// closed. + private var maxIdleTimer: Timer? + + /// The maximum age of a connection. If the connection remains open after this amount of time + /// then it will be gracefully closed. + private var maxAgeTimer: Timer? + + /// The maximum amount of time a connection may spend closing gracefully, after which it is + /// closed abruptly. The timer starts after the second GOAWAY frame has been sent. + private var maxGraceTimer: Timer? + + /// The amount of time to wait before sending a keep alive ping. + private var keepAliveTimer: Timer? + + /// The amount of time the client has to reply after sending a keep alive ping. Only used if + /// `keepAliveTimer` is set. + private var keepAliveTimeoutTimer: Timer + + private struct Timer { + /// The delay to wait before running the task. + private let delay: TimeAmount + /// The task to run, if scheduled. + private var task: Scheduled? + + init(delay: TimeAmount) { + self.delay = delay + self.task = nil + } + + /// Schedule a task on the given `EventLoop`. + mutating func schedule(on eventLoop: EventLoop, task: @escaping () throws -> Void) { + self.task?.cancel() + self.task = eventLoop.scheduleTask(in: self.delay, task) + } + + /// Cancels the task, if one was scheduled. + mutating func cancel() { + self.task?.cancel() + self.task = nil + } + } + + /// Opaque data sent in keep alive pings. + private let keepAlivePingData: HTTP2PingData + + /// Whether a flush is pending. + private var flushPending: Bool + /// Whether `channelRead` has been called and `channelReadComplete` hasn't yet been called. + /// Resets once `channelReadComplete` returns. + private var inReadLoop: Bool + + /// The current state of the connection. + private var state: StateMachine + + /// The clock. + private let clock: Clock + + /// A clock providing the current time. + /// + /// This is necessary for testing where a manual clock can be used and advanced from the test. + /// While NIO's `EmbeddedEventLoop` provides control over its view of time (and therefore any + /// events scheduled on it) it doesn't offer a way to get the current time. This is usually done + /// via `NIODeadline`. + enum Clock { + case nio + case manual(Manual) + + func now() -> NIODeadline { + switch self { + case .nio: + return .now() + case .manual(let clock): + return clock.time + } + } + + final class Manual { + private(set) var time: NIODeadline + + init() { + self.time = .uptimeNanoseconds(0) + } + + func advance(by amount: TimeAmount) { + self.time = self.time + amount + } + } + } + + /// Stats about recently written frames. Used to determine whether to reset keep-alive state. + private var frameStats: FrameStats + + struct FrameStats { + private(set) var didWriteHeadersOrData = false + + /// Mark that a HEADERS frame has been written. + mutating func wroteHeaders() { + self.didWriteHeadersOrData = true + } + + /// Mark that DATA frame has been written. + mutating func wroteData() { + self.didWriteHeadersOrData = true + } + + /// Resets the state such that no HEADERS or DATA frames have been written. + mutating func reset() { + self.didWriteHeadersOrData = false + } + } + + /// A synchronous view over this handler. + var syncView: SyncView { + return SyncView(self) + } + + /// A synchronous view over this handler. + /// + /// Methods on this view *must* be called from the same `EventLoop` as the `Channel` in which + /// this handler exists. + struct SyncView { + private let handler: ServerConnectionManagementHandler + + fileprivate init(_ handler: ServerConnectionManagementHandler) { + self.handler = handler + } + + /// Notify the handler that the connection has received a flush event. + func connectionWillFlush() { + // The handler can't rely on `flush(context:)` due to its expected position in the pipeline. + // It's expected to be placed after the HTTP/2 handler (i.e. closer to the application) as + // it needs to receive HTTP/2 frames. However, flushes from stream channels aren't sent down + // the entire connection channel, instead they are sent from the point in the channel they + // are multiplexed from (either the HTTP/2 handler or the HTTP/2 multiplexing handler, + // depending on how multiplexing is configured). + self.handler.eventLoop.assertInEventLoop() + if self.handler.frameStats.didWriteHeadersOrData { + self.handler.frameStats.reset() + self.handler.state.resetKeepAliveState() + } + } + + /// Notify the handler that a HEADERS frame was written in the last write loop. + func wroteHeadersFrame() { + self.handler.eventLoop.assertInEventLoop() + self.handler.frameStats.wroteHeaders() + } + + /// Notify the handler that a DATA frame was written in the last write loop. + func wroteDataFrame() { + self.handler.eventLoop.assertInEventLoop() + self.handler.frameStats.wroteData() + } + } + + /// Creates a new handler which manages the lifecycle of a connection. + /// + /// - Parameters: + /// - eventLoop: The `EventLoop` of the `Channel` this handler is placed in. + /// - maxIdleTime: The maximum amount time a connection may be idle for before being closed. + /// - maxAge: The maximum amount of time a connection may exist before being gracefully closed. + /// - maxGraceTime: The maximum amount of time that the connection has to close gracefully. + /// - keepAliveTime: The amount of time to wait after reading data before sending a keep-alive + /// ping. + /// - keepAliveTimeout: The amount of time the client has to reply after the server sends a + /// keep-alive ping to keep the connection open. The connection is closed if no reply + /// is received. + /// - allowKeepAliveWithoutCalls: Whether the server allows the client to send keep-alive pings + /// when there are no calls in progress. + /// - minPingIntervalWithoutCalls: The minimum allowed interval the client is allowed to send + /// keep-alive pings. Pings more frequent than this interval count as 'strikes' and the + /// connection is closed if there are too many strikes. + /// - clock: A clock providing the current time. + init( + eventLoop: EventLoop, + maxIdleTime: TimeAmount?, + maxAge: TimeAmount?, + maxGraceTime: TimeAmount?, + keepAliveTime: TimeAmount?, + keepAliveTimeout: TimeAmount?, + allowKeepAliveWithoutCalls: Bool, + minPingIntervalWithoutCalls: TimeAmount, + clock: Clock = .nio + ) { + self.eventLoop = eventLoop + + self.maxIdleTimer = maxIdleTime.map { Timer(delay: $0) } + self.maxAgeTimer = maxAge.map { Timer(delay: $0) } + self.maxGraceTimer = maxGraceTime.map { Timer(delay: $0) } + + self.keepAliveTimer = keepAliveTime.map { Timer(delay: $0) } + // Always create a keep alive timeout timer, it's only used if there is a keep alive timer. + self.keepAliveTimeoutTimer = Timer(delay: keepAliveTimeout ?? .seconds(20)) + + // Generate a random value to be used as keep alive ping data. + let pingData = UInt64.random(in: .min ... .max) + self.keepAlivePingData = HTTP2PingData(withInteger: pingData) + + self.state = StateMachine( + allowKeepAliveWithoutCalls: allowKeepAliveWithoutCalls, + minPingReceiveIntervalWithoutCalls: minPingIntervalWithoutCalls, + goAwayPingData: HTTP2PingData(withInteger: ~pingData) + ) + + self.flushPending = false + self.inReadLoop = false + self.clock = clock + self.frameStats = FrameStats() + } + + func handlerAdded(context: ChannelHandlerContext) { + assert(context.eventLoop === self.eventLoop) + } + + func channelActive(context: ChannelHandlerContext) { + self.maxAgeTimer?.schedule(on: context.eventLoop) { + self.initiateGracefulShutdown(context: context) + } + + self.maxIdleTimer?.schedule(on: context.eventLoop) { + self.initiateGracefulShutdown(context: context) + } + + self.keepAliveTimer?.schedule(on: context.eventLoop) { + self.keepAliveTimerFired(context: context) + } + + context.fireChannelActive() + } + + func channelInactive(context: ChannelHandlerContext) { + self.maxIdleTimer?.cancel() + self.maxAgeTimer?.cancel() + self.maxGraceTimer?.cancel() + self.keepAliveTimer?.cancel() + self.keepAliveTimeoutTimer.cancel() + context.fireChannelInactive() + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case let event as NIOHTTP2StreamCreatedEvent: + // The connection isn't idle if a stream is open. + self.maxIdleTimer?.cancel() + self.state.streamOpened(event.streamID) + + case let event as StreamClosedEvent: + switch self.state.streamClosed(event.streamID) { + case .startIdleTimer: + self.maxIdleTimer?.schedule(on: context.eventLoop) { + self.initiateGracefulShutdown(context: context) + } + + case .close: + context.close(mode: .all, promise: nil) + + case .none: + () + } + + default: + () + } + + context.fireUserInboundEventTriggered(event) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.inReadLoop = true + + // Any read data indicates that the connection is alive so cancel the keep-alive timers. + self.keepAliveTimer?.cancel() + self.keepAliveTimeoutTimer.cancel() + + let frame = self.unwrapInboundIn(data) + switch frame.payload { + case .ping(let data, let ack): + if ack { + self.handlePingAck(context: context, data: data) + } else { + self.handlePing(context: context, data: data) + } + + default: + () // Only interested in PING frames, ignore the rest. + } + + context.fireChannelRead(data) + } + + func channelReadComplete(context: ChannelHandlerContext) { + while self.flushPending { + self.flushPending = false + context.flush() + } + + self.inReadLoop = false + + // Done reading: schedule the keep-alive timer. + self.keepAliveTimer?.schedule(on: context.eventLoop) { + self.keepAliveTimerFired(context: context) + } + + context.fireChannelReadComplete() + } + + func flush(context: ChannelHandlerContext) { + self.maybeFlush(context: context) + } +} + +extension ServerConnectionManagementHandler { + private func maybeFlush(context: ChannelHandlerContext) { + if self.inReadLoop { + self.flushPending = true + } else { + context.flush() + } + } + + private func initiateGracefulShutdown(context: ChannelHandlerContext) { + context.eventLoop.assertInEventLoop() + + // Cancel any timers if initiating shutdown. + self.maxIdleTimer?.cancel() + self.maxAgeTimer?.cancel() + self.keepAliveTimer?.cancel() + self.keepAliveTimeoutTimer.cancel() + + switch self.state.startGracefulShutdown() { + case .sendGoAwayAndPing(let pingData): + // There's a time window between the server sending a GOAWAY frame and the client receiving + // it. During this time the client may open new streams as it doesn't yet know about the + // GOAWAY frame. + // + // The server therefore sends a GOAWAY with the last stream ID set to the maximum stream ID + // and follows it with a PING frame. When the server receives the ack for the PING frame it + // knows that the client has received the initial GOAWAY frame and that no more streams may + // be opened. The server can then send an additional GOAWAY frame with a more representative + // last stream ID. + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway( + lastStreamID: .maxID, + errorCode: .noError, + opaqueData: nil + ) + ) + + let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(pingData, ack: false)) + + context.write(self.wrapOutboundOut(goAway), promise: nil) + context.write(self.wrapOutboundOut(ping), promise: nil) + self.maybeFlush(context: context) + + case .none: + () // Already shutting down. + } + } + + private func handlePing(context: ChannelHandlerContext, data: HTTP2PingData) { + switch self.state.receivedPing(atTime: self.clock.now(), data: data) { + case .enhanceYourCalmThenClose(let streamID): + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway( + lastStreamID: streamID, + errorCode: .enhanceYourCalm, + opaqueData: context.channel.allocator.buffer(string: "too_many_pings") + ) + ) + + context.write(self.wrapOutboundOut(goAway), promise: nil) + self.maybeFlush(context: context) + context.close(promise: nil) + + case .sendAck: + let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(data, ack: true)) + context.write(self.wrapOutboundOut(ping), promise: nil) + self.maybeFlush(context: context) + + case .none: + () + } + } + + private func handlePingAck(context: ChannelHandlerContext, data: HTTP2PingData) { + switch self.state.receivedPingAck(data: data) { + case .sendGoAway(let streamID, let close): + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway(lastStreamID: streamID, errorCode: .noError, opaqueData: nil) + ) + + context.write(self.wrapOutboundOut(goAway), promise: nil) + self.maybeFlush(context: context) + + if close { + context.close(promise: nil) + } else { + // RPCs may have a grace period for finishing once the second GOAWAY frame has finished. + // If this is set close the connection abruptly once the grace period passes. + self.maxGraceTimer?.schedule(on: context.eventLoop) { + context.close(promise: nil) + } + } + + case .none: + () + } + } + + private func keepAliveTimerFired(context: ChannelHandlerContext) { + let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(self.keepAlivePingData, ack: false)) + context.write(self.wrapInboundOut(ping), promise: nil) + self.maybeFlush(context: context) + + // Schedule a timeout on waiting for the response. + self.keepAliveTimeoutTimer.schedule(on: context.eventLoop) { + self.initiateGracefulShutdown(context: context) + } + } +} diff --git a/Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionHandler+StateMachineTests.swift b/Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandler+StateMachineTests.swift similarity index 97% rename from Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionHandler+StateMachineTests.swift rename to Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandler+StateMachineTests.swift index ec4671c8a..47daf4d58 100644 --- a/Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionHandler+StateMachineTests.swift +++ b/Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandler+StateMachineTests.swift @@ -20,12 +20,12 @@ import XCTest @testable import GRPCHTTP2Core -final class ServerConnectionHandlerStateMachineTests: XCTestCase { +final class ServerConnectionManagementHandlerStateMachineTests: XCTestCase { private func makeStateMachine( allowKeepAliveWithoutCalls: Bool = false, minPingReceiveIntervalWithoutCalls: TimeAmount = .minutes(5), goAwayPingData: HTTP2PingData = HTTP2PingData(withInteger: 42) - ) -> ServerConnectionHandler.StateMachine { + ) -> ServerConnectionManagementHandler.StateMachine { return .init( allowKeepAliveWithoutCalls: allowKeepAliveWithoutCalls, minPingReceiveIntervalWithoutCalls: minPingReceiveIntervalWithoutCalls, @@ -169,7 +169,7 @@ final class ServerConnectionHandlerStateMachineTests: XCTestCase { } func testPingStrikeUsingMinReceiveInterval( - state: inout ServerConnectionHandler.StateMachine, + state: inout ServerConnectionManagementHandler.StateMachine, interval: TimeAmount, expectedID id: HTTP2StreamID ) { diff --git a/Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift new file mode 100644 index 000000000..abbd6bb52 --- /dev/null +++ b/Tests/GRPCHTTP2CoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift @@ -0,0 +1,428 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import NIOCore +import NIOEmbedded +import NIOHTTP2 +import XCTest + +@testable import GRPCHTTP2Core + +final class ServerConnectionManagementHandlerTests: XCTestCase { + func testIdleTimeoutOnNewConnection() throws { + let connection = try Connection(maxIdleTime: .minutes(1)) + try connection.activate() + // Hit the max idle time. + connection.advanceTime(by: .minutes(1)) + + // Follow the graceful shutdown flow. + try self.testGracefulShutdown(connection: connection, lastStreamID: 0) + + // Closed because no streams were open. + try connection.waitUntilClosed() + } + + func testIdleTimerIsCancelledWhenStreamIsOpened() throws { + let connection = try Connection(maxIdleTime: .minutes(1)) + try connection.activate() + + // Open a stream to cancel the idle timer and run through the max idle time. + connection.streamOpened(1) + connection.advanceTime(by: .minutes(1)) + + // No GOAWAY frame means the timer was cancelled. + XCTAssertNil(try connection.readFrame()) + } + + func testIdleTimerStartsWhenAllStreamsAreClosed() throws { + let connection = try Connection(maxIdleTime: .minutes(1)) + try connection.activate() + + // Open a stream to cancel the idle timer and run through the max idle time. + connection.streamOpened(1) + connection.advanceTime(by: .minutes(1)) + XCTAssertNil(try connection.readFrame()) + + // Close the stream to start the timer again. + connection.streamClosed(1) + connection.advanceTime(by: .minutes(1)) + + // Follow the graceful shutdown flow. + try self.testGracefulShutdown(connection: connection, lastStreamID: 1) + + // Closed because no streams were open. + try connection.waitUntilClosed() + } + + func testMaxAge() throws { + let connection = try Connection(maxAge: .minutes(1)) + try connection.activate() + + // Open some streams. + connection.streamOpened(1) + connection.streamOpened(3) + + // Run to the max age and follow the graceful shutdown flow. + connection.advanceTime(by: .minutes(1)) + try self.testGracefulShutdown(connection: connection, lastStreamID: 3) + + // Close the streams. + connection.streamClosed(1) + connection.streamClosed(3) + + // Connection will be closed now. + try connection.waitUntilClosed() + } + + func testGracefulShutdownRatchetsDownStreamID() throws { + // This test uses the idle timeout to trigger graceful shutdown. The mechanism is the same + // regardless of how it's triggered. + let connection = try Connection(maxIdleTime: .minutes(1)) + try connection.activate() + + // Trigger the shutdown, but open a stream during shutdown. + connection.advanceTime(by: .minutes(1)) + try self.testGracefulShutdown( + connection: connection, + lastStreamID: 1, + streamToOpenBeforePingAck: 1 + ) + + // Close the stream to trigger closing the connection. + connection.streamClosed(1) + try connection.waitUntilClosed() + } + + func testGracefulShutdownGracePeriod() throws { + // This test uses the idle timeout to trigger graceful shutdown. The mechanism is the same + // regardless of how it's triggered. + let connection = try Connection( + maxIdleTime: .minutes(1), + maxGraceTime: .seconds(5) + ) + try connection.activate() + + // Trigger the shutdown, but open a stream during shutdown. + connection.advanceTime(by: .minutes(1)) + try self.testGracefulShutdown( + connection: connection, + lastStreamID: 1, + streamToOpenBeforePingAck: 1 + ) + + // Wait out the grace period without closing the stream. + connection.advanceTime(by: .seconds(5)) + try connection.waitUntilClosed() + } + + func testKeepAliveOnNewConnection() throws { + let connection = try Connection( + keepAliveTime: .minutes(5), + keepAliveTimeout: .seconds(5) + ) + try connection.activate() + + // Wait for the keep alive timer to fire which should cause the server to send a keep + // alive PING. + connection.advanceTime(by: .minutes(5)) + let frame1 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame1.streamID, .rootStream) + try XCTAssertPing(frame1.payload) { data, ack in + XCTAssertFalse(ack) + // Data is opaque, send it back. + try connection.ping(data: data, ack: true) + } + + // Run past the timeout, nothing should happen. + connection.advanceTime(by: .seconds(5)) + XCTAssertNil(try connection.readFrame()) + } + + func testKeepAliveStartsAfterReadLoop() throws { + let connection = try Connection( + keepAliveTime: .minutes(5), + keepAliveTimeout: .seconds(5) + ) + try connection.activate() + + // Write a frame into the channel _without_ calling channel read complete. This will cancel + // the keep alive timer. + let settings = HTTP2Frame(streamID: .rootStream, payload: .settings(.settings([]))) + connection.channel.pipeline.fireChannelRead(NIOAny(settings)) + + // Run out the keep alive timer, it shouldn't fire. + connection.advanceTime(by: .minutes(5)) + XCTAssertNil(try connection.readFrame()) + + // Fire channel read complete to start the keep alive timer again. + connection.channel.pipeline.fireChannelReadComplete() + + // Now expire the keep alive timer again, we should read out a PING frame. + connection.advanceTime(by: .minutes(5)) + let frame1 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame1.streamID, .rootStream) + XCTAssertPing(frame1.payload) { data, ack in + XCTAssertFalse(ack) + } + } + + func testKeepAliveOnNewConnectionWithoutResponse() throws { + let connection = try Connection( + keepAliveTime: .minutes(5), + keepAliveTimeout: .seconds(5) + ) + try connection.activate() + + // Wait for the keep alive timer to fire which should cause the server to send a keep + // alive PING. + connection.advanceTime(by: .minutes(5)) + let frame1 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame1.streamID, .rootStream) + XCTAssertPing(frame1.payload) { data, ack in + XCTAssertFalse(ack) + } + + // We didn't ack the PING, the connection should shutdown after the timeout. + connection.advanceTime(by: .seconds(5)) + try self.testGracefulShutdown(connection: connection, lastStreamID: 0) + + // Connection is closed now. + try connection.waitUntilClosed() + } + + func testClientKeepAlivePolicing() throws { + let connection = try Connection( + allowKeepAliveWithoutCalls: true, + minPingIntervalWithoutCalls: .minutes(1) + ) + try connection.activate() + + // The first ping is valid, the second and third are strikes. + for _ in 1 ... 3 { + try connection.ping(data: HTTP2PingData(), ack: false) + let frame = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame.streamID, .rootStream) + XCTAssertPing(frame.payload) { data, ack in + XCTAssertEqual(data, HTTP2PingData()) + XCTAssertTrue(ack) + } + } + + // The fourth ping is the third strike and triggers a GOAWAY. + try connection.ping(data: HTTP2PingData(), ack: false) + let frame = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame.streamID, .rootStream) + XCTAssertGoAway(frame.payload) { streamID, error, data in + XCTAssertEqual(streamID, .rootStream) + XCTAssertEqual(error, .enhanceYourCalm) + XCTAssertEqual(data, ByteBuffer(string: "too_many_pings")) + } + + // The server should close the connection. + try connection.waitUntilClosed() + } + + func testClientKeepAliveWithPermissibleIntervals() throws { + let connection = try Connection( + allowKeepAliveWithoutCalls: true, + minPingIntervalWithoutCalls: .minutes(1), + manualClock: true + ) + try connection.activate() + + for _ in 1 ... 100 { + try connection.ping(data: HTTP2PingData(), ack: false) + let frame = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame.streamID, .rootStream) + XCTAssertPing(frame.payload) { data, ack in + XCTAssertEqual(data, HTTP2PingData()) + XCTAssertTrue(ack) + } + + // Advance by the ping interval. + connection.advanceTime(by: .minutes(1)) + } + } + + func testClientKeepAliveResetState() throws { + let connection = try Connection( + allowKeepAliveWithoutCalls: true, + minPingIntervalWithoutCalls: .minutes(1) + ) + try connection.activate() + + func sendThreeKeepAlivePings() throws { + // The first ping is valid, the second and third are strikes. + for _ in 1 ... 3 { + try connection.ping(data: HTTP2PingData(), ack: false) + let frame = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame.streamID, .rootStream) + XCTAssertPing(frame.payload) { data, ack in + XCTAssertEqual(data, HTTP2PingData()) + XCTAssertTrue(ack) + } + } + } + + try sendThreeKeepAlivePings() + + // "send" a HEADERS frame and flush to reset keep alive state. + connection.syncView.wroteHeadersFrame() + connection.syncView.connectionWillFlush() + + // As above, the first ping is valid, the next two are strikes. + try sendThreeKeepAlivePings() + + // The next ping is the third strike and triggers a GOAWAY. + try connection.ping(data: HTTP2PingData(), ack: false) + let frame = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame.streamID, .rootStream) + XCTAssertGoAway(frame.payload) { streamID, error, data in + XCTAssertEqual(streamID, .rootStream) + XCTAssertEqual(error, .enhanceYourCalm) + XCTAssertEqual(data, ByteBuffer(string: "too_many_pings")) + } + + // The server should close the connection. + try connection.waitUntilClosed() + } +} + +extension ServerConnectionManagementHandlerTests { + private func testGracefulShutdown( + connection: Connection, + lastStreamID: HTTP2StreamID, + streamToOpenBeforePingAck: HTTP2StreamID? = nil + ) throws { + let frame1 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame1.streamID, .rootStream) + XCTAssertGoAway(frame1.payload) { streamID, errorCode, _ in + XCTAssertEqual(streamID, .maxID) + XCTAssertEqual(errorCode, .noError) + } + + // Followed by a PING + let frame2 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame2.streamID, .rootStream) + try XCTAssertPing(frame2.payload) { data, ack in + XCTAssertFalse(ack) + + if let id = streamToOpenBeforePingAck { + connection.streamOpened(id) + } + + // Send the PING ACK. + try connection.ping(data: data, ack: true) + } + + // PING ACK triggers another GOAWAY. + let frame3 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame3.streamID, .rootStream) + XCTAssertGoAway(frame3.payload) { streamID, errorCode, _ in + XCTAssertEqual(streamID, lastStreamID) + XCTAssertEqual(errorCode, .noError) + } + } +} + +extension ServerConnectionManagementHandlerTests { + struct Connection { + let channel: EmbeddedChannel + let syncView: ServerConnectionManagementHandler.SyncView + + var loop: EmbeddedEventLoop { + self.channel.embeddedEventLoop + } + + private let clock: ServerConnectionManagementHandler.Clock + + init( + maxIdleTime: TimeAmount? = nil, + maxAge: TimeAmount? = nil, + maxGraceTime: TimeAmount? = nil, + keepAliveTime: TimeAmount? = nil, + keepAliveTimeout: TimeAmount? = nil, + allowKeepAliveWithoutCalls: Bool = false, + minPingIntervalWithoutCalls: TimeAmount = .minutes(5), + manualClock: Bool = false + ) throws { + if manualClock { + self.clock = .manual(ServerConnectionManagementHandler.Clock.Manual()) + } else { + self.clock = .nio + } + + let loop = EmbeddedEventLoop() + let handler = ServerConnectionManagementHandler( + eventLoop: loop, + maxIdleTime: maxIdleTime, + maxAge: maxAge, + maxGraceTime: maxGraceTime, + keepAliveTime: keepAliveTime, + keepAliveTimeout: keepAliveTimeout, + allowKeepAliveWithoutCalls: allowKeepAliveWithoutCalls, + minPingIntervalWithoutCalls: minPingIntervalWithoutCalls, + clock: self.clock + ) + + self.syncView = handler.syncView + self.channel = EmbeddedChannel(handler: handler, loop: loop) + } + + func activate() throws { + try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait() + } + + func advanceTime(by delta: TimeAmount) { + switch self.clock { + case .nio: + () + case .manual(let clock): + clock.advance(by: delta) + } + + self.loop.advanceTime(by: delta) + } + + func streamOpened(_ id: HTTP2StreamID) { + let event = NIOHTTP2StreamCreatedEvent( + streamID: id, + localInitialWindowSize: nil, + remoteInitialWindowSize: nil + ) + self.channel.pipeline.fireUserInboundEventTriggered(event) + } + + func streamClosed(_ id: HTTP2StreamID) { + let event = StreamClosedEvent(streamID: id, reason: nil) + self.channel.pipeline.fireUserInboundEventTriggered(event) + } + + func ping(data: HTTP2PingData, ack: Bool) throws { + let frame = HTTP2Frame(streamID: .rootStream, payload: .ping(data, ack: ack)) + try self.channel.writeInbound(frame) + } + + func readFrame() throws -> HTTP2Frame? { + return try self.channel.readOutbound(as: HTTP2Frame.self) + } + + func waitUntilClosed() throws { + self.channel.embeddedEventLoop.run() + try self.channel.closeFuture.wait() + } + } +} diff --git a/Tests/GRPCHTTP2CoreTests/XCTest+FramePayload.swift b/Tests/GRPCHTTP2CoreTests/XCTest+FramePayload.swift new file mode 100644 index 000000000..b6892d0db --- /dev/null +++ b/Tests/GRPCHTTP2CoreTests/XCTest+FramePayload.swift @@ -0,0 +1,43 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import NIOCore +import NIOHTTP2 +import XCTest + +func XCTAssertGoAway( + _ payload: HTTP2Frame.FramePayload, + verify: (HTTP2StreamID, HTTP2ErrorCode, ByteBuffer?) throws -> Void = { _, _, _ in } +) rethrows { + switch payload { + case .goAway(let lastStreamID, let errorCode, let opaqueData): + try verify(lastStreamID, errorCode, opaqueData) + default: + XCTFail("Expected '.goAway' got '\(payload)'") + } +} + +func XCTAssertPing( + _ payload: HTTP2Frame.FramePayload, + verify: (HTTP2PingData, Bool) throws -> Void = { _, _ in } +) rethrows { + switch payload { + case .ping(let data, ack: let ack): + try verify(data, ack) + default: + XCTFail("Expected '.ping' got '\(payload)'") + } +}