diff --git a/.gitignore b/.gitignore index 0d38404..bd396ea 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,4 @@ fastlane/screenshots fastlane/test_output /.swiftpm .DS_Store +.vscode diff --git a/.swift-format b/.swift-format index 05e92a7..55e0330 100644 --- a/.swift-format +++ b/.swift-format @@ -13,6 +13,7 @@ "XCTAssertNoThrow" ] }, + "prioritizeKeepingFunctionOutputTogether": true, "respectsExistingLineBreaks": true, "rules": { "AllPublicDeclarationsHaveDocumentation": true, diff --git a/Sources/NatsSwift/Extensions/Data+Parser.swift b/Sources/NatsSwift/Extensions/Data+Parser.swift index d360ebe..6e41c27 100644 --- a/Sources/NatsSwift/Extensions/Data+Parser.swift +++ b/Sources/NatsSwift/Extensions/Data+Parser.swift @@ -17,7 +17,9 @@ extension Data { return self.dropFirst(prefix.count) } - func split(separator: Data, maxSplits: Int = .max, omittingEmptySubsequences: Bool = true) + func split( + separator: Data, maxSplits: Int = .max, omittingEmptySubsequences: Bool = true + ) -> [Data] { var chunks: [Data] = [] @@ -149,7 +151,7 @@ extension Data { let headerParts = header.split(separator: ":") if headerParts.count == 2 { headers.append( - try! HeaderName(String(headerParts[0])), + try HeaderName(String(headerParts[0])), HeaderValue(String(headerParts[1]))) } else { logger.error("Error parsing header: \(header)") diff --git a/Sources/NatsSwift/NatsClient/NatsClient+Events.swift b/Sources/NatsSwift/NatsClient/NatsClient+Events.swift new file mode 100644 index 0000000..f551ad9 --- /dev/null +++ b/Sources/NatsSwift/NatsClient/NatsClient+Events.swift @@ -0,0 +1,32 @@ +// +// NatsClient+Events.swift +// +// NatsSwift +// + +import Foundation + +extension Client { + @discardableResult + public func on(_ events: [NatsEventKind], _ handler: @escaping (NatsEvent) -> Void) -> String { + guard let connectionHandler = self.connectionHandler else { + return "" + } + return connectionHandler.addListeners(for: events, using: handler) + } + + @discardableResult + public func on(_ event: NatsEventKind, _ handler: @escaping (NatsEvent) -> Void) -> String { + guard let connectionHandler = self.connectionHandler else { + return "" + } + return connectionHandler.addListeners(for: [event], using: handler) + } + + func off(_ id: String) { + guard let connectionHandler = self.connectionHandler else { + return + } + connectionHandler.removeListener(id) + } +} diff --git a/Sources/NatsSwift/NatsClient/NatsClient.swift b/Sources/NatsSwift/NatsClient/NatsClient.swift index 15232f6..939005b 100755 --- a/Sources/NatsSwift/NatsClient/NatsClient.swift +++ b/Sources/NatsSwift/NatsClient/NatsClient.swift @@ -44,26 +44,12 @@ public struct Auth { } public class Client { - var urls: [URL] = [] - var pingInteval: TimeInterval = 1.0 - var reconnectWait: TimeInterval = 2.0 - var maxReconnects: Int? = nil - var auth: Auth? = nil - internal let allocator = ByteBufferAllocator() internal var buffer: ByteBuffer internal var connectionHandler: ConnectionHandler? internal init() { self.buffer = allocator.buffer(capacity: 1024) - self.connectionHandler = ConnectionHandler( - inputBuffer: buffer, - urls: urls, - reconnectWait: reconnectWait, - maxReconnects: maxReconnects, - pingInterval: pingInteval, - auth: auth - ) } } @@ -73,8 +59,7 @@ extension Client { //TODO(jrm): handle response logger.debug("connect") guard let connectionHandler = self.connectionHandler else { - throw NSError( - domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"]) + throw NatsClientError("internal error: empty connection handler") } try await connectionHandler.connect() } @@ -82,8 +67,7 @@ extension Client { public func close() async throws { logger.debug("close") guard let connectionHandler = self.connectionHandler else { - throw NSError( - domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"]) + throw NatsClientError("internal error: empty connection handler") } try await connectionHandler.close() } @@ -93,8 +77,7 @@ extension Client { ) throws { logger.debug("publish") guard let connectionHandler = self.connectionHandler else { - throw NSError( - domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"]) + throw NatsClientError("internal error: empty connection handler") } try connectionHandler.write(operation: ClientOp.publish((subject, reply, payload, headers))) } @@ -102,8 +85,7 @@ extension Client { public func flush() async throws { logger.debug("flush") guard let connectionHandler = self.connectionHandler else { - throw NSError( - domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"]) + throw NatsClientError("internal error: empty connection handler") } connectionHandler.channel?.flush() } @@ -111,8 +93,7 @@ extension Client { public func subscribe(to subject: String) async throws -> Subscription { logger.info("subscribe to subject \(subject)") guard let connectionHandler = self.connectionHandler else { - throw NSError( - domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"]) + throw NatsClientError("internal error: empty connection handler") } return try await connectionHandler.subscribe(subject) diff --git a/Sources/NatsSwift/NatsConnection.swift b/Sources/NatsSwift/NatsConnection.swift index 784acd8..96d4d4d 100644 --- a/Sources/NatsSwift/NatsConnection.swift +++ b/Sources/NatsSwift/NatsConnection.swift @@ -17,6 +17,8 @@ class ConnectionHandler: ChannelInboundHandler { internal let allocator = ByteBufferAllocator() internal var inputBuffer: ByteBuffer + internal var eventHandlerStore: [NatsEventKind: [NatsEventHandler]] = [:] + // Connection options internal var urls: [URL] // nanoseconds representation of TimeInterval @@ -40,7 +42,6 @@ class ConnectionHandler: ChannelInboundHandler { inputBuffer.writeBuffer(&byteBuffer) } - // TODO(pp): errors in parser should trigger context.fireErrorCaught() which invokes errorCaught() and invokes reconnect func channelReadComplete(context: ChannelHandlerContext) { var inputChunk = Data(buffer: inputBuffer) @@ -97,6 +98,7 @@ class ConnectionHandler: ChannelInboundHandler { } catch { // TODO(pp): handle async error logger.error("error sending pong: \(error)") + self.fire(.error(NatsClientError("error sending pong: \(error)"))) continue } case .pong: @@ -112,6 +114,8 @@ class ConnectionHandler: ChannelInboundHandler { { inputBuffer.clear() context.fireErrorCaught(err) + } else { + self.fire(.error(err)) } // TODO(pp): handle auth errors here case .message(let msg): @@ -164,16 +168,15 @@ class ConnectionHandler: ChannelInboundHandler { channel.pipeline.addHandler(self).whenComplete { result in switch result { case .success(): - print("success") + logger.debug("success") case .failure(let error): - print("error: \(error)") + logger.debug("error: \(error)") } } return channel.eventLoop.makeSucceededFuture(()) }.connectTimeout(.seconds(5)) guard let url = self.urls.first, let host = url.host, let port = url.port else { - throw NSError( - domain: "nats_swift", code: 1, userInfo: ["message": "no url"]) + throw NatsClientError("no url") } self.channel = try await bootstrap.connect(host: host, port: port).get() } catch { @@ -195,18 +198,13 @@ class ConnectionHandler: ChannelInboundHandler { if let credentialsPath = auth.credentialsPath { let credentials = try await URLSession.shared.data(from: credentialsPath).0 guard let jwt = JwtUtils.parseDecoratedJWT(contents: credentials) else { - throw NSError( - domain: "nats_swift", code: 1, - userInfo: ["message": "failed to extract JWT from credentials file"]) + throw NatsClientError("failed to extract JWT from credentials file") } guard let nkey = JwtUtils.parseDecoratedNKey(contents: credentials) else { - throw NSError( - domain: "nats_swift", code: 1, - userInfo: ["message": "failed to extract NKEY from credentials file"]) + throw NatsClientError("failed to extract NKEY from credentials file") } guard let nonce = self.serverInfo?.nonce else { - throw NSError( - domain: "nats_swift", code: 1, userInfo: ["message": "missing nonce"]) + throw NatsClientError("missing nonce") } let keypair = try KeyPair(seed: String(data: nkey, encoding: .utf8)!) let nonceData = nonce.data(using: .utf8)! @@ -229,9 +227,10 @@ class ConnectionHandler: ChannelInboundHandler { } } } - self.state = .pending + self.state = .connected + self.fire(.connected) guard let channel = self.channel else { - throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "empty channel"]) + throw NatsClientError("internal error: empty channel") } // Schedule the task to send a PING periodically let pingInterval = TimeAmount.nanoseconds(Int64(self.pingInterval * 1_000_000_000)) @@ -246,6 +245,7 @@ class ConnectionHandler: ChannelInboundHandler { func close() async throws { self.state = .closed try await disconnect() + self.fire(.closed) try await self.group.shutdownGracefully() } @@ -280,15 +280,19 @@ class ConnectionHandler: ChannelInboundHandler { func channelInactive(context: ChannelHandlerContext) { logger.debug("TCP channel inactive") - if self.state == .pending { + if self.state == .connected { handleDisconnect() } } func errorCaught(context: ChannelHandlerContext, error: Error) { - // TODO(pp): implement Close() on the connection and call it here logger.debug("Encountered error on the channel: \(error)") context.close(promise: nil) + if let natsErr = error as? NatsError { + self.fire(.error(natsErr)) + } else { + logger.error("unexpected error: \(error)") + } if self.state == .pending { handleDisconnect() } else if self.state == .disconnected { @@ -304,6 +308,9 @@ class ConnectionHandler: ChannelInboundHandler { do { try await self.disconnect() promise.succeed() + } catch ChannelError.alreadyClosed { + // if the channel was already closed, no need to return error + promise.succeed() } catch { promise.fail(error) } @@ -311,6 +318,7 @@ class ConnectionHandler: ChannelInboundHandler { promise.futureResult.whenComplete { result in do { try result.get() + self.fire(.disconnected) } catch { logger.error("Error closing connection: \(error)") } @@ -336,6 +344,17 @@ class ConnectionHandler: ChannelInboundHandler { logger.debug("reconnected") break } + if self.state != .connected { + logger.error("could not reconnect; maxReconnects exceeded") + logger.debug("closing connection") + do { + try await self.close() + } catch { + logger.error("error closing connection: \(error)") + return + } + return + } for (sid, sub) in self.subscriptions { try write(operation: ClientOp.subscribe((sid, sub.subject, nil))) } @@ -362,7 +381,7 @@ class ConnectionHandler: ChannelInboundHandler { func write(operation: ClientOp) throws { guard let allocator = self.channel?.allocator else { - throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "no allocator"]) + throw NatsClientError("internal error: no allocator") } let payload = try operation.asBytes(using: allocator) try self.writeMessage(payload) @@ -384,3 +403,86 @@ class ConnectionHandler: ChannelInboundHandler { return sub } } + +extension ConnectionHandler { + + internal func fire(_ event: NatsEvent) { + let eventKind = event.kind() + guard let handlerStore = self.eventHandlerStore[eventKind] else { return } + + handlerStore.forEach { + $0.handler(event) + } + + } + + internal func addListeners( + for events: [NatsEventKind], using handler: @escaping (NatsEvent) -> Void + ) -> String { + + let id = String.hash() + + for event in events { + if self.eventHandlerStore[event] == nil { + self.eventHandlerStore[event] = [] + } + self.eventHandlerStore[event]?.append( + NatsEventHandler(lid: id, handler: handler)) + } + + return id + + } + + internal func removeListener(_ id: String) { + + for event in NatsEventKind.all { + + let handlerStore = self.eventHandlerStore[event] + if let store = handlerStore { + self.eventHandlerStore[event] = store.filter { $0.listenerId != id } + } + + } + + } + +} + +/// Nats events +public enum NatsEventKind: String { + case connected = "connected" + case disconnected = "disconnected" + case closed = "closed" + case error = "error" + static let all = [connected, disconnected, closed, error] +} + +public enum NatsEvent { + case connected + case disconnected + case closed + case error(NatsError) + + func kind() -> NatsEventKind { + switch self { + case .connected: + return .connected + case .disconnected: + return .disconnected + case .closed: + return .closed + case .error(_): + return .error + } + } +} + +internal struct NatsEventHandler { + let listenerId: String + let handler: (NatsEvent) -> Void + init(lid: String, handler: @escaping (NatsEvent) -> Void) { + self.listenerId = lid + self.handler = handler + } +} diff --git a/Sources/NatsSwift/NatsError.swift b/Sources/NatsSwift/NatsError.swift index 7e0846d..c09f1dd 100644 --- a/Sources/NatsSwift/NatsError.swift +++ b/Sources/NatsSwift/NatsError.swift @@ -4,11 +4,11 @@ // // TODO(pp): For now we're using error implementation from old codebase, consider changing -protocol NatsError: Error { - var description: String { get set } +public protocol NatsError: Error { + var description: String { get } } -struct NatsConnectionError: NatsError { +struct NatsServerError: NatsError { var description: String var normalizedError: String { return description.trimWhitespacesAndApostrophes().lowercased() @@ -18,21 +18,14 @@ struct NatsConnectionError: NatsError { } } -struct NatsSubscribeError: NatsError { +struct NatsParserError: NatsError { var description: String init(_ description: String) { self.description = description } } -struct NatsPublishError: NatsError { - var description: String - init(_ description: String) { - self.description = description - } -} - -struct NatsTimeoutError: NatsError { +struct NatsClientError: NatsError { var description: String init(_ description: String) { self.description = description diff --git a/Sources/NatsSwift/NatsHeaders.swift b/Sources/NatsSwift/NatsHeaders.swift index 9647fdb..9c8d348 100644 --- a/Sources/NatsSwift/NatsHeaders.swift +++ b/Sources/NatsSwift/NatsHeaders.swift @@ -13,19 +13,7 @@ public struct HeaderValue: Equatable, CustomStringConvertible { } } -// Errors for parsing HeaderValue and HeaderName -public enum ParseHeaderValueError: Error, CustomStringConvertible { - case invalidCharacter - - public var description: String { - switch self { - case .invalidCharacter: - return "Invalid character found in header value (value cannot contain '\\r' or '\\n')" - } - } -} - -public enum ParseHeaderNameError: Error, CustomStringConvertible { +public enum ParseHeaderNameError: NatsError { case invalidCharacter public var description: String { diff --git a/Sources/NatsSwift/NatsProto.swift b/Sources/NatsSwift/NatsProto.swift index 28e06be..792426f 100644 --- a/Sources/NatsSwift/NatsProto.swift +++ b/Sources/NatsSwift/NatsProto.swift @@ -37,15 +37,14 @@ enum ServerOp { case info(ServerInfo) case ping case pong - case error(NatsConnectionError) + case error(NatsServerError) case message(MessageInbound) case hMessage(HMessageInbound) static func parse(from msg: Data) throws -> ServerOp { guard msg.count > 2 else { - throw NSError( - domain: "nats_swift", code: 1, - userInfo: ["message": "unable to parse inbound message: \(message)"]) + throw NatsParserError( + "unable to parse inbound message: \(String(data: msg, encoding: .utf8)!)") } let msgType = msg.getMessageType() switch msgType { @@ -59,17 +58,15 @@ enum ServerOp { return ok case .error: if let errMsg = msg.removePrefix(Data(NatsOperation.error.rawBytes)).toString() { - return error(NatsConnectionError(errMsg)) + return error(NatsServerError(errMsg)) } - return error(NatsConnectionError("unexpected error")) + return error(NatsServerError("unexpected error")) case .ping: return ping case .pong: return pong default: - throw NSError( - domain: "nats_swift", code: 1, - userInfo: ["message": "Unknown server op: \(message)"]) + throw NatsParserError("unknown server op: \(String(data: msg, encoding: .utf8)!)") } } } @@ -98,9 +95,7 @@ internal struct HMessageInbound: Equatable { subjectData, sidData, replyData, lengthHeaders, lengthData in let subject = String(decoding: subjectData, as: UTF8.self) guard let sid = UInt64(String(decoding: sidData, as: UTF8.self)) else { - throw NSError( - domain: "nats_swift", code: 1, - userInfo: ["message": "unable to parse subscription ID as number"]) + throw NatsParserError("unable to parse subscription ID as number") } var replySubject: String? = nil if let replyData = replyData { @@ -124,9 +119,7 @@ internal struct HMessageInbound: Equatable { protoComponents[0], protoComponents[1], protoComponents[2], protoComponents[3], protoComponents[4]) default: - throw NSError( - domain: "nats_swift", code: 1, - userInfo: ["message": "unable to parse inbound message header"]) + throw NatsParserError("unable to parse inbound message header") } return msg } @@ -154,9 +147,7 @@ internal struct MessageInbound: Equatable { subjectData, sidData, replyData, lengthData in let subject = String(decoding: subjectData, as: UTF8.self) guard let sid = UInt64(String(decoding: sidData, as: UTF8.self)) else { - throw NSError( - domain: "nats_swift", code: 1, - userInfo: ["message": "unable to parse subscription ID as number"]) + throw NatsParserError("unable to parse subscription ID as number") } var replySubject: String? = nil if let replyData = replyData { @@ -175,9 +166,7 @@ internal struct MessageInbound: Equatable { msg = try parseArgs( protoComponents[0], protoComponents[1], protoComponents[2], protoComponents[3]) default: - throw NSError( - domain: "nats_swift", code: 1, - userInfo: ["message": "unable to parse inbound message header"]) + throw NatsParserError("unable to parse inbound message header") } return msg } diff --git a/Tests/NatsSwiftTests/Integration/EventsTests.swift b/Tests/NatsSwiftTests/Integration/EventsTests.swift new file mode 100644 index 0000000..9254ea7 --- /dev/null +++ b/Tests/NatsSwiftTests/Integration/EventsTests.swift @@ -0,0 +1,95 @@ +// +// EventsTests.swift +// NatsSwiftTests +// + +import Foundation +import Logging +import XCTest + +@testable import NatsSwift + +class TestNatsEvents: XCTestCase { + + static var allTests = [ + ("testClientConnectedEvent", testClientConnectedEvent), + ("testClientConnectedEvent", testClientConnectedEvent), + ("testClientConnectedEvent", testClientConnectedEvent), + ] + + var natsServer = NatsServer() + + override func tearDown() { + super.tearDown() + natsServer.stop() + } + + func testClientConnectedEvent() async throws { + natsServer.start() + logger.logLevel = .debug + + let client = ClientOptions().url(URL(string: natsServer.clientURL)!).build() + + let expectation = XCTestExpectation( + description: "client was not notified of connection established event") + client.on(.connected) { event in + XCTAssertEqual(event.kind(), NatsEventKind.connected) + expectation.fulfill() + } + try await client.connect() + + await fulfillment(of: [expectation], timeout: 1.0) + try await client.close() + } + + func testClientClosedEvent() async throws { + natsServer.start() + logger.logLevel = .debug + + let client = ClientOptions().url(URL(string: natsServer.clientURL)!).build() + + let expectation = XCTestExpectation( + description: "client was not notified of connection closed event") + client.on(.closed) { event in + XCTAssertEqual(event.kind(), NatsEventKind.closed) + expectation.fulfill() + } + try await client.connect() + + try await client.close() + await fulfillment(of: [expectation], timeout: 1.0) + } + + func testClientReconnectEvent() async throws { + natsServer.start() + let port = natsServer.port! + logger.logLevel = .debug + + let client = ClientOptions() + .url(URL(string: natsServer.clientURL)!) + .reconnectWait(1) + .build() + + let disconnected = XCTestExpectation( + description: "client was not notified of disconnection event") + client.on(.disconnected) { event in + XCTAssertEqual(event.kind(), NatsEventKind.disconnected) + disconnected.fulfill() + } + try await client.connect() + natsServer.stop() + + let reconnected = XCTestExpectation( + description: "client was not notified of reconnection event") + client.on(.connected) { event in + XCTAssertEqual(event.kind(), NatsEventKind.connected) + reconnected.fulfill() + } + await fulfillment(of: [disconnected], timeout: 5.0) + + natsServer.start(port: port) + await fulfillment(of: [reconnected], timeout: 5.0) + + try await client.close() + } +}