Skip to content

Commit

Permalink
Merge branch 'main' into update-limits
Browse files Browse the repository at this point in the history
  • Loading branch information
glbrntt committed May 22, 2024
2 parents b6df647 + 8b6a8f4 commit 0a9f2b2
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
/// Resets once `channelReadComplete` returns.
private var inReadLoop: Bool

/// The context of the channel this handler is in.
private var context: ChannelHandlerContext?

/// Creates a new handler which manages the lifecycle of a connection.
///
/// - Parameters:
Expand Down Expand Up @@ -118,6 +121,11 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl

func handlerAdded(context: ChannelHandlerContext) {
assert(context.eventLoop === self.eventLoop)
self.context = context
}

func handlerRemoved(context: ChannelHandlerContext) {
self.context = nil
}

func channelActive(context: ChannelHandlerContext) {
Expand All @@ -144,31 +152,10 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case let event as NIOHTTP2StreamCreatedEvent:
// Stream created, so the connection isn't idle.
self.maxIdleTimer?.cancel()
self.state.streamOpened(event.streamID)
self.streamCreated(event.streamID, channel: context.channel)

case let event as StreamClosedEvent:
switch self.state.streamClosed(event.streamID) {
case .startIdleTimer(let cancelKeepalive):
// All streams are closed, restart the idle timer, and stop the keep-alive timer (it may
// not stop if keep-alive is allowed when there are no active calls).
self.maxIdleTimer?.schedule(on: context.eventLoop) {
self.maxIdleTimerFired(context: context)
}

if cancelKeepalive {
self.keepaliveTimer?.cancel()
}

case .close:
// Connection was closing but waiting for all streams to close. They must all be closed
// now so close the connection.
context.close(promise: nil)

case .none:
()
}
self.streamClosed(event.streamID, channel: context.channel)

default:
()
Expand Down Expand Up @@ -263,6 +250,42 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
}
}

extension ClientConnectionHandler: NIOHTTP2StreamDelegate {
func streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
self.eventLoop.assertInEventLoop()

// Stream created, so the connection isn't idle.
self.maxIdleTimer?.cancel()
self.state.streamOpened(id)
}

func streamClosed(_ id: HTTP2StreamID, channel: any Channel) {
guard let context = self.context else { return }
self.eventLoop.assertInEventLoop()

switch self.state.streamClosed(id) {
case .startIdleTimer(let cancelKeepalive):
// All streams are closed, restart the idle timer, and stop the keep-alive timer (it may
// not stop if keep-alive is allowed when there are no active calls).
self.maxIdleTimer?.schedule(on: context.eventLoop) {
self.maxIdleTimerFired(context: context)
}

if cancelKeepalive {
self.keepaliveTimer?.cancel()
}

case .close:
// Connection was closing but waiting for all streams to close. They must all be closed
// now so close the connection.
context.close(promise: nil)

case .none:
()
}
}
}

extension ClientConnectionHandler {
private func maybeFlush(context: ChannelHandlerContext) {
if self.inReadLoop {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
/// Resets once `channelReadComplete` returns.
private var inReadLoop: Bool

/// The context of the channel this handler is in.
private var context: ChannelHandlerContext?

/// The current state of the connection.
private var state: StateMachine

Expand Down Expand Up @@ -236,6 +239,11 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {

func handlerAdded(context: ChannelHandlerContext) {
assert(context.eventLoop === self.eventLoop)
self.context = context
}

func handlerRemoved(context: ChannelHandlerContext) {
self.context = nil
}

func channelActive(context: ChannelHandlerContext) {
Expand Down Expand Up @@ -266,23 +274,10 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case let event as NIOHTTP2StreamCreatedEvent:
// The connection isn't idle if a stream is open.
self.maxIdleTimer?.cancel()
self.state.streamOpened(event.streamID)
self.streamCreated(event.streamID, channel: context.channel)

case let event as StreamClosedEvent:
switch self.state.streamClosed(event.streamID) {
case .startIdleTimer:
self.maxIdleTimer?.schedule(on: context.eventLoop) {
self.initiateGracefulShutdown(context: context)
}

case .close:
context.close(mode: .all, promise: nil)

case .none:
()
}
self.streamClosed(event.streamID, channel: context.channel)

default:
()
Expand Down Expand Up @@ -335,6 +330,31 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
}
}

extension ServerConnectionManagementHandler: NIOHTTP2StreamDelegate {
func streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
// The connection isn't idle if a stream is open.
self.maxIdleTimer?.cancel()
self.state.streamOpened(id)
}

func streamClosed(_ id: HTTP2StreamID, channel: any Channel) {
guard let context = self.context else { return }

switch self.state.streamClosed(id) {
case .startIdleTimer:
self.maxIdleTimer?.schedule(on: context.eventLoop) {
self.initiateGracefulShutdown(context: context)
}

case .close:
context.close(mode: .all, promise: nil)

case .none:
()
}
}
}

extension ServerConnectionManagementHandler {
private func maybeFlush(context: ChannelHandlerContext) {
if self.inReadLoop {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ final class ClientConnectionHandlerTests: XCTestCase {
extension ClientConnectionHandlerTests {
struct Connection {
let channel: EmbeddedChannel
let streamDelegate: any NIOHTTP2StreamDelegate
var loop: EmbeddedEventLoop {
self.channel.embeddedEventLoop
}
Expand All @@ -245,6 +246,7 @@ extension ClientConnectionHandlerTests {
keepaliveWithoutCalls: allowKeepaliveWithoutCalls
)

self.streamDelegate = handler
self.channel = EmbeddedChannel(handler: handler, loop: loop)
}

Expand All @@ -253,17 +255,11 @@ extension ClientConnectionHandlerTests {
}

func streamOpened(_ id: HTTP2StreamID) {
let event = NIOHTTP2StreamCreatedEvent(
streamID: id,
localInitialWindowSize: nil,
remoteInitialWindowSize: nil
)
self.channel.pipeline.fireUserInboundEventTriggered(event)
self.streamDelegate.streamCreated(id, channel: self.channel)
}

func streamClosed(_ id: HTTP2StreamID) {
let event = StreamClosedEvent(streamID: id, reason: nil)
self.channel.pipeline.fireUserInboundEventTriggered(event)
self.streamDelegate.streamClosed(id, channel: self.channel)
}

func goAway(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ extension ServerConnectionManagementHandlerTests {
extension ServerConnectionManagementHandlerTests {
struct Connection {
let channel: EmbeddedChannel
let streamDelegate: any NIOHTTP2StreamDelegate
let syncView: ServerConnectionManagementHandler.SyncView

var loop: EmbeddedEventLoop {
Expand Down Expand Up @@ -378,6 +379,7 @@ extension ServerConnectionManagementHandlerTests {
clock: self.clock
)

self.streamDelegate = handler
self.syncView = handler.syncView
self.channel = EmbeddedChannel(handler: handler, loop: loop)
}
Expand All @@ -398,17 +400,11 @@ extension ServerConnectionManagementHandlerTests {
}

func streamOpened(_ id: HTTP2StreamID) {
let event = NIOHTTP2StreamCreatedEvent(
streamID: id,
localInitialWindowSize: nil,
remoteInitialWindowSize: nil
)
self.channel.pipeline.fireUserInboundEventTriggered(event)
self.streamDelegate.streamCreated(id, channel: self.channel)
}

func streamClosed(_ id: HTTP2StreamID) {
let event = StreamClosedEvent(streamID: id, reason: nil)
self.channel.pipeline.fireUserInboundEventTriggered(event)
self.streamDelegate.streamClosed(id, channel: self.channel)
}

func ping(data: HTTP2PingData, ack: Bool) throws {
Expand Down

0 comments on commit 0a9f2b2

Please sign in to comment.