diff --git a/Sources/GRPCHTTP2Core/Client/Connection/ClientConnectionHandler.swift b/Sources/GRPCHTTP2Core/Client/Connection/ClientConnectionHandler.swift new file mode 100644 index 000000000..0b7f27791 --- /dev/null +++ b/Sources/GRPCHTTP2Core/Client/Connection/ClientConnectionHandler.swift @@ -0,0 +1,421 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import NIOCore +import NIOHTTP2 + +/// An event which happens on a client's HTTP/2 connection. +enum ClientConnectionEvent: Sendable, Hashable { + enum CloseReason: Sendable, Hashable { + /// The server sent a GOAWAY frame to the client. + case goAway(HTTP2ErrorCode, String) + /// The keep alive timer fired and subsequently timed out. + case keepAliveExpired + /// The connection became idle. + case idle + } + + /// The connection has started shutting down, no new streams should be created. + case closing(CloseReason) +} + +/// A `ChannelHandler` which manages part of the lifecycle of a gRPC connection over HTTP/2. +/// +/// This handler is responsible for managing several aspects of the connection. These include: +/// 1. Periodically sending keep alive pings to the server (if configured) and closing the +/// connection if necessary. +/// 2. Closing the connection if it is idle (has no open streams) for a configured amount of time. +/// 3. Forwarding lifecycle events to the next handler. +/// +/// Some of the behaviours are described in [gRFC A8](https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md). +final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandler { + typealias InboundIn = HTTP2Frame + typealias InboundOut = ClientConnectionEvent + + typealias OutboundIn = Never + typealias OutboundOut = HTTP2Frame + + /// The `EventLoop` of the `Channel` this handler exists in. + private let eventLoop: EventLoop + + /// The maximum amount of time the connection may be idle for. If the connection remains idle + /// (i.e. has no open streams) for this period of time then the connection will be gracefully + /// closed. + private var maxIdleTimer: Timer? + + /// The amount of time to wait before sending a keep alive ping. + private var keepAliveTimer: Timer? + + /// The amount of time the client has to reply after sending a keep alive ping. Only used if + /// `keepAliveTimer` is set. + private var keepAliveTimeoutTimer: Timer + + /// Opaque data sent in keep alive pings. + private let keepAlivePingData: HTTP2PingData + + /// The current state of the connection. + private var state: StateMachine + + /// Whether a flush is pending. + private var flushPending: Bool + /// Whether `channelRead` has been called and `channelReadComplete` hasn't yet been called. + /// Resets once `channelReadComplete` returns. + private var inReadLoop: Bool + + /// Creates a new handler which manages the lifecycle of a connection. + /// + /// - Parameters: + /// - eventLoop: The `EventLoop` of the `Channel` this handler is placed in. + /// - maxIdleTime: The maximum amount time a connection may be idle for before being closed. + /// - keepAliveTime: The amount of time to wait after reading data before sending a keep-alive + /// ping. + /// - keepAliveTimeout: The amount of time the client has to reply after the server sends a + /// keep-alive ping to keep the connection open. The connection is closed if no reply + /// is received. + /// - keepAliveWithoutCalls: Whether the client sends keep-alive pings when there are no calls + /// in progress. + init( + eventLoop: EventLoop, + maxIdleTime: TimeAmount?, + keepAliveTime: TimeAmount?, + keepAliveTimeout: TimeAmount?, + keepAliveWithoutCalls: Bool + ) { + self.eventLoop = eventLoop + self.maxIdleTimer = maxIdleTime.map { Timer(delay: $0) } + self.keepAliveTimer = keepAliveTime.map { Timer(delay: $0, repeat: true) } + self.keepAliveTimeoutTimer = Timer(delay: keepAliveTimeout ?? .seconds(20)) + self.keepAlivePingData = HTTP2PingData(withInteger: .random(in: .min ... .max)) + self.state = StateMachine(allowKeepAliveWithoutCalls: keepAliveWithoutCalls) + + self.flushPending = false + self.inReadLoop = false + } + + func handlerAdded(context: ChannelHandlerContext) { + assert(context.eventLoop === self.eventLoop) + } + + func channelActive(context: ChannelHandlerContext) { + self.keepAliveTimer?.schedule(on: context.eventLoop) { + self.keepAliveTimerFired(context: context) + } + + self.maxIdleTimer?.schedule(on: context.eventLoop) { + self.maxIdleTimerFired(context: context) + } + } + + func channelInactive(context: ChannelHandlerContext) { + self.state.closed() + self.keepAliveTimer?.cancel() + self.keepAliveTimeoutTimer.cancel() + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case let event as NIOHTTP2StreamCreatedEvent: + // Stream created, so the connection isn't idle. + self.maxIdleTimer?.cancel() + self.state.streamOpened(event.streamID) + + case let event as StreamClosedEvent: + switch self.state.streamClosed(event.streamID) { + case .startIdleTimer(let cancelKeepAlive): + // All streams are closed, restart the idle timer, and stop the keep-alive timer (it may + // not stop if keep-alive is allowed when there are no active calls). + self.maxIdleTimer?.schedule(on: context.eventLoop) { + self.maxIdleTimerFired(context: context) + } + + if cancelKeepAlive { + self.keepAliveTimer?.cancel() + } + + case .close: + // Connection was closing but waiting for all streams to close. They must all be closed + // now so close the connection. + context.close(promise: nil) + + case .none: + () + } + + default: + () + } + + context.fireUserInboundEventTriggered(event) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let frame = self.unwrapInboundIn(data) + self.inReadLoop = true + + switch frame.payload { + case .goAway(_, let errorCode, let data): + // Receiving a GOAWAY frame means we need to stop creating streams immediately and start + // closing the connection. + switch self.state.beginGracefulShutdown() { + case .sendGoAway(let close): + // gRPC servers may indicate why the GOAWAY was sent in the opaque data. + let message = data.map { String(buffer: $0) } ?? "" + context.fireChannelRead(self.wrapInboundOut(.closing(.goAway(errorCode, message)))) + + // Clients should send GOAWAYs when closing a connection. + self.writeAndFlushGoAway(context: context, errorCode: .noError) + if close { + context.close(promise: nil) + } + + case .none: + () + } + + case .ping(let data, let ack): + // Pings are ack'd by the HTTP/2 handler so we only pay attention to acks here, and in + // particular only those carrying the keep-alive data. + if ack, data == self.keepAlivePingData { + self.keepAliveTimeoutTimer.cancel() + self.keepAliveTimer?.schedule(on: context.eventLoop) { + self.keepAliveTimerFired(context: context) + } + } + + default: + () + } + } + + func channelReadComplete(context: ChannelHandlerContext) { + while self.flushPending { + self.flushPending = false + context.flush() + } + + self.inReadLoop = false + context.fireChannelReadComplete() + } +} + +extension ClientConnectionHandler { + private func maybeFlush(context: ChannelHandlerContext) { + if self.inReadLoop { + self.flushPending = true + } else { + context.flush() + } + } + + private func keepAliveTimerFired(context: ChannelHandlerContext) { + guard self.state.sendKeepAlivePing() else { return } + + // Cancel the keep alive timer when the client sends a ping. The timer is resumed when the ping + // is acknowledged. + self.keepAliveTimer?.cancel() + + let ping = HTTP2Frame(streamID: .rootStream, payload: .ping(self.keepAlivePingData, ack: false)) + context.write(self.wrapOutboundOut(ping), promise: nil) + self.maybeFlush(context: context) + + // Schedule a timeout on waiting for the response. + self.keepAliveTimeoutTimer.schedule(on: context.eventLoop) { + self.keepAliveTimeoutExpired(context: context) + } + } + + private func keepAliveTimeoutExpired(context: ChannelHandlerContext) { + guard self.state.beginClosing() else { return } + + context.fireChannelRead(self.wrapInboundOut(.closing(.keepAliveExpired))) + self.writeAndFlushGoAway(context: context, message: "keepalive_expired") + context.close(promise: nil) + } + + private func maxIdleTimerFired(context: ChannelHandlerContext) { + guard self.state.beginClosing() else { return } + + context.fireChannelRead(self.wrapInboundOut(.closing(.idle))) + self.writeAndFlushGoAway(context: context, message: "idle") + context.close(promise: nil) + } + + private func writeAndFlushGoAway( + context: ChannelHandlerContext, + errorCode: HTTP2ErrorCode = .noError, + message: String? = nil + ) { + let goAway = HTTP2Frame( + streamID: .rootStream, + payload: .goAway( + lastStreamID: 0, + errorCode: errorCode, + opaqueData: message.map { context.channel.allocator.buffer(string: $0) } + ) + ) + + context.write(self.wrapOutboundOut(goAway), promise: nil) + self.maybeFlush(context: context) + } +} + +extension ClientConnectionHandler { + struct StateMachine { + private var state: State + + private enum State { + case active(Active) + case closing(Closing) + case closed + + struct Active { + var openStreams: Set + var allowKeepAliveWithoutCalls: Bool + + init(allowKeepAliveWithoutCalls: Bool) { + self.openStreams = [] + self.allowKeepAliveWithoutCalls = allowKeepAliveWithoutCalls + } + } + + struct Closing { + var allowKeepAliveWithoutCalls: Bool + var openStreams: Set + + init(from state: Active) { + self.openStreams = state.openStreams + self.allowKeepAliveWithoutCalls = state.allowKeepAliveWithoutCalls + } + } + } + + init(allowKeepAliveWithoutCalls: Bool) { + self.state = .active(State.Active(allowKeepAliveWithoutCalls: allowKeepAliveWithoutCalls)) + } + + /// Record that the stream with the given ID has been opened. + mutating func streamOpened(_ id: HTTP2StreamID) { + switch self.state { + case .active(var state): + let (inserted, _) = state.openStreams.insert(id) + assert(inserted, "Can't open stream \(Int(id)), it's already open") + self.state = .active(state) + + case .closing(var state): + let (inserted, _) = state.openStreams.insert(id) + assert(inserted, "Can't open stream \(Int(id)), it's already open") + self.state = .closing(state) + + case .closed: + () + } + } + + enum OnStreamClosed: Equatable { + /// Start the idle timer, after which the connection should be closed gracefully. + case startIdleTimer(cancelKeepAlive: Bool) + /// Close the connection. + case close + /// Do nothing. + case none + } + + /// Record that the stream with the given ID has been closed. + mutating func streamClosed(_ id: HTTP2StreamID) -> OnStreamClosed { + let onStreamClosed: OnStreamClosed + + switch self.state { + case .active(var state): + let removedID = state.openStreams.remove(id) + assert(removedID != nil, "Can't close stream \(Int(id)), it wasn't open") + if state.openStreams.isEmpty { + onStreamClosed = .startIdleTimer(cancelKeepAlive: !state.allowKeepAliveWithoutCalls) + } else { + onStreamClosed = .none + } + self.state = .active(state) + + case .closing(var state): + let removedID = state.openStreams.remove(id) + assert(removedID != nil, "Can't close stream \(Int(id)), it wasn't open") + onStreamClosed = state.openStreams.isEmpty ? .close : .none + self.state = .closing(state) + + case .closed: + onStreamClosed = .none + } + + return onStreamClosed + } + + /// Returns whether a keep alive ping should be sent to the server. + mutating func sendKeepAlivePing() -> Bool { + let sendKeepAlivePing: Bool + + // Only send a ping if there are open streams or there are no open streams and keep alive + // is permitted when there are no active calls. + switch self.state { + case .active(let state): + sendKeepAlivePing = !state.openStreams.isEmpty || state.allowKeepAliveWithoutCalls + case .closing(let state): + sendKeepAlivePing = !state.openStreams.isEmpty || state.allowKeepAliveWithoutCalls + case .closed: + sendKeepAlivePing = false + } + + return sendKeepAlivePing + } + + enum OnGracefulShutDown: Equatable { + case sendGoAway(Bool) + case none + } + + mutating func beginGracefulShutdown() -> OnGracefulShutDown { + let onGracefulShutdown: OnGracefulShutDown + + switch self.state { + case .active(let state): + // Only close immediately if there are no open streams. The client doesn't need to + // ratchet down the last stream ID as only the client creates streams in gRPC. + let close = state.openStreams.isEmpty + onGracefulShutdown = .sendGoAway(close) + self.state = .closing(State.Closing(from: state)) + + case .closing, .closed: + onGracefulShutdown = .none + } + + return onGracefulShutdown + } + + /// Returns whether the connection should be closed. + mutating func beginClosing() -> Bool { + switch self.state { + case .active(let active): + self.state = .closing(State.Closing(from: active)) + return true + case .closing, .closed: + return false + } + } + + /// Marks the state as closed. + mutating func closed() { + self.state = .closed + } + } +} diff --git a/Sources/GRPCHTTP2Core/Internal/Timer.swift b/Sources/GRPCHTTP2Core/Internal/Timer.swift new file mode 100644 index 000000000..0d97d148e --- /dev/null +++ b/Sources/GRPCHTTP2Core/Internal/Timer.swift @@ -0,0 +1,67 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import NIOCore + +struct Timer { + /// The delay to wait before running the task. + private let delay: TimeAmount + /// The task to run, if scheduled. + private var task: Kind? + /// Whether the task to schedule is repeated. + private let `repeat`: Bool + + private enum Kind { + case once(Scheduled) + case repeated(RepeatedTask) + + func cancel() { + switch self { + case .once(let task): + task.cancel() + case .repeated(let task): + task.cancel() + } + } + } + + init(delay: TimeAmount, repeat: Bool = false) { + self.delay = delay + self.task = nil + self.repeat = `repeat` + } + + /// Schedule a task on the given `EventLoop`. + mutating func schedule(on eventLoop: EventLoop, work: @escaping () throws -> Void) { + self.task?.cancel() + + if self.repeat { + let task = eventLoop.scheduleRepeatedTask(initialDelay: self.delay, delay: self.delay) { _ in + try work() + } + self.task = .repeated(task) + } else { + let task = eventLoop.scheduleTask(in: self.delay, work) + self.task = .once(task) + } + } + + /// Cancels the task, if one was scheduled. + mutating func cancel() { + self.task?.cancel() + self.task = nil + } +} diff --git a/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift b/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift index ffdbb8946..459da88af 100644 --- a/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift +++ b/Sources/GRPCHTTP2Core/Server/Connection/ServerConnectionManagementHandler.swift @@ -66,30 +66,6 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { /// `keepAliveTimer` is set. private var keepAliveTimeoutTimer: Timer - private struct Timer { - /// The delay to wait before running the task. - private let delay: TimeAmount - /// The task to run, if scheduled. - private var task: Scheduled? - - init(delay: TimeAmount) { - self.delay = delay - self.task = nil - } - - /// Schedule a task on the given `EventLoop`. - mutating func schedule(on eventLoop: EventLoop, task: @escaping () throws -> Void) { - self.task?.cancel() - self.task = eventLoop.scheduleTask(in: self.delay, task) - } - - /// Cancels the task, if one was scheduled. - mutating func cancel() { - self.task?.cancel() - self.task = nil - } - } - /// Opaque data sent in keep alive pings. private let keepAlivePingData: HTTP2PingData diff --git a/Tests/GRPCHTTP2CoreTests/Client/Connection/ClientConnectionHandlerStateMachineTests.swift b/Tests/GRPCHTTP2CoreTests/Client/Connection/ClientConnectionHandlerStateMachineTests.swift new file mode 100644 index 000000000..206c317e4 --- /dev/null +++ b/Tests/GRPCHTTP2CoreTests/Client/Connection/ClientConnectionHandlerStateMachineTests.swift @@ -0,0 +1,107 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import NIOCore +import NIOEmbedded +import XCTest + +@testable import GRPCHTTP2Core + +final class ClientConnectionHandlerStateMachineTests: XCTestCase { + private func makeStateMachine( + keepAliveWithoutCalls: Bool = false + ) -> ClientConnectionHandler.StateMachine { + return ClientConnectionHandler.StateMachine(allowKeepAliveWithoutCalls: keepAliveWithoutCalls) + } + + func testCloseSomeStreamsWhenActive() { + var state = self.makeStateMachine() + state.streamOpened(1) + state.streamOpened(2) + XCTAssertEqual(state.streamClosed(2), .none) + XCTAssertEqual(state.streamClosed(1), .startIdleTimer(cancelKeepAlive: true)) + } + + func testCloseSomeStreamsWhenClosing() { + var state = self.makeStateMachine() + state.streamOpened(1) + state.streamOpened(2) + XCTAssertTrue(state.beginClosing()) + XCTAssertEqual(state.streamClosed(2), .none) + XCTAssertEqual(state.streamClosed(1), .close) + } + + func testOpenAndCloseStreamWhenClosed() { + var state = self.makeStateMachine() + state.closed() + state.streamOpened(1) + XCTAssertEqual(state.streamClosed(1), .none) + } + + func testSendKeepAlivePing() { + var state = self.makeStateMachine(keepAliveWithoutCalls: false) + // No streams open so ping isn't allowed. + XCTAssertFalse(state.sendKeepAlivePing()) + + // Stream open, ping allowed. + state.streamOpened(1) + XCTAssertTrue(state.sendKeepAlivePing()) + + // No stream, no ping. + XCTAssertEqual(state.streamClosed(1), .startIdleTimer(cancelKeepAlive: true)) + XCTAssertFalse(state.sendKeepAlivePing()) + } + + func testSendKeepAlivePingWhenAllowedWithoutCalls() { + var state = self.makeStateMachine(keepAliveWithoutCalls: true) + // Keep alive is allowed when no streams are open, so pings are allowed. + XCTAssertTrue(state.sendKeepAlivePing()) + + state.streamOpened(1) + XCTAssertTrue(state.sendKeepAlivePing()) + + XCTAssertEqual(state.streamClosed(1), .startIdleTimer(cancelKeepAlive: false)) + XCTAssertTrue(state.sendKeepAlivePing()) + } + + func testSendKeepAlivePingWhenClosing() { + var state = self.makeStateMachine(keepAliveWithoutCalls: false) + state.streamOpened(1) + XCTAssertTrue(state.beginClosing()) + + // Stream is opened and state is closing, ping is allowed. + XCTAssertTrue(state.sendKeepAlivePing()) + } + + func testSendKeepAlivePingWhenClosed() { + var state = self.makeStateMachine(keepAliveWithoutCalls: true) + state.closed() + XCTAssertFalse(state.sendKeepAlivePing()) + } + + func testBeginGracefulShutdownWhenStreamsAreOpen() { + var state = self.makeStateMachine() + state.streamOpened(1) + // Close is false as streams are still open. + XCTAssertEqual(state.beginGracefulShutdown(), .sendGoAway(false)) + } + + func testBeginGracefulShutdownWhenNoStreamsAreOpen() { + var state = self.makeStateMachine() + // Close immediately, not streams are open. + XCTAssertEqual(state.beginGracefulShutdown(), .sendGoAway(true)) + } +} diff --git a/Tests/GRPCHTTP2CoreTests/Client/Connection/ClientConnectionHandlerTests.swift b/Tests/GRPCHTTP2CoreTests/Client/Connection/ClientConnectionHandlerTests.swift new file mode 100644 index 000000000..882086751 --- /dev/null +++ b/Tests/GRPCHTTP2CoreTests/Client/Connection/ClientConnectionHandlerTests.swift @@ -0,0 +1,274 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import NIOCore +import NIOEmbedded +import NIOHTTP2 +import XCTest + +@testable import GRPCHTTP2Core + +final class ClientConnectionHandlerTests: XCTestCase { + func testMaxIdleTime() throws { + let connection = try Connection(maxIdleTime: .minutes(5)) + try connection.activate() + + // Idle with no streams open we should: + // - read out a closing event, + // - write a GOAWAY frame, + // - close. + connection.loop.advanceTime(by: .minutes(5)) + + XCTAssertEqual(try connection.readEvent(), .closing(.idle)) + + let frame = try XCTUnwrap(try connection.readFrame()) + XCTAssertEqual(frame.streamID, .rootStream) + XCTAssertGoAway(frame.payload) { lastStreamID, error, data in + XCTAssertEqual(lastStreamID, .rootStream) + XCTAssertEqual(error, .noError) + XCTAssertEqual(data, ByteBuffer(string: "idle")) + } + + try connection.waitUntilClosed() + } + + func testMaxIdleTimeWhenOpenStreams() throws { + let connection = try Connection(maxIdleTime: .minutes(5)) + try connection.activate() + + // Open a stream, the idle timer should be cancelled. + connection.streamOpened(1) + + // Advance by the idle time, nothing should happen. + connection.loop.advanceTime(by: .minutes(5)) + XCTAssertNil(try connection.readEvent()) + XCTAssertNil(try connection.readFrame()) + + // Close the stream, the idle timer should begin again. + connection.streamClosed(1) + connection.loop.advanceTime(by: .minutes(5)) + let frame = try XCTUnwrap(try connection.readFrame()) + XCTAssertGoAway(frame.payload) { lastStreamID, error, data in + XCTAssertEqual(lastStreamID, .rootStream) + XCTAssertEqual(error, .noError) + XCTAssertEqual(data, ByteBuffer(string: "idle")) + } + + try connection.waitUntilClosed() + } + + func testKeepAliveWithOpenStreams() throws { + let connection = try Connection(keepAliveTime: .minutes(1), keepAliveTimeout: .seconds(10)) + try connection.activate() + + // Open a stream so keep-alive starts. + connection.streamOpened(1) + + for _ in 0 ..< 10 { + // Advance time, a PING should be sent, ACK it. + connection.loop.advanceTime(by: .minutes(1)) + let frame1 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame1.streamID, .rootStream) + try XCTAssertPing(frame1.payload) { data, ack in + XCTAssertFalse(ack) + try connection.ping(data: data, ack: true) + } + + XCTAssertNil(try connection.readFrame()) + } + + // Close the stream, keep-alive pings should stop. + connection.streamClosed(1) + connection.loop.advanceTime(by: .minutes(1)) + XCTAssertNil(try connection.readFrame()) + } + + func testKeepAliveWithNoOpenStreams() throws { + let connection = try Connection(keepAliveTime: .minutes(1), allowKeepAliveWithoutCalls: true) + try connection.activate() + + for _ in 0 ..< 10 { + // Advance time, a PING should be sent, ACK it. + connection.loop.advanceTime(by: .minutes(1)) + let frame1 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame1.streamID, .rootStream) + try XCTAssertPing(frame1.payload) { data, ack in + XCTAssertFalse(ack) + try connection.ping(data: data, ack: true) + } + + XCTAssertNil(try connection.readFrame()) + } + } + + func testKeepAliveWithOpenStreamsTimingOut() throws { + let connection = try Connection(keepAliveTime: .minutes(1), keepAliveTimeout: .seconds(10)) + try connection.activate() + + // Open a stream so keep-alive starts. + connection.streamOpened(1) + + // Advance time, a PING should be sent, don't ACK it. + connection.loop.advanceTime(by: .minutes(1)) + let frame1 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame1.streamID, .rootStream) + XCTAssertPing(frame1.payload) { _, ack in + XCTAssertFalse(ack) + } + + // Advance time by the keep alive timeout. We should: + // - read a connection event + // - read out a GOAWAY frame + // - be closed + connection.loop.advanceTime(by: .seconds(10)) + + XCTAssertEqual(try connection.readEvent(), .closing(.keepAliveExpired)) + + let frame2 = try XCTUnwrap(connection.readFrame()) + XCTAssertEqual(frame2.streamID, .rootStream) + XCTAssertGoAway(frame2.payload) { lastStreamID, error, data in + XCTAssertEqual(lastStreamID, .rootStream) + XCTAssertEqual(error, .noError) + XCTAssertEqual(data, ByteBuffer(string: "keepalive_expired")) + } + + // Doesn't wait for streams to close: the connection is bad. + try connection.waitUntilClosed() + } + + func testPingsAreIgnored() throws { + let connection = try Connection() + try connection.activate() + + // PING frames without ack set should be ignored, we rely on the HTTP/2 handler replying to them. + try connection.ping(data: HTTP2PingData(), ack: false) + XCTAssertNil(try connection.readFrame()) + } + + func testReceiveGoAway() throws { + let connection = try Connection() + try connection.activate() + + try connection.goAway( + lastStreamID: 0, + errorCode: .enhanceYourCalm, + opaqueData: ByteBuffer(string: "too_many_pings") + ) + + // Should read out an event and close (because there are no open streams). + XCTAssertEqual( + try connection.readEvent(), + .closing(.goAway(.enhanceYourCalm, "too_many_pings")) + ) + try connection.waitUntilClosed() + } + + func testReceiveGoAwayWithOpenStreams() throws { + let connection = try Connection() + try connection.activate() + + connection.streamOpened(1) + connection.streamOpened(2) + connection.streamOpened(3) + + try connection.goAway(lastStreamID: .maxID, errorCode: .noError) + + // Should read out an event. + XCTAssertEqual(try connection.readEvent(), .closing(.goAway(.noError, ""))) + + // Close streams so the connection can close. + connection.streamClosed(1) + connection.streamClosed(2) + connection.streamClosed(3) + try connection.waitUntilClosed() + } +} + +extension ClientConnectionHandlerTests { + struct Connection { + let channel: EmbeddedChannel + var loop: EmbeddedEventLoop { + self.channel.embeddedEventLoop + } + + init( + maxIdleTime: TimeAmount? = nil, + keepAliveTime: TimeAmount? = nil, + keepAliveTimeout: TimeAmount? = nil, + allowKeepAliveWithoutCalls: Bool = false + ) throws { + let loop = EmbeddedEventLoop() + let handler = ClientConnectionHandler( + eventLoop: loop, + maxIdleTime: maxIdleTime, + keepAliveTime: keepAliveTime, + keepAliveTimeout: keepAliveTimeout, + keepAliveWithoutCalls: allowKeepAliveWithoutCalls + ) + + self.channel = EmbeddedChannel(handler: handler, loop: loop) + } + + func activate() throws { + try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait() + } + + func streamOpened(_ id: HTTP2StreamID) { + let event = NIOHTTP2StreamCreatedEvent( + streamID: id, + localInitialWindowSize: nil, + remoteInitialWindowSize: nil + ) + self.channel.pipeline.fireUserInboundEventTriggered(event) + } + + func streamClosed(_ id: HTTP2StreamID) { + let event = StreamClosedEvent(streamID: id, reason: nil) + self.channel.pipeline.fireUserInboundEventTriggered(event) + } + + func goAway( + lastStreamID: HTTP2StreamID, + errorCode: HTTP2ErrorCode, + opaqueData: ByteBuffer? = nil + ) throws { + let frame = HTTP2Frame( + streamID: .rootStream, + payload: .goAway(lastStreamID: lastStreamID, errorCode: errorCode, opaqueData: opaqueData) + ) + + try self.channel.writeInbound(frame) + } + + func ping(data: HTTP2PingData, ack: Bool) throws { + let frame = HTTP2Frame(streamID: .rootStream, payload: .ping(data, ack: ack)) + try self.channel.writeInbound(frame) + } + + func readFrame() throws -> HTTP2Frame? { + return try self.channel.readOutbound(as: HTTP2Frame.self) + } + + func readEvent() throws -> ClientConnectionEvent? { + return try self.channel.readInbound(as: ClientConnectionEvent.self) + } + + func waitUntilClosed() throws { + self.channel.embeddedEventLoop.run() + try self.channel.closeFuture.wait() + } + } +} diff --git a/Tests/GRPCHTTP2CoreTests/Internal/TimerTests.swift b/Tests/GRPCHTTP2CoreTests/Internal/TimerTests.swift new file mode 100644 index 000000000..fadfe4fd6 --- /dev/null +++ b/Tests/GRPCHTTP2CoreTests/Internal/TimerTests.swift @@ -0,0 +1,99 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Atomics +import NIOEmbedded +import XCTest + +@testable import GRPCHTTP2Core + +internal final class TimerTests: XCTestCase { + func testScheduleOneOffTimer() { + let loop = EmbeddedEventLoop() + defer { try! loop.close() } + + var value = 0 + + var timer = Timer(delay: .seconds(1), repeat: false) + timer.schedule(on: loop) { + XCTAssertEqual(value, 0) + value += 1 + } + + loop.advanceTime(by: .milliseconds(999)) + XCTAssertEqual(value, 0) + loop.advanceTime(by: .milliseconds(1)) + XCTAssertEqual(value, 1) + + // Run again to make sure the task wasn't repeated. + loop.advanceTime(by: .seconds(1)) + XCTAssertEqual(value, 1) + } + + func testCancelOneOffTimer() { + let loop = EmbeddedEventLoop() + defer { try! loop.close() } + + var timer = Timer(delay: .seconds(1), repeat: false) + timer.schedule(on: loop) { + XCTFail("Timer wasn't cancelled") + } + + loop.advanceTime(by: .milliseconds(999)) + timer.cancel() + loop.advanceTime(by: .milliseconds(1)) + } + + func testScheduleRepeatedTimer() throws { + let loop = EmbeddedEventLoop() + defer { try! loop.close() } + + var values = [Int]() + + var timer = Timer(delay: .seconds(1), repeat: true) + timer.schedule(on: loop) { + values.append(values.count) + } + + loop.advanceTime(by: .milliseconds(999)) + XCTAssertEqual(values, []) + loop.advanceTime(by: .milliseconds(1)) + XCTAssertEqual(values, [0]) + + loop.advanceTime(by: .seconds(1)) + XCTAssertEqual(values, [0, 1]) + loop.advanceTime(by: .seconds(1)) + XCTAssertEqual(values, [0, 1, 2]) + + timer.cancel() + loop.advanceTime(by: .seconds(1)) + XCTAssertEqual(values, [0, 1, 2]) + } + + func testCancelRepeatedTimer() { + let loop = EmbeddedEventLoop() + defer { try! loop.close() } + + var timer = Timer(delay: .seconds(1), repeat: true) + timer.schedule(on: loop) { + XCTFail("Timer wasn't cancelled") + } + + loop.advanceTime(by: .milliseconds(999)) + timer.cancel() + loop.advanceTime(by: .milliseconds(1)) + } +}