Skip to content

Commit

Permalink
Return close code from WebSocketClient.run/connect (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed Apr 18, 2024
1 parent 505d324 commit 1b9bdfc
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 111 deletions.
3 changes: 2 additions & 1 deletion Sources/HummingbirdWSClient/Client/ClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import NIOCore
@_documentation(visibility: internal)
public protocol ClientConnectionChannel: Sendable {
associatedtype Value: Sendable
associatedtype Result

/// Setup child channel
/// - Parameters:
Expand All @@ -31,5 +32,5 @@ public protocol ClientConnectionChannel: Sendable {
/// - Parameters:
/// - value: Object to process input/output on child channel
/// - logger: Logger to use while processing messages
func handle(value: Value, logger: Logger) async throws
func handle(value: Value, logger: Logger) async throws -> Result
}
5 changes: 3 additions & 2 deletions Sources/HummingbirdWSClient/Client/ClientConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import NIOPosix
import Network
import NIOTransportServices
#endif
import NIOWebSocket

/// A generic client connection to a server.
///
Expand Down Expand Up @@ -86,12 +87,12 @@ public struct ClientConnection<ClientChannel: ClientConnectionChannel>: Sendable
}
#endif

public func run() async throws {
public func run() async throws -> ClientChannel.Result {
let channelResult = try await self.makeClient(
clientChannel: self.clientChannel,
address: self.address
)
try await self.clientChannel.handle(value: channelResult, logger: self.logger)
return try await self.clientChannel.handle(value: channelResult, logger: self.logger)
}

/// Connect to server
Expand Down
4 changes: 3 additions & 1 deletion Sources/HummingbirdWSClient/Client/TLSClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import Logging
import NIOCore
import NIOSSL
import NIOWebSocket

/// Sets up client channel to use TLS before accessing base channel setup
@_documentation(visibility: internal)
public struct TLSClientChannel<BaseChannel: ClientConnectionChannel>: ClientConnectionChannel {
public typealias Value = BaseChannel.Value
public typealias Result = BaseChannel.Result

/// Initialize TLSChannel
/// - Parameters:
Expand Down Expand Up @@ -51,7 +53,7 @@ public struct TLSClientChannel<BaseChannel: ClientConnectionChannel>: ClientConn
/// - Parameters:
/// - value: Object to process input/output on child channel
/// - logger: Logger to use while processing messages
public func handle(value: BaseChannel.Value, logger: Logging.Logger) async throws {
public func handle(value: BaseChannel.Value, logger: Logging.Logger) async throws -> Result {
try await self.baseChannel.handle(value: value, logger: logger)
}

Expand Down
26 changes: 15 additions & 11 deletions Sources/HummingbirdWSClient/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import NIOCore
import NIOPosix
import NIOSSL
import NIOTransportServices
import NIOWebSocket

/// WebSocket client
///
Expand Down Expand Up @@ -111,8 +112,9 @@ public struct WebSocketClient {
}
#endif

/// Connect and run handler
public func run() async throws {
/// Connect and run handler
/// - Returns: WebSocket close frame details if server returned any
@discardableResult public func run() async throws -> WebSocketCloseFrame? {
guard let host = url.host else { throw WebSocketClientError.invalidURL }
let requiresTLS = self.url.scheme == .wss || self.url.scheme == .https
let port = self.url.port ?? (requiresTLS ? 443 : 80)
Expand All @@ -129,7 +131,7 @@ public struct WebSocketClient {
eventLoopGroup: self.eventLoopGroup,
logger: self.logger
)
try await client.run()
return try await client.run()

#if canImport(Network)
case .ts(let tlsOptions):
Expand All @@ -140,7 +142,7 @@ public struct WebSocketClient {
eventLoopGroup: self.eventLoopGroup,
logger: self.logger
)
try await client.run()
return try await client.run()

#endif
case .none:
Expand All @@ -158,7 +160,7 @@ public struct WebSocketClient {
eventLoopGroup: self.eventLoopGroup,
logger: self.logger
)
try await client.run()
return try await client.run()
}
} else {
let client = try ClientConnection(
Expand All @@ -171,7 +173,7 @@ public struct WebSocketClient {
eventLoopGroup: self.eventLoopGroup,
logger: self.logger
)
try await client.run()
return try await client.run()
}
}
}
Expand All @@ -186,14 +188,15 @@ extension WebSocketClient {
/// - eventLoopGroup: EventLoopGroup to run WebSocket client on
/// - logger: Logger
/// - process: Closure handling webSocket
public static func connect(
/// - Returns: WebSocket close frame details if server returned any
@discardableResult public static func connect(
url: String,
configuration: WebSocketClientConfiguration = .init(),
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
) async throws {
) async throws -> WebSocketCloseFrame? {
let ws = self.init(
url: url,
configuration: configuration,
Expand All @@ -202,7 +205,7 @@ extension WebSocketClient {
logger: logger,
handler: handler
)
try await ws.run()
return try await ws.run()
}

#if canImport(Network)
Expand All @@ -215,14 +218,15 @@ extension WebSocketClient {
/// - eventLoopGroup: EventLoopGroup to run WebSocket client on
/// - logger: Logger
/// - process: WebSocket data handler
/// - Returns: WebSocket close frame details if server returned any
public static func connect(
url: String,
configuration: WebSocketClientConfiguration = .init(),
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
handler: @escaping WebSocketDataHandler<BasicWebSocketContext>
) async throws {
) async throws -> WebSocketCloseFrame? {
let ws = self.init(
url: url,
configuration: configuration,
Expand All @@ -231,7 +235,7 @@ extension WebSocketClient {
logger: logger,
handler: handler
)
try await ws.run()
return try await ws.run()
}
#endif
}
4 changes: 2 additions & 2 deletions Sources/HummingbirdWSClient/WebSocketClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ struct WebSocketClientChannel: ClientConnectionChannel {
}
}

func handle(value: Value, logger: Logger) async throws {
func handle(value: Value, logger: Logger) async throws -> WebSocketCloseFrame? {
switch try await value.get() {
case .websocket(let webSocketChannel, let extensions):
await WebSocketHandler.handle(
return try await WebSocketHandler.handle(
type: .client,
configuration: .init(
extensions: extensions,
Expand Down
Loading

0 comments on commit 1b9bdfc

Please sign in to comment.