diff --git a/Package.swift b/Package.swift index 948bb21..8eac5ac 100644 --- a/Package.swift +++ b/Package.swift @@ -8,37 +8,50 @@ let package = Package( platforms: [.macOS(.v14), .iOS(.v17), .tvOS(.v17)], products: [ .library(name: "HummingbirdWebSocket", targets: ["HummingbirdWebSocket"]), + .library(name: "HummingbirdWSClient", targets: ["HummingbirdWSClient"]), .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]), ], dependencies: [ .package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "2.0.0-beta.2"), - .package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.0"), - .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), .package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"), .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.20.0"), .package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"), ], targets: [ .target(name: "HummingbirdWebSocket", dependencies: [ + .byName(name: "HummingbirdWSCore"), .product(name: "Hummingbird", package: "hummingbird"), - .product(name: "HummingbirdTLS", package: "hummingbird"), - .product(name: "AsyncAlgorithms", package: "swift-async-algorithms"), + .product(name: "NIOHTTPTypes", package: "swift-nio-extras"), + .product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"), + ]), + .target(name: "HummingbirdWSClient", dependencies: [ + .byName(name: "HummingbirdWSCore"), .product(name: "HTTPTypes", package: "swift-http-types"), + .product(name: "Logging", package: "swift-log"), .product(name: "NIOCore", package: "swift-nio"), - .product(name: "NIOHTTPTypes", package: "swift-nio-extras"), .product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"), + .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), + .product(name: "NIOWebSocket", package: "swift-nio"), + ]), + .target(name: "HummingbirdWSCore", dependencies: [ + .product(name: "HTTPTypes", package: "swift-http-types"), + .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOWebSocket", package: "swift-nio"), ]), .target(name: "HummingbirdWSCompression", dependencies: [ - .byName(name: "HummingbirdWebSocket"), + .byName(name: "HummingbirdWSCore"), .product(name: "CompressNIO", package: "compress-nio"), ]), .testTarget(name: "HummingbirdWebSocketTests", dependencies: [ .byName(name: "HummingbirdWebSocket"), + .byName(name: "HummingbirdWSClient"), .byName(name: "HummingbirdWSCompression"), - .product(name: "Atomics", package: "swift-atomics"), .product(name: "Hummingbird", package: "hummingbird"), .product(name: "HummingbirdTesting", package: "hummingbird"), .product(name: "HummingbirdTLS", package: "hummingbird"), diff --git a/Snippets/AutobahnClientTest.swift b/Snippets/AutobahnClientTest.swift new file mode 100644 index 0000000..05e4a77 --- /dev/null +++ b/Snippets/AutobahnClientTest.swift @@ -0,0 +1,32 @@ +import HummingbirdWSClient +import HummingbirdWSCompression +import Logging + +let cases = 1...1 + +var logger = Logger(label: "TestClient") +logger.logLevel = .trace +do { + for c in cases { + logger.info("Case \(c)") + try await WebSocketClient.connect( + url: .init("ws://127.0.0.1:9001/runCase?case=\(c)&agent=HB"), + configuration: .init(maxFrameSize: 1 << 16, extensions: [.perMessageDeflate(maxDecompressedFrameSize: 65536)]), + logger: logger + ) { inbound, outbound, _ in + for try await msg in inbound.messages(maxSize: .max) { + switch msg { + case .binary(let buffer): + try await outbound.write(.binary(buffer)) + case .text(let string): + try await outbound.write(.text(string)) + } + } + } + } + try await WebSocketClient.connect(url: .init("ws://127.0.0.1:9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in + for try await _ in inbound {} + } +} catch { + logger.error("Error: \(error)") +} diff --git a/Snippets/WebSocketClientTest.swift b/Snippets/WebSocketClientTest.swift new file mode 100644 index 0000000..cf91de1 --- /dev/null +++ b/Snippets/WebSocketClientTest.swift @@ -0,0 +1,19 @@ +import HummingbirdWSClient +import Logging + +var logger = Logger(label: "TestClient") +logger.logLevel = .trace +do { + try await WebSocketClient.connect( + url: .init("https://echo.websocket.org"), + configuration: .init(maxFrameSize: 1 << 16), + logger: logger + ) { inbound, outbound, _ in + try await outbound.write(.text("Hello")) + for try await msg in inbound.messages(maxSize: .max) { + print(msg) + } + } +} catch { + logger.error("Error: \(error)") +} diff --git a/Snippets/WebsocketTest.swift b/Snippets/WebsocketTest.swift index 1c40638..03a013b 100644 --- a/Snippets/WebsocketTest.swift +++ b/Snippets/WebsocketTest.swift @@ -14,6 +14,7 @@ router.get { _, _ in } router.ws("/ws") { inbound, outbound, _ in + try await outbound.write(.text("Hello")) for try await frame in inbound { if frame.opcode == .text, String(buffer: frame.data) == "disconnect", frame.fin == true { break diff --git a/Sources/HummingbirdWSClient/Client/ClientChannel.swift b/Sources/HummingbirdWSClient/Client/ClientChannel.swift new file mode 100644 index 0000000..b044df0 --- /dev/null +++ b/Sources/HummingbirdWSClient/Client/ClientChannel.swift @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore + +/// ClientConnection child channel setup protocol +public protocol ClientConnectionChannel: Sendable { + associatedtype Value: Sendable + + /// Setup child channel + /// - Parameters: + /// - channel: Child channel + /// - logger: Logger used during setup + /// - Returns: Object to process input/output on child channel + func setup(channel: Channel, logger: Logger) -> EventLoopFuture + + /// handle messages being passed down the channel pipeline + /// - Parameters: + /// - value: Object to process input/output on child channel + /// - logger: Logger to use while processing messages + func handle(value: Value, logger: Logger) async throws +} diff --git a/Sources/HummingbirdWSClient/Client/ClientConnection.swift b/Sources/HummingbirdWSClient/Client/ClientConnection.swift new file mode 100644 index 0000000..6a62820 --- /dev/null +++ b/Sources/HummingbirdWSClient/Client/ClientConnection.swift @@ -0,0 +1,171 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOPosix +#if canImport(Network) +import Network +import NIOTransportServices +#endif + +/// A generic client connection to a server. +/// +/// Actual client protocol is implemented in `ClientChannel` generic parameter +public struct ClientConnection: Sendable { + /// Address to connect to + public struct Address: Sendable, Equatable { + enum _Internal: Equatable { + case hostname(_ host: String, port: Int) + case unixDomainSocket(path: String) + } + + let value: _Internal + init(_ value: _Internal) { + self.value = value + } + + // Address define by host and port + public static func hostname(_ host: String, port: Int) -> Self { .init(.hostname(host, port: port)) } + // Address defined by unxi domain socket + public static func unixDomainSocket(path: String) -> Self { .init(.unixDomainSocket(path: path)) } + } + + typealias ChannelResult = ClientChannel.Value + /// Logger used by Server + let logger: Logger + let eventLoopGroup: EventLoopGroup + let clientChannel: ClientChannel + let address: Address + #if canImport(Network) + let tlsOptions: NWProtocolTLS.Options? + #endif + + /// Initialize Client + public init( + _ clientChannel: ClientChannel, + address: Address, + eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, + logger: Logger + ) { + self.clientChannel = clientChannel + self.address = address + self.eventLoopGroup = eventLoopGroup + self.logger = logger + #if canImport(Network) + self.tlsOptions = nil + #endif + } + + #if canImport(Network) + /// Initialize Client with TLS options + public init( + _ clientChannel: ClientChannel, + address: Address, + transportServicesTLSOptions: TSTLSOptions, + eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, + logger: Logger + ) throws { + self.clientChannel = clientChannel + self.address = address + self.eventLoopGroup = eventLoopGroup + self.logger = logger + self.tlsOptions = transportServicesTLSOptions.options + } + #endif + + public func run() async throws { + let channelResult = try await self.makeClient( + clientChannel: self.clientChannel, + address: self.address + ) + try await self.clientChannel.handle(value: channelResult, logger: self.logger) + } + + /// Connect to server + func makeClient(clientChannel: ClientChannel, address: Address) async throws -> ChannelResult { + // get bootstrap + let bootstrap: ClientBootstrapProtocol + #if canImport(Network) + if let tsBootstrap = self.createTSBootstrap() { + bootstrap = tsBootstrap + } else { + #if os(iOS) || os(tvOS) + self.logger.warning("Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework") + #endif + bootstrap = self.createSocketsBootstrap() + } + #else + bootstrap = self.createSocketsBootstrap() + #endif + + // connect + let result: ChannelResult + do { + switch address.value { + case .hostname(let host, let port): + result = try await bootstrap + .connect(host: host, port: port) { channel in + clientChannel.setup(channel: channel, logger: self.logger) + } + self.logger.debug("Client connnected to \(host):\(port)") + case .unixDomainSocket(let path): + result = try await bootstrap + .connect(unixDomainSocketPath: path) { channel in + clientChannel.setup(channel: channel, logger: self.logger) + } + self.logger.debug("Client connnected to socket path \(path)") + } + return result + } catch { + throw error + } + } + + /// create a BSD sockets based bootstrap + private func createSocketsBootstrap() -> ClientBootstrap { + return ClientBootstrap(group: self.eventLoopGroup) + } + + #if canImport(Network) + /// create a NIOTransportServices bootstrap using Network.framework + private func createTSBootstrap() -> NIOTSConnectionBootstrap? { + guard let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup) else { + return nil + } + if let tlsOptions { + return bootstrap.tlsOptions(tlsOptions) + } + return bootstrap + } + #endif +} + +protocol ClientBootstrapProtocol { + func connect( + host: String, + port: Int, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output + + func connect( + unixDomainSocketPath: String, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output +} + +extension ClientBootstrap: ClientBootstrapProtocol {} +#if canImport(Network) +extension NIOTSConnectionBootstrap: ClientBootstrapProtocol {} +#endif diff --git a/Sources/HummingbirdWSClient/Client/Parser.swift b/Sources/HummingbirdWSClient/Client/Parser.swift new file mode 100644 index 0000000..6bff3fb --- /dev/null +++ b/Sources/HummingbirdWSClient/Client/Parser.swift @@ -0,0 +1,661 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2021-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +// Half inspired by Reader class from John Sundell's Ink project +// https://github.com/JohnSundell/Ink/blob/master/Sources/Ink/Internal/Reader.swift +// with optimisation working ie removing String and doing my own UTF8 processing inspired by Fabian Fett's work in +// https://github.com/fabianfett/pure-swift-json/blob/master/Sources/PureSwiftJSONParsing/DocumentReader.swift + +/// Reader object for parsing String buffers +struct Parser: Sendable { + enum Error: Swift.Error { + case overflow + case unexpected + case emptyString + case invalidUTF8 + } + + /// Create a Parser object + /// - Parameter string: UTF8 data to parse + init?(_ utf8Data: some Collection, validateUTF8: Bool = true) { + if let buffer = utf8Data as? [UInt8] { + self.buffer = buffer + } else { + self.buffer = Array(utf8Data) + } + self.index = 0 + self.range = 0.. +} + +// MARK: sub-parsers + +extension Parser { + /// initialise a parser that parses a section of the buffer attached to another parser + private init(_ parser: Parser, range: Range) { + self.buffer = parser.buffer + self.index = range.startIndex + self.range = range + + precondition(range.startIndex >= 0 && range.endIndex <= self.buffer.endIndex) + precondition(range.startIndex == self.buffer.endIndex || self.buffer[range.startIndex] & 0xC0 != 0x80) // check we arent in the middle of a UTF8 character + } + + /// initialise a parser that parses a section of the buffer attached to this parser + func subParser(_ range: Range) -> Parser { + return Parser(self, range: range) + } +} + +extension Parser { + /// Return current character + /// - Throws: .overflow + /// - Returns: Current character + mutating func character() throws -> Unicode.Scalar { + guard !self.reachedEnd() else { throw Error.overflow } + return unsafeCurrentAndAdvance() + } + + /// Read the current character and return if it is as intended. If character test returns true then move forward 1 + /// - Parameter char: character to compare against + /// - Throws: .overflow + /// - Returns: If current character was the one we expected + mutating func read(_ char: Unicode.Scalar) throws -> Bool { + let initialIndex = self.index + let c = try character() + guard c == char else { self.index = initialIndex; return false } + return true + } + + /// Read the current character and check if it is in a set of characters If character test returns true then move forward 1 + /// - Parameter characterSet: Set of characters to compare against + /// - Throws: .overflow + /// - Returns: If current character is in character set + mutating func read(_ characterSet: Set) throws -> Bool { + let initialIndex = self.index + let c = try character() + guard characterSet.contains(c) else { self.index = initialIndex; return false } + return true + } + + /// Compare characters at current position against provided string. If the characters are the same as string provided advance past string + /// - Parameter string: String to compare against + /// - Throws: .overflow, .emptyString + /// - Returns: If characters at current position equal string + mutating func read(_ string: String) throws -> Bool { + let initialIndex = self.index + guard string.count > 0 else { throw Error.emptyString } + let subString = try read(count: string.count) + guard subString.string == string else { self.index = initialIndex; return false } + return true + } + + /// Read next so many characters from buffer + /// - Parameter count: Number of characters to read + /// - Throws: .overflow + /// - Returns: The string read from the buffer + mutating func read(count: Int) throws -> Parser { + var count = count + var readEndIndex = self.index + while count > 0 { + guard readEndIndex != self.range.endIndex else { throw Error.overflow } + readEndIndex = skipUTF8Character(at: readEndIndex) + count -= 1 + } + let result = self.subParser(self.index.. Parser { + let startIndex = self.index + while !self.reachedEnd() { + if unsafeCurrent() == until { + return self.subParser(startIndex.., throwOnOverflow: Bool = true) throws -> Parser { + let startIndex = self.index + while !self.reachedEnd() { + if characterSet.contains(unsafeCurrent()) { + return self.subParser(startIndex.. Bool, throwOnOverflow: Bool = true) throws -> Parser { + let startIndex = self.index + while !self.reachedEnd() { + if until(unsafeCurrent()) { + return self.subParser(startIndex.., throwOnOverflow: Bool = true) throws -> Parser { + let startIndex = self.index + while !self.reachedEnd() { + if unsafeCurrent()[keyPath: keyPath] { + return self.subParser(startIndex.. Parser { + var untilString = untilString + return try untilString.withUTF8 { utf8 in + guard utf8.count > 0 else { throw Error.emptyString } + let startIndex = self.index + var foundIndex = self.index + var untilIndex = 0 + while !self.reachedEnd() { + if self.buffer[self.index] == utf8[untilIndex] { + if untilIndex == 0 { + foundIndex = self.index + } + untilIndex += 1 + if untilIndex == utf8.endIndex { + unsafeAdvance() + if skipToEnd == false { + self.index = foundIndex + } + let result = self.subParser(startIndex.. Parser { + let startIndex = self.index + self.index = self.range.endIndex + return self.subParser(startIndex.. Int { + var count = 0 + while !self.reachedEnd(), + unsafeCurrent() == `while` + { + unsafeAdvance() + count += 1 + } + return count + } + + /// Read while character at current position is in supplied set + /// - Parameter while: character set to check + /// - Returns: String read from buffer + @discardableResult mutating func read(while characterSet: Set) -> Parser { + let startIndex = self.index + while !self.reachedEnd(), + characterSet.contains(unsafeCurrent()) + { + unsafeAdvance() + } + return self.subParser(startIndex.. Bool) -> Parser { + let startIndex = self.index + while !self.reachedEnd(), + `while`(unsafeCurrent()) + { + unsafeAdvance() + } + return self.subParser(startIndex..) -> Parser { + let startIndex = self.index + while !self.reachedEnd(), + unsafeCurrent()[keyPath: keyPath] + { + unsafeAdvance() + } + return self.subParser(startIndex.. [Parser] { + var subParsers: [Parser] = [] + while !self.reachedEnd() { + do { + let section = try read(until: separator) + subParsers.append(section) + unsafeAdvance() + } catch { + if !self.reachedEnd() { + subParsers.append(self.readUntilTheEnd()) + } + } + } + return subParsers + } + + /// Return whether we have reached the end of the buffer + /// - Returns: Have we reached the end + func reachedEnd() -> Bool { + return self.index == self.range.endIndex + } +} + +/// Public versions of internal functions which include tests for overflow +extension Parser { + /// Return the character at the current position + /// - Throws: .overflow + /// - Returns: Unicode.Scalar + func current() -> Unicode.Scalar { + guard !self.reachedEnd() else { return Unicode.Scalar(0) } + return unsafeCurrent() + } + + /// Move forward one character + /// - Throws: .overflow + mutating func advance() throws { + guard !self.reachedEnd() else { throw Error.overflow } + return self.unsafeAdvance() + } + + /// Move forward so many character + /// - Parameter amount: number of characters to move forward + /// - Throws: .overflow + mutating func advance(by amount: Int) throws { + var amount = amount + while amount > 0 { + guard !self.reachedEnd() else { throw Error.overflow } + self.index = skipUTF8Character(at: self.index) + amount -= 1 + } + } + + /// Move backwards one character + /// - Throws: .overflow + mutating func retreat() throws { + guard self.index > self.range.startIndex else { throw Error.overflow } + self.index = backOneUTF8Character(at: self.index) + } + + /// Move back so many characters + /// - Parameter amount: number of characters to move back + /// - Throws: .overflow + mutating func retreat(by amount: Int) throws { + var amount = amount + while amount > 0 { + guard self.index > self.range.startIndex else { throw Error.overflow } + self.index = backOneUTF8Character(at: self.index) + amount -= 1 + } + } + + /// Move parser to beginning of string + mutating func moveToStart() { + self.index = self.range.startIndex + } + + /// Move parser to end of string + mutating func moveToEnd() { + self.index = self.range.endIndex + } + + mutating func unsafeAdvance() { + self.index = skipUTF8Character(at: self.index) + } + + mutating func unsafeAdvance(by amount: Int) { + var amount = amount + while amount > 0 { + self.index = skipUTF8Character(at: self.index) + amount -= 1 + } + } +} + +/// extend Parser to conform to Sequence +extension Parser: Sequence { + public typealias Element = Unicode.Scalar + + public func makeIterator() -> Iterator { + return Iterator(self) + } + + public struct Iterator: IteratorProtocol { + public typealias Element = Unicode.Scalar + + var parser: Parser + + init(_ parser: Parser) { + self.parser = parser + } + + public mutating func next() -> Unicode.Scalar? { + guard !self.parser.reachedEnd() else { return nil } + return self.parser.unsafeCurrentAndAdvance() + } + } +} + +// internal versions without checks +private extension Parser { + func unsafeCurrent() -> Unicode.Scalar { + return decodeUTF8Character(at: self.index).0 + } + + mutating func unsafeCurrentAndAdvance() -> Unicode.Scalar { + let (unicodeScalar, index) = decodeUTF8Character(at: self.index) + self.index = index + return unicodeScalar + } + + mutating func _setPosition(_ index: Int) { + self.index = index + } + + func makeString(_ bytes: Bytes) -> String where Bytes.Element == UInt8, Bytes.Index == Int { + if let string = bytes.withContiguousStorageIfAvailable({ String(decoding: $0, as: Unicode.UTF8.self) }) { + return string + } else { + return String(decoding: bytes, as: Unicode.UTF8.self) + } + } +} + +// UTF8 parsing +extension Parser { + func decodeUTF8Character(at index: Int) -> (Unicode.Scalar, Int) { + var index = index + let byte1 = UInt32(buffer[index]) + var value: UInt32 + if byte1 & 0xC0 == 0xC0 { + index += 1 + let byte2 = UInt32(buffer[index] & 0x3F) + if byte1 & 0xE0 == 0xE0 { + index += 1 + let byte3 = UInt32(buffer[index] & 0x3F) + if byte1 & 0xF0 == 0xF0 { + index += 1 + let byte4 = UInt32(buffer[index] & 0x3F) + value = (byte1 & 0x7) << 18 + byte2 << 12 + byte3 << 6 + byte4 + } else { + value = (byte1 & 0xF) << 12 + byte2 << 6 + byte3 + } + } else { + value = (byte1 & 0x1F) << 6 + byte2 + } + } else { + value = byte1 & 0x7F + } + let unicodeScalar = Unicode.Scalar(value)! + return (unicodeScalar, index + 1) + } + + func skipUTF8Character(at index: Int) -> Int { + if self.buffer[index] & 0x80 != 0x80 { return index + 1 } + if self.buffer[index + 1] & 0xC0 == 0x80 { return index + 2 } + if self.buffer[index + 2] & 0xC0 == 0x80 { return index + 3 } + return index + 4 + } + + func backOneUTF8Character(at index: Int) -> Int { + if self.buffer[index - 1] & 0xC0 != 0x80 { return index - 1 } + if self.buffer[index - 2] & 0xC0 != 0x80 { return index - 2 } + if self.buffer[index - 3] & 0xC0 != 0x80 { return index - 3 } + return index - 4 + } + + /// same as `decodeUTF8Character` but adds extra validation, so we can make assumptions later on in decode and skip + func validateUTF8Character(at index: Int) -> (Unicode.Scalar?, Int) { + var index = index + let byte1 = UInt32(buffer[index]) + var value: UInt32 + if byte1 & 0xC0 == 0xC0 { + index += 1 + let byte = UInt32(buffer[index]) + guard byte & 0xC0 == 0x80 else { return (nil, index) } + let byte2 = UInt32(byte & 0x3F) + if byte1 & 0xE0 == 0xE0 { + index += 1 + let byte = UInt32(buffer[index]) + guard byte & 0xC0 == 0x80 else { return (nil, index) } + let byte3 = UInt32(byte & 0x3F) + if byte1 & 0xF0 == 0xF0 { + index += 1 + let byte = UInt32(buffer[index]) + guard byte & 0xC0 == 0x80 else { return (nil, index) } + let byte4 = UInt32(byte & 0x3F) + value = (byte1 & 0x7) << 18 + byte2 << 12 + byte3 << 6 + byte4 + } else { + value = (byte1 & 0xF) << 12 + byte2 << 6 + byte3 + } + } else { + value = (byte1 & 0x1F) << 6 + byte2 + } + } else { + value = byte1 & 0x7F + } + let unicodeScalar = Unicode.Scalar(value) + return (unicodeScalar, index + 1) + } + + /// return if the buffer is valid UTF8 + func validateUTF8() -> Bool { + var index = self.range.startIndex + while index < self.range.endIndex { + let (scalar, newIndex) = self.validateUTF8Character(at: index) + guard scalar != nil else { return false } + index = newIndex + } + return true + } + + private static let asciiHexValues: [UInt8] = [ + /* 00 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 08 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 10 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 18 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 20 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 28 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 30 */ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + /* 38 */ 0x08, 0x09, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 40 */ 0x80, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x80, + /* 48 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 50 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 58 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 60 */ 0x80, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x80, + /* 68 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 70 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 78 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + + /* 80 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 88 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 90 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* 98 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* A0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* A8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* B0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* B8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* C0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* C8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* D0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* D8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* E0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* E8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* F0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + /* F8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + ] + + /// percent decode UTF8 + public func percentDecode() -> String? { + struct DecodeError: Swift.Error {} + func _percentDecode(_ original: ArraySlice, _ bytes: UnsafeMutableBufferPointer) throws -> Int { + var newIndex = 0 + var index = original.startIndex + while index < (original.endIndex - 2) { + // if we have found a percent sign + if original[index] == 0x25 { + let high = Self.asciiHexValues[Int(original[index + 1])] + let low = Self.asciiHexValues[Int(original[index + 2])] + index += 3 + if ((high | low) & 0x80) != 0 { + throw DecodeError() + } + bytes[newIndex] = (high << 4) | low + newIndex += 1 + } else { + bytes[newIndex] = original[index] + newIndex += 1 + index += 1 + } + } + while index < original.endIndex { + bytes[newIndex] = original[index] + newIndex += 1 + index += 1 + } + return newIndex + } + guard self.index != self.range.endIndex else { return "" } + do { + if #available(macOS 11, macCatalyst 14.0, iOS 14.0, tvOS 14.0, *) { + return try String(unsafeUninitializedCapacity: range.endIndex - index) { bytes -> Int in + return try _percentDecode(self.buffer[self.index.. { + init(_ string: String) { + self = Set(string.unicodeScalars) + } +} diff --git a/Sources/HummingbirdWSClient/Client/TLSClientChannel.swift b/Sources/HummingbirdWSClient/Client/TLSClientChannel.swift new file mode 100644 index 0000000..00c278c --- /dev/null +++ b/Sources/HummingbirdWSClient/Client/TLSClientChannel.swift @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOSSL + +/// Sets up client channel to use TLS before accessing base channel setup +public struct TLSClientChannel: ClientConnectionChannel { + public typealias Value = BaseChannel.Value + + /// Initialize TLSChannel + /// - Parameters: + /// - baseChannel: Base child channel wrap + /// - tlsConfiguration: TLS configuration + public init(_ baseChannel: BaseChannel, tlsConfiguration: TLSConfiguration, serverHostname: String? = nil) throws { + self.baseChannel = baseChannel + self.sslContext = try NIOSSLContext(configuration: tlsConfiguration) + self.serverHostname = serverHostname + } + + /// Setup child channel with TLS and the base channel setup + /// - Parameters: + /// - channel: Child channel + /// - logger: Logger used during setup + /// - Returns: Object to process input/output on child channel + @inlinable + public func setup(channel: Channel, logger: Logger) -> EventLoopFuture { + channel.eventLoop.makeCompletedFuture { + let sslHandler = try NIOSSLClientHandler(context: self.sslContext, serverHostname: self.serverHostname) + try channel.pipeline.syncOperations.addHandler(sslHandler) + }.flatMap { + self.baseChannel.setup(channel: channel, logger: logger) + } + } + + @inlinable + /// handle messages being passed down the channel pipeline + /// - Parameters: + /// - value: Object to process input/output on child channel + /// - logger: Logger to use while processing messages + public func handle(value: BaseChannel.Value, logger: Logging.Logger) async throws { + try await self.baseChannel.handle(value: value, logger: logger) + } + + @usableFromInline + let sslContext: NIOSSLContext + @usableFromInline + let serverHostname: String? + @usableFromInline + var baseChannel: BaseChannel +} diff --git a/Sources/HummingbirdWSClient/Client/TSTLSOptions.swift b/Sources/HummingbirdWSClient/Client/TSTLSOptions.swift new file mode 100644 index 0000000..3d721f1 --- /dev/null +++ b/Sources/HummingbirdWSClient/Client/TSTLSOptions.swift @@ -0,0 +1,175 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2021-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Network) +import Foundation +import Network +import Security + +/// Wrapper for NIO transport services TLS options +public struct TSTLSOptions: Sendable { + public struct Error: Swift.Error, Equatable { + enum _Internal: Equatable { + case invalidFormat + case interactionNotAllowed + case verificationFailed + } + + private let value: _Internal + init(_ value: _Internal) { + self.value = value + } + + // invalid format + public static var invalidFormat: Self { .init(.invalidFormat) } + // unable to import p12 as no interaction is allowed + public static var interactionNotAllowed: Self { .init(.interactionNotAllowed) } + // MAC verification failed during PKCS12 import (wrong password?) + public static var verificationFailed: Self { .init(.verificationFailed) } + } + + public struct Identity { + let secIdentity: SecIdentity + + public static func secIdentity(_ secIdentity: SecIdentity) -> Self { + return .init(secIdentity: secIdentity) + } + + public static func p12(filename: String, password: String) throws -> Self { + guard let secIdentity = try Self.loadP12(filename: filename, password: password) else { throw Error.invalidFormat } + return .init(secIdentity: secIdentity) + } + + private static func loadP12(filename: String, password: String) throws -> SecIdentity? { + let data = try Data(contentsOf: URL(fileURLWithPath: filename)) + let options: [String: String] = [kSecImportExportPassphrase as String: password] + var rawItems: CFArray? + let result = SecPKCS12Import(data as CFData, options as CFDictionary, &rawItems) + switch result { + case errSecSuccess: + break + case errSecInteractionNotAllowed: + throw Error.interactionNotAllowed + case errSecPkcs12VerifyFailure: + throw Error.verificationFailed + default: + throw Error.invalidFormat + } + let items = rawItems! as! [[String: Any]] + let firstItem = items[0] + return firstItem[kSecImportItemIdentity as String] as! SecIdentity? + } + } + + /// Struct defining an array of certificates + public struct Certificates { + let certificates: [SecCertificate] + + /// Create certificate array from already loaded SecCertificate array + public static var none: Self { .init(certificates: []) } + + /// Create certificate array from already loaded SecCertificate array + public static func certificates(_ secCertificates: [SecCertificate]) -> Self { .init(certificates: secCertificates) } + + /// Create certificate array from DER file + public static func der(filename: String) throws -> Self { + let certificateData = try Data(contentsOf: URL(fileURLWithPath: filename)) + guard let secCertificate = SecCertificateCreateWithData(nil, certificateData as CFData) else { throw Error.invalidFormat } + return .init(certificates: [secCertificate]) + } + } + + /// Initialize TSTLSOptions + public init(_ options: NWProtocolTLS.Options?) { + if let options { + self.value = .some(options) + } else { + self.value = .none + } + } + + /// TSTLSOptions holding options + public static func options(_ options: NWProtocolTLS.Options) -> Self { + return .init(value: .some(options)) + } + + public static func options( + serverIdentity: Identity + ) -> Self? { + let options = NWProtocolTLS.Options() + + // server identity + guard let secIdentity = sec_identity_create(serverIdentity.secIdentity) else { return nil } + sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) + + return .init(value: .some(options)) + } + + public static func options( + clientIdentity: Identity, trustRoots: Certificates = .none, serverName: String? = nil + ) -> Self? { + let options = NWProtocolTLS.Options() + + // server identity + guard let secIdentity = sec_identity_create(clientIdentity.secIdentity) else { return nil } + sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) + if let serverName { + sec_protocol_options_set_tls_server_name(options.securityProtocolOptions, serverName) + } + // sec_protocol_options_set + sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) + + // add verify block to control certificate verification + if trustRoots.certificates.count > 0 { + sec_protocol_options_set_verify_block( + options.securityProtocolOptions, + { _, sec_trust, sec_protocol_verify_complete in + let trust = sec_trust_copy_ref(sec_trust).takeRetainedValue() + SecTrustSetAnchorCertificates(trust, trustRoots.certificates as CFArray) + SecTrustEvaluateAsyncWithError(trust, Self.tlsDispatchQueue) { _, result, error in + if let error { + print("Trust failed: \(error.localizedDescription)") + } + sec_protocol_verify_complete(result) + } + }, Self.tlsDispatchQueue + ) + } + return .init(value: .some(options)) + } + + /// Empty TSTLSOptions + public static var none: Self { + return .init(value: .none) + } + + var options: NWProtocolTLS.Options? { + if case .some(let options) = self.value { return options } + return nil + } + + /// Internal storage for TSTLSOptions. @unchecked Sendable while NWProtocolTLS.Options + /// is not Sendable + private enum Internal: @unchecked Sendable { + case some(NWProtocolTLS.Options) + case none + } + + private let value: Internal + private init(value: Internal) { self.value = value } + + /// Dispatch queue used by Network framework TLS to control certificate verification + static let tlsDispatchQueue = DispatchQueue(label: "WSTSTLSConfiguration") +} +#endif diff --git a/Sources/HummingbirdWSClient/Client/URI.swift b/Sources/HummingbirdWSClient/Client/URI.swift new file mode 100644 index 0000000..bd816bd --- /dev/null +++ b/Sources/HummingbirdWSClient/Client/URI.swift @@ -0,0 +1,134 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2021-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// Simple URL parser +struct URI: Sendable, CustomStringConvertible, ExpressibleByStringLiteral { + struct Scheme: RawRepresentable, Equatable { + let rawValue: String + + init(rawValue: String) { + self.rawValue = rawValue + } + + static var http: Self { return .init(rawValue: "http") } + static var https: Self { return .init(rawValue: "https") } + static var unix: Self { return .init(rawValue: "unix") } + static var http_unix: Self { return .init(rawValue: "http_unix") } + static var https_unix: Self { return .init(rawValue: "https_unix") } + static var ws: Self { return .init(rawValue: "ws") } + static var wss: Self { return .init(rawValue: "wss") } + } + + let string: String + + /// URL scheme + var scheme: Scheme? { return self._scheme.map { .init(rawValue: $0.string) } } + /// URL host + var host: String? { return self._host.map(\.string) } + /// URL port + var port: Int? { return self._port.map { Int($0.string) } ?? nil } + /// URL path + var path: String { return self._path.map(\.string) ?? "/" } + /// URL query + var query: String? { return self._query.map { String($0.string) }} + + private let _scheme: Parser? + private let _host: Parser? + private let _port: Parser? + private let _path: Parser? + private let _query: Parser? + + var description: String { self.string } + + /// Initialize `URI` from `String` + /// - Parameter string: input string + init(_ string: String) { + enum ParsingState { + case readingScheme + case readingHost + case readingPort + case readingPath + case readingQuery + case finished + } + var scheme: Parser? + var host: Parser? + var port: Parser? + var path: Parser? + var query: Parser? + var state: ParsingState = .readingScheme + if string.first == "/" { + state = .readingPath + } + + var parser = Parser(string) + while state != .finished { + if parser.reachedEnd() { break } + switch state { + case .readingScheme: + // search for "://" to find scheme and host + scheme = try? parser.read(untilString: "://", skipToEnd: true) + if scheme != nil { + state = .readingHost + } else { + state = .readingPath + } + + case .readingHost: + let h = try! parser.read(until: Self.hostEndSet, throwOnOverflow: false) + if h.count != 0 { + host = h + } + if parser.current() == ":" { + state = .readingPort + } else if parser.current() == "?" { + state = .readingQuery + } else { + state = .readingPath + } + + case .readingPort: + parser.unsafeAdvance() + port = try! parser.read(until: Self.portEndSet, throwOnOverflow: false) + state = .readingPath + + case .readingPath: + path = try! parser.read(until: "?", throwOnOverflow: false) + state = .readingQuery + + case .readingQuery: + parser.unsafeAdvance() + query = try! parser.read(until: "#", throwOnOverflow: false) + state = .finished + + case .finished: + break + } + } + + self.string = string + self._scheme = scheme + self._host = host + self._port = port + self._path = path + self._query = query + } + + init(stringLiteral value: String) { + self.init(value) + } + + private static let hostEndSet: Set = Set(":/?") + private static let portEndSet: Set = Set("/?") +} diff --git a/Sources/HummingbirdWSClient/Exports.swift b/Sources/HummingbirdWSClient/Exports.swift new file mode 100644 index 0000000..6a282d2 --- /dev/null +++ b/Sources/HummingbirdWSClient/Exports.swift @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@_exported import HummingbirdWSCore diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift b/Sources/HummingbirdWSClient/WebSocketClient.swift similarity index 92% rename from Sources/HummingbirdWebSocket/Client/WebSocketClient.swift rename to Sources/HummingbirdWSClient/WebSocketClient.swift index 40b1ff4..bc8681c 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift +++ b/Sources/HummingbirdWSClient/WebSocketClient.swift @@ -13,13 +13,12 @@ //===----------------------------------------------------------------------===// import HTTPTypes -import HummingbirdCore -import HummingbirdTLS +import HummingbirdWSCore import Logging import NIOCore import NIOPosix +import NIOSSL import NIOTransportServices -import ServiceLifecycle /// WebSocket client /// @@ -70,14 +69,14 @@ public struct WebSocketClient { /// - eventLoopGroup: EventLoopGroup to run WebSocket client on /// - logger: Logger public init( - url: URI, + url: String, configuration: WebSocketClientConfiguration = .init(), tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, handler: @escaping WebSocketDataHandler ) { - self.url = url + self.url = .init(url) self.handler = handler self.configuration = configuration self.eventLoopGroup = eventLoopGroup @@ -96,14 +95,14 @@ public struct WebSocketClient { /// - eventLoopGroup: EventLoopGroup to run WebSocket client on /// - logger: Logger public init( - url: URI, + url: String, configuration: WebSocketClientConfiguration = .init(), transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, handler: @escaping WebSocketDataHandler ) { - self.url = url + self.url = .init(url) self.handler = handler self.configuration = configuration self.eventLoopGroup = eventLoopGroup @@ -117,15 +116,14 @@ public struct WebSocketClient { guard let host = url.host else { throw WebSocketClientError.invalidURL } let requiresTLS = self.url.scheme == .wss || self.url.scheme == .https let port = self.url.port ?? (requiresTLS ? 443 : 80) - // url path must include query values as well - let urlPath = self.url.path + (self.url.query.map { "?\($0)" } ?? "") if requiresTLS { switch self.tlsConfiguration { case .niossl(let tlsConfiguration): let client = try ClientConnection( TLSClientChannel( - WebSocketClientChannel(handler: handler, url: urlPath, configuration: self.configuration), - tlsConfiguration: tlsConfiguration + WebSocketClientChannel(handler: handler, url: url, configuration: self.configuration), + tlsConfiguration: tlsConfiguration, + serverHostname: host ), address: .hostname(host, port: port), eventLoopGroup: self.eventLoopGroup, @@ -136,7 +134,7 @@ public struct WebSocketClient { #if canImport(Network) case .ts(let tlsOptions): let client = try ClientConnection( - WebSocketClientChannel(handler: handler, url: urlPath, configuration: self.configuration), + WebSocketClientChannel(handler: handler, url: url, configuration: self.configuration), address: .hostname(host, port: port), transportServicesTLSOptions: tlsOptions, eventLoopGroup: self.eventLoopGroup, @@ -150,10 +148,11 @@ public struct WebSocketClient { TLSClientChannel( WebSocketClientChannel( handler: handler, - url: urlPath, + url: url, configuration: self.configuration ), - tlsConfiguration: TLSConfiguration.makeClientConfiguration() + tlsConfiguration: TLSConfiguration.makeClientConfiguration(), + serverHostname: host ), address: .hostname(host, port: port), eventLoopGroup: self.eventLoopGroup, @@ -162,10 +161,10 @@ public struct WebSocketClient { try await client.run() } } else { - let client = ClientConnection( + let client = try ClientConnection( WebSocketClientChannel( handler: handler, - url: urlPath, + url: url, configuration: self.configuration ), address: .hostname(host, port: port), @@ -188,7 +187,7 @@ extension WebSocketClient { /// - logger: Logger /// - process: Closure handling webSocket public static func connect( - url: URI, + url: String, configuration: WebSocketClientConfiguration = .init(), tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, @@ -217,7 +216,7 @@ extension WebSocketClient { /// - logger: Logger /// - process: WebSocket data handler public static func connect( - url: URI, + url: String, configuration: WebSocketClientConfiguration = .init(), transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, @@ -236,6 +235,3 @@ extension WebSocketClient { } #endif } - -/// conform to Service -extension WebSocketClient: Service {} diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift b/Sources/HummingbirdWSClient/WebSocketClientChannel.swift similarity index 80% rename from Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift rename to Sources/HummingbirdWSClient/WebSocketClientChannel.swift index 0c68e93..4215dbf 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift +++ b/Sources/HummingbirdWSClient/WebSocketClientChannel.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// import HTTPTypes -import HummingbirdCore +import HummingbirdWSCore import Logging import NIOCore import NIOHTTP1 @@ -28,12 +28,15 @@ struct WebSocketClientChannel: ClientConnectionChannel { typealias Value = EventLoopFuture - let url: String + let urlPath: String + let hostHeader: String let handler: WebSocketDataHandler let configuration: WebSocketClientConfiguration - init(handler: @escaping WebSocketDataHandler, url: String, configuration: WebSocketClientConfiguration) { - self.url = url + init(handler: @escaping WebSocketDataHandler, url: URI, configuration: WebSocketClientConfiguration) throws { + guard let hostHeader = Self.urlHostHeader(for: url) else { throw WebSocketClientError.invalidURL } + self.hostHeader = hostHeader + self.urlPath = Self.urlPath(for: url) self.handler = handler self.configuration = configuration } @@ -57,8 +60,8 @@ struct WebSocketClientChannel: ClientConnectionChannel { ) var headers = HTTPHeaders() - headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") headers.add(name: "Content-Length", value: "0") + headers.add(name: "Host", value: self.hostHeader) let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders) headers.add(contentsOf: additionalHeaders) // add websocket extensions to headers @@ -67,7 +70,7 @@ struct WebSocketClientChannel: ClientConnectionChannel { let requestHead = HTTPRequestHead( version: .http1_1, method: .GET, - uri: self.url, + uri: self.urlPath, headers: headers ) @@ -81,8 +84,10 @@ struct WebSocketClientChannel: ClientConnectionChannel { } ) + var pipelineConfiguration = NIOUpgradableHTTPClientPipelineConfiguration(upgradeConfiguration: clientUpgradeConfiguration) + pipelineConfiguration.leftOverBytesStrategy = .forwardBytes let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( - configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) + configuration: pipelineConfiguration ) return negotiationResultFuture @@ -108,4 +113,17 @@ struct WebSocketClientChannel: ClientConnectionChannel { throw WebSocketClientError.webSocketUpgradeFailed } } + + static func urlPath(for url: URI) -> String { + url.path + (url.query.map { "?\($0)" } ?? "") + } + + static func urlHostHeader(for url: URI) -> String? { + guard let host = url.host else { return nil } + if let port = url.port { + return "\(host):\(port)" + } else { + return host + } + } } diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift b/Sources/HummingbirdWSClient/WebSocketClientConfiguration.swift similarity index 98% rename from Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift rename to Sources/HummingbirdWSClient/WebSocketClientConfiguration.swift index 0cf3f51..6b41d83 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift +++ b/Sources/HummingbirdWSClient/WebSocketClientConfiguration.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import HTTPTypes +import HummingbirdWSCore /// Configuration for a client connecting to a WebSocket public struct WebSocketClientConfiguration: Sendable { diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientError.swift b/Sources/HummingbirdWSClient/WebSocketClientError.swift similarity index 73% rename from Sources/HummingbirdWebSocket/Client/WebSocketClientError.swift rename to Sources/HummingbirdWSClient/WebSocketClientError.swift index a22682d..cb7e763 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientError.swift +++ b/Sources/HummingbirdWSClient/WebSocketClientError.swift @@ -27,5 +27,14 @@ public struct WebSocketClientError: Swift.Error, Equatable { /// Provided URL is invalid public static var invalidURL: Self { .init(.invalidURL) } /// WebSocket upgrade failed. - public static var webSocketUpgradeFailed: Self { .init(.invalidURL) } + public static var webSocketUpgradeFailed: Self { .init(.webSocketUpgradeFailed) } +} + +extension WebSocketClientError: CustomStringConvertible { + public var description: String { + switch self.value { + case .invalidURL: "Invalid URL" + case .webSocketUpgradeFailed: "WebSocket upgrade failed" + } + } } diff --git a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift index ab8b7cb..89c44a4 100644 --- a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift +++ b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// import CompressNIO -import HummingbirdWebSocket +import HummingbirdWSCore import NIOCore import NIOWebSocket diff --git a/Sources/HummingbirdWebSocket/UnsafeTransfer.swift b/Sources/HummingbirdWSCore/UnsafeTransfer.swift similarity index 100% rename from Sources/HummingbirdWebSocket/UnsafeTransfer.swift rename to Sources/HummingbirdWSCore/UnsafeTransfer.swift diff --git a/Sources/HummingbirdWebSocket/WebSocketContext.swift b/Sources/HummingbirdWSCore/WebSocketContext.swift similarity index 87% rename from Sources/HummingbirdWebSocket/WebSocketContext.swift rename to Sources/HummingbirdWSCore/WebSocketContext.swift index 083532d..901c776 100644 --- a/Sources/HummingbirdWebSocket/WebSocketContext.swift +++ b/Sources/HummingbirdWSCore/WebSocketContext.swift @@ -27,4 +27,9 @@ public protocol WebSocketContext: Sendable { public struct BasicWebSocketContext: WebSocketContext { public let allocator: ByteBufferAllocator public let logger: Logger + + package init(allocator: ByteBufferAllocator, logger: Logger) { + self.allocator = allocator + self.logger = logger + } } diff --git a/Sources/HummingbirdWebSocket/WebSocketDataFrame.swift b/Sources/HummingbirdWSCore/WebSocketDataFrame.swift similarity index 100% rename from Sources/HummingbirdWebSocket/WebSocketDataFrame.swift rename to Sources/HummingbirdWSCore/WebSocketDataFrame.swift diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWSCore/WebSocketDataHandler.swift similarity index 88% rename from Sources/HummingbirdWebSocket/WebSocketDataHandler.swift rename to Sources/HummingbirdWSCore/WebSocketDataHandler.swift index 51daf02..07f0236 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWSCore/WebSocketDataHandler.swift @@ -12,12 +12,6 @@ // //===----------------------------------------------------------------------===// -import AsyncAlgorithms -import HTTPTypes -import Logging -import NIOCore -import NIOWebSocket - /// Function that handles websocket data and text blocks public typealias WebSocketDataHandler = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void diff --git a/Sources/HummingbirdWebSocket/WebSocketExtension.swift b/Sources/HummingbirdWSCore/WebSocketExtension.swift similarity index 98% rename from Sources/HummingbirdWebSocket/WebSocketExtension.swift rename to Sources/HummingbirdWSCore/WebSocketExtension.swift index 9239511..9418d2f 100644 --- a/Sources/HummingbirdWebSocket/WebSocketExtension.swift +++ b/Sources/HummingbirdWSCore/WebSocketExtension.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// +import Foundation import HTTPTypes -import NIOCore import NIOWebSocket /// Protocol for WebSocket extension @@ -117,7 +117,7 @@ public struct WebSocketExtensionHTTPParameters: Sendable, Equatable { } public let parameters: [String: Parameter] - let name: String + public let name: String /// initialise WebSocket extension parameters from string init?(from header: some StringProtocol) { @@ -156,7 +156,7 @@ public struct WebSocketExtensionHTTPParameters: Sendable, Equatable { extension WebSocketExtensionHTTPParameters { /// Initialiser used by tests - init(_ name: String, parameters: [String: Parameter]) { + package init(_ name: String, parameters: [String: Parameter]) { self.name = name self.parameters = parameters } diff --git a/Sources/HummingbirdWebSocket/WebSocketFrameSequence.swift b/Sources/HummingbirdWSCore/WebSocketFrameSequence.swift similarity index 100% rename from Sources/HummingbirdWebSocket/WebSocketFrameSequence.swift rename to Sources/HummingbirdWSCore/WebSocketFrameSequence.swift diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWSCore/WebSocketHandler.swift similarity index 95% rename from Sources/HummingbirdWebSocket/WebSocketHandler.swift rename to Sources/HummingbirdWSCore/WebSocketHandler.swift index cb04ce3..a225178 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWSCore/WebSocketHandler.swift @@ -18,7 +18,7 @@ import NIOWebSocket import ServiceLifecycle /// WebSocket type -enum WebSocketType: Sendable { +package enum WebSocketType: Sendable { case client case server } @@ -45,14 +45,19 @@ public struct AutoPingSetup: Sendable { /// /// Manages ping, pong and close messages. Collates data and text messages into final frame /// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user. -actor WebSocketHandler { +package actor WebSocketHandler { enum InternalError: Error { case close(WebSocketErrorCode) } - struct Configuration { + package struct Configuration { let extensions: [any WebSocketExtension] let autoPing: AutoPingSetup + + package init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup) { + self.extensions = extensions + self.autoPing = autoPing + } } static let pingDataSize = 16 @@ -78,7 +83,7 @@ actor WebSocketHandler { self.closed = false } - static func handle( + package static func handle( type: WebSocketType, configuration: Configuration, asyncChannel: NIOAsyncChannel, @@ -242,7 +247,8 @@ actor WebSocketHandler { var buffer = self.context.allocator.buffer(capacity: 2) buffer.write(webSocketErrorCode: code) - try await self.outbound.write(.init(fin: true, opcode: .connectionClose, data: buffer)) + + try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) // Only server should initiate a connection close. Clients should wait for the // server to close the connection when it receives the WebSocket close packet // See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1 diff --git a/Sources/HummingbirdWebSocket/WebSocketInboundMessageStream.swift b/Sources/HummingbirdWSCore/WebSocketInboundMessageStream.swift similarity index 98% rename from Sources/HummingbirdWebSocket/WebSocketInboundMessageStream.swift rename to Sources/HummingbirdWSCore/WebSocketInboundMessageStream.swift index 5a9f754..6fa5aa1 100644 --- a/Sources/HummingbirdWebSocket/WebSocketInboundMessageStream.swift +++ b/Sources/HummingbirdWSCore/WebSocketInboundMessageStream.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -import NIOConcurrencyHelpers import NIOCore import NIOWebSocket diff --git a/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift b/Sources/HummingbirdWSCore/WebSocketInboundStream.swift similarity index 84% rename from Sources/HummingbirdWebSocket/WebSocketInboundStream.swift rename to Sources/HummingbirdWSCore/WebSocketInboundStream.swift index 8f9ec06..b69d1c8 100644 --- a/Sources/HummingbirdWebSocket/WebSocketInboundStream.swift +++ b/Sources/HummingbirdWSCore/WebSocketInboundStream.swift @@ -62,8 +62,18 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable { case .connectionClose: // we received a connection close. // send a close back if it hasn't already been send and exit - _ = try await self.handler.close(code: .normalClosure) - self.closed = true + var data = frame.unmaskedData + let dataSize = data.readableBytes + let closeCode = data.readWebSocketErrorCode() + if dataSize == 0 || closeCode != nil { + if case .unknown = closeCode { + _ = try await self.handler.close(code: .protocolError) + } else { + _ = try await self.handler.close(code: .normalClosure) + } + } else { + _ = try await self.handler.close(code: .protocolError) + } return nil case .ping: try await self.handler.onPing(frame) @@ -77,9 +87,12 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable { } return .init(from: frame) default: - break + // if we receive a reserved opcode we should fail the connection + self.handler.context.logger.trace("Received reserved opcode", metadata: ["opcode": .stringConvertible(frame.opcode)]) + throw WebSocketHandler.InternalError.close(.protocolError) } } catch { + self.handler.context.logger.trace("Error: \(error)") // catch errors while processing websocket frames so responding close message // can be dealt with let errorCode = WebSocketErrorCode(error) diff --git a/Sources/HummingbirdWebSocket/WebSocketMessage.swift b/Sources/HummingbirdWSCore/WebSocketMessage.swift similarity index 100% rename from Sources/HummingbirdWebSocket/WebSocketMessage.swift rename to Sources/HummingbirdWSCore/WebSocketMessage.swift diff --git a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift b/Sources/HummingbirdWSCore/WebSocketOutboundWriter.swift similarity index 99% rename from Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift rename to Sources/HummingbirdWSCore/WebSocketOutboundWriter.swift index 75eadb7..79490fa 100644 --- a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift +++ b/Sources/HummingbirdWSCore/WebSocketOutboundWriter.swift @@ -29,7 +29,7 @@ public struct WebSocketOutboundWriter: Sendable { case custom(WebSocketFrame) } - let handler: WebSocketHandler + package let handler: WebSocketHandler /// Write WebSocket frame public func write(_ frame: OutboundFrame) async throws { diff --git a/Sources/HummingbirdWebSocket/Exports.swift b/Sources/HummingbirdWebSocket/Exports.swift new file mode 100644 index 0000000..6a282d2 --- /dev/null +++ b/Sources/HummingbirdWebSocket/Exports.swift @@ -0,0 +1,15 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@_exported import HummingbirdWSCore diff --git a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift b/Sources/HummingbirdWebSocket/NIOWebSocketServerUpgrade+ext.swift similarity index 100% rename from Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift rename to Sources/HummingbirdWebSocket/NIOWebSocketServerUpgrade+ext.swift diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/WebSocketChannel.swift similarity index 99% rename from Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift rename to Sources/HummingbirdWebSocket/WebSocketChannel.swift index b273d5c..0b2ca4f 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/WebSocketChannel.swift @@ -15,6 +15,7 @@ import HTTPTypes import Hummingbird import HummingbirdCore +import HummingbirdWSCore import Logging import NIOConcurrencyHelpers import NIOCore diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift b/Sources/HummingbirdWebSocket/WebSocketHTTPChannelBuilder.swift similarity index 98% rename from Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift rename to Sources/HummingbirdWebSocket/WebSocketHTTPChannelBuilder.swift index 5d1329b..38a4658 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHTTPChannelBuilder.swift @@ -15,6 +15,7 @@ import HTTPTypes import Hummingbird import HummingbirdCore +import HummingbirdWSCore import Logging import NIOCore diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/WebSocketRouter.swift similarity index 99% rename from Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift rename to Sources/HummingbirdWebSocket/WebSocketRouter.swift index e5d165c..e8ed348 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/WebSocketRouter.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import Atomics import HTTPTypes import Hummingbird import HummingbirdCore +import HummingbirdWSCore import Logging import NIOConcurrencyHelpers import NIOCore diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift b/Sources/HummingbirdWebSocket/WebSocketServerConfiguration.swift similarity index 98% rename from Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift rename to Sources/HummingbirdWebSocket/WebSocketServerConfiguration.swift index 5761cef..070bd0a 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift +++ b/Sources/HummingbirdWebSocket/WebSocketServerConfiguration.swift @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +import HummingbirdWSCore + /// Configuration for a WebSocket server public struct WebSocketServerConfiguration: Sendable { /// Max websocket frame size that can be sent/received diff --git a/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift index 08ae903..b8bd881 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift @@ -14,8 +14,10 @@ import Hummingbird import HummingbirdCore -@testable import HummingbirdWebSocket +import HummingbirdWebSocket +import HummingbirdWSClient @testable import HummingbirdWSCompression +@testable import HummingbirdWSCore import Logging import NIOCore import NIOWebSocket diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index c34551f..5b9bd44 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -18,6 +18,7 @@ import HummingbirdCore import HummingbirdTesting import HummingbirdTLS import HummingbirdWebSocket +import HummingbirdWSClient import Logging import NIOCore import NIOPosix diff --git a/scripts/autobahn-config/fuzzingserver.json b/scripts/autobahn-config/fuzzingserver.json new file mode 100644 index 0000000..4bc4fc5 --- /dev/null +++ b/scripts/autobahn-config/fuzzingserver.json @@ -0,0 +1,11 @@ +{ + "url": "ws://127.0.0.1:9001", + "outdir": "./reports/clients", + "cases": ["*"], + "exclude-cases": [ + "9.*", + "12.*", + "13.*" + ], + "exclude-agent-cases": {} +} diff --git a/scripts/autobahn.sh b/scripts/autobahn.sh new file mode 100755 index 0000000..0747186 --- /dev/null +++ b/scripts/autobahn.sh @@ -0,0 +1,6 @@ +docker run -it --rm \ + -v "${PWD}/scripts/autobahn-config:/config" \ + -v "${PWD}/.build/reports:/reports" \ + -p 9001:9001 \ + --name fuzzingserver \ + crossbario/autobahn-testsuite