Skip to content

Commit

Permalink
Add connection events (#32)
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Feb 12, 2024
1 parent 67e12c6 commit b194ec7
Show file tree
Hide file tree
Showing 10 changed files with 274 additions and 90 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ fastlane/screenshots
fastlane/test_output
/.swiftpm
.DS_Store
.vscode
1 change: 1 addition & 0 deletions .swift-format
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"XCTAssertNoThrow"
]
},
"prioritizeKeepingFunctionOutputTogether": true,
"respectsExistingLineBreaks": true,
"rules": {
"AllPublicDeclarationsHaveDocumentation": true,
Expand Down
6 changes: 4 additions & 2 deletions Sources/NatsSwift/Extensions/Data+Parser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ extension Data {
return self.dropFirst(prefix.count)
}

func split(separator: Data, maxSplits: Int = .max, omittingEmptySubsequences: Bool = true)
func split(
separator: Data, maxSplits: Int = .max, omittingEmptySubsequences: Bool = true
)
-> [Data]
{
var chunks: [Data] = []
Expand Down Expand Up @@ -149,7 +151,7 @@ extension Data {
let headerParts = header.split(separator: ":")
if headerParts.count == 2 {
headers.append(
try! HeaderName(String(headerParts[0])),
try HeaderName(String(headerParts[0])),
HeaderValue(String(headerParts[1])))
} else {
logger.error("Error parsing header: \(header)")
Expand Down
32 changes: 32 additions & 0 deletions Sources/NatsSwift/NatsClient/NatsClient+Events.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//
// NatsClient+Events.swift
//
// NatsSwift
//

import Foundation

extension Client {
@discardableResult
public func on(_ events: [NatsEventKind], _ handler: @escaping (NatsEvent) -> Void) -> String {
guard let connectionHandler = self.connectionHandler else {
return ""
}
return connectionHandler.addListeners(for: events, using: handler)
}

@discardableResult
public func on(_ event: NatsEventKind, _ handler: @escaping (NatsEvent) -> Void) -> String {
guard let connectionHandler = self.connectionHandler else {
return ""
}
return connectionHandler.addListeners(for: [event], using: handler)
}

func off(_ id: String) {
guard let connectionHandler = self.connectionHandler else {
return
}
connectionHandler.removeListener(id)
}
}
29 changes: 5 additions & 24 deletions Sources/NatsSwift/NatsClient/NatsClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,12 @@ public struct Auth {
}

public class Client {
var urls: [URL] = []
var pingInteval: TimeInterval = 1.0
var reconnectWait: TimeInterval = 2.0
var maxReconnects: Int? = nil
var auth: Auth? = nil

internal let allocator = ByteBufferAllocator()
internal var buffer: ByteBuffer
internal var connectionHandler: ConnectionHandler?

internal init() {
self.buffer = allocator.buffer(capacity: 1024)
self.connectionHandler = ConnectionHandler(
inputBuffer: buffer,
urls: urls,
reconnectWait: reconnectWait,
maxReconnects: maxReconnects,
pingInterval: pingInteval,
auth: auth
)
}
}

Expand All @@ -73,17 +59,15 @@ extension Client {
//TODO(jrm): handle response
logger.debug("connect")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
try await connectionHandler.connect()
}

public func close() async throws {
logger.debug("close")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
try await connectionHandler.close()
}
Expand All @@ -93,26 +77,23 @@ extension Client {
) throws {
logger.debug("publish")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
try connectionHandler.write(operation: ClientOp.publish((subject, reply, payload, headers)))
}

public func flush() async throws {
logger.debug("flush")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
connectionHandler.channel?.flush()
}

public func subscribe(to subject: String) async throws -> Subscription {
logger.info("subscribe to subject \(subject)")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
return try await connectionHandler.subscribe(subject)

Expand Down
138 changes: 120 additions & 18 deletions Sources/NatsSwift/NatsConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class ConnectionHandler: ChannelInboundHandler {
internal let allocator = ByteBufferAllocator()
internal var inputBuffer: ByteBuffer

internal var eventHandlerStore: [NatsEventKind: [NatsEventHandler]] = [:]

// Connection options
internal var urls: [URL]
// nanoseconds representation of TimeInterval
Expand All @@ -40,7 +42,6 @@ class ConnectionHandler: ChannelInboundHandler {
inputBuffer.writeBuffer(&byteBuffer)
}

// TODO(pp): errors in parser should trigger context.fireErrorCaught() which invokes errorCaught() and invokes reconnect
func channelReadComplete(context: ChannelHandlerContext) {
var inputChunk = Data(buffer: inputBuffer)

Expand Down Expand Up @@ -97,6 +98,7 @@ class ConnectionHandler: ChannelInboundHandler {
} catch {
// TODO(pp): handle async error
logger.error("error sending pong: \(error)")
self.fire(.error(NatsClientError("error sending pong: \(error)")))
continue
}
case .pong:
Expand All @@ -112,6 +114,8 @@ class ConnectionHandler: ChannelInboundHandler {
{
inputBuffer.clear()
context.fireErrorCaught(err)
} else {
self.fire(.error(err))
}
// TODO(pp): handle auth errors here
case .message(let msg):
Expand Down Expand Up @@ -164,16 +168,15 @@ class ConnectionHandler: ChannelInboundHandler {
channel.pipeline.addHandler(self).whenComplete { result in
switch result {
case .success():
print("success")
logger.debug("success")
case .failure(let error):
print("error: \(error)")
logger.debug("error: \(error)")
}
}
return channel.eventLoop.makeSucceededFuture(())
}.connectTimeout(.seconds(5))
guard let url = self.urls.first, let host = url.host, let port = url.port else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "no url"])
throw NatsClientError("no url")
}
self.channel = try await bootstrap.connect(host: host, port: port).get()
} catch {
Expand All @@ -195,18 +198,13 @@ class ConnectionHandler: ChannelInboundHandler {
if let credentialsPath = auth.credentialsPath {
let credentials = try await URLSession.shared.data(from: credentialsPath).0
guard let jwt = JwtUtils.parseDecoratedJWT(contents: credentials) else {
throw NSError(
domain: "nats_swift", code: 1,
userInfo: ["message": "failed to extract JWT from credentials file"])
throw NatsClientError("failed to extract JWT from credentials file")
}
guard let nkey = JwtUtils.parseDecoratedNKey(contents: credentials) else {
throw NSError(
domain: "nats_swift", code: 1,
userInfo: ["message": "failed to extract NKEY from credentials file"])
throw NatsClientError("failed to extract NKEY from credentials file")
}
guard let nonce = self.serverInfo?.nonce else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "missing nonce"])
throw NatsClientError("missing nonce")
}
let keypair = try KeyPair(seed: String(data: nkey, encoding: .utf8)!)
let nonceData = nonce.data(using: .utf8)!
Expand All @@ -229,9 +227,10 @@ class ConnectionHandler: ChannelInboundHandler {
}
}
}
self.state = .pending
self.state = .connected
self.fire(.connected)
guard let channel = self.channel else {
throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "empty channel"])
throw NatsClientError("internal error: empty channel")
}
// Schedule the task to send a PING periodically
let pingInterval = TimeAmount.nanoseconds(Int64(self.pingInterval * 1_000_000_000))
Expand All @@ -246,6 +245,7 @@ class ConnectionHandler: ChannelInboundHandler {
func close() async throws {
self.state = .closed
try await disconnect()
self.fire(.closed)
try await self.group.shutdownGracefully()
}

Expand Down Expand Up @@ -280,15 +280,19 @@ class ConnectionHandler: ChannelInboundHandler {
func channelInactive(context: ChannelHandlerContext) {
logger.debug("TCP channel inactive")

if self.state == .pending {
if self.state == .connected {
handleDisconnect()
}
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
// TODO(pp): implement Close() on the connection and call it here
logger.debug("Encountered error on the channel: \(error)")
context.close(promise: nil)
if let natsErr = error as? NatsError {
self.fire(.error(natsErr))
} else {
logger.error("unexpected error: \(error)")
}
if self.state == .pending {
handleDisconnect()
} else if self.state == .disconnected {
Expand All @@ -304,13 +308,17 @@ class ConnectionHandler: ChannelInboundHandler {
do {
try await self.disconnect()
promise.succeed()
} catch ChannelError.alreadyClosed {
// if the channel was already closed, no need to return error
promise.succeed()
} catch {
promise.fail(error)
}
}
promise.futureResult.whenComplete { result in
do {
try result.get()
self.fire(.disconnected)
} catch {
logger.error("Error closing connection: \(error)")
}
Expand All @@ -336,6 +344,17 @@ class ConnectionHandler: ChannelInboundHandler {
logger.debug("reconnected")
break
}
if self.state != .connected {
logger.error("could not reconnect; maxReconnects exceeded")
logger.debug("closing connection")
do {
try await self.close()
} catch {
logger.error("error closing connection: \(error)")
return
}
return
}
for (sid, sub) in self.subscriptions {
try write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
}
Expand All @@ -362,7 +381,7 @@ class ConnectionHandler: ChannelInboundHandler {

func write(operation: ClientOp) throws {
guard let allocator = self.channel?.allocator else {
throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "no allocator"])
throw NatsClientError("internal error: no allocator")
}
let payload = try operation.asBytes(using: allocator)
try self.writeMessage(payload)
Expand All @@ -384,3 +403,86 @@ class ConnectionHandler: ChannelInboundHandler {
return sub
}
}

extension ConnectionHandler {

internal func fire(_ event: NatsEvent) {
let eventKind = event.kind()
guard let handlerStore = self.eventHandlerStore[eventKind] else { return }

handlerStore.forEach {
$0.handler(event)
}

}

internal func addListeners(
for events: [NatsEventKind], using handler: @escaping (NatsEvent) -> Void
) -> String {

let id = String.hash()

for event in events {
if self.eventHandlerStore[event] == nil {
self.eventHandlerStore[event] = []
}
self.eventHandlerStore[event]?.append(
NatsEventHandler(lid: id, handler: handler))
}

return id

}

internal func removeListener(_ id: String) {

for event in NatsEventKind.all {

let handlerStore = self.eventHandlerStore[event]
if let store = handlerStore {
self.eventHandlerStore[event] = store.filter { $0.listenerId != id }
}

}

}

}

/// Nats events
public enum NatsEventKind: String {
case connected = "connected"
case disconnected = "disconnected"
case closed = "closed"
case error = "error"
static let all = [connected, disconnected, closed, error]
}

public enum NatsEvent {
case connected
case disconnected
case closed
case error(NatsError)

func kind() -> NatsEventKind {
switch self {
case .connected:
return .connected
case .disconnected:
return .disconnected
case .closed:
return .closed
case .error(_):
return .error
}
}
}

internal struct NatsEventHandler {
let listenerId: String
let handler: (NatsEvent) -> Void
init(lid: String, handler: @escaping (NatsEvent) -> Void) {
self.listenerId = lid
self.handler = handler
}
}
Loading

0 comments on commit b194ec7

Please sign in to comment.