Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connection events #32

Merged
merged 2 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ fastlane/screenshots
fastlane/test_output
/.swiftpm
.DS_Store
.vscode
1 change: 1 addition & 0 deletions .swift-format
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"XCTAssertNoThrow"
]
},
"prioritizeKeepingFunctionOutputTogether": true,
"respectsExistingLineBreaks": true,
"rules": {
"AllPublicDeclarationsHaveDocumentation": true,
Expand Down
6 changes: 4 additions & 2 deletions Sources/NatsSwift/Extensions/Data+Parser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ extension Data {
return self.dropFirst(prefix.count)
}

func split(separator: Data, maxSplits: Int = .max, omittingEmptySubsequences: Bool = true)
func split(
separator: Data, maxSplits: Int = .max, omittingEmptySubsequences: Bool = true
)
-> [Data]
{
var chunks: [Data] = []
Expand Down Expand Up @@ -149,7 +151,7 @@ extension Data {
let headerParts = header.split(separator: ":")
if headerParts.count == 2 {
headers.append(
try! HeaderName(String(headerParts[0])),
try HeaderName(String(headerParts[0])),
HeaderValue(String(headerParts[1])))
} else {
logger.error("Error parsing header: \(header)")
Expand Down
32 changes: 32 additions & 0 deletions Sources/NatsSwift/NatsClient/NatsClient+Events.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//
// NatsClient+Events.swift
//
// NatsSwift
//

import Foundation

extension Client {
@discardableResult
public func on(_ events: [NatsEventKind], _ handler: @escaping (NatsEvent) -> Void) -> String {
guard let connectionHandler = self.connectionHandler else {
return ""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does returning this mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing really, the alternative would be to throw an error here and I thought it might be better to avoid it... Normally this method returns a callback ID so that it can be closed with off().

}
return connectionHandler.addListeners(for: events, using: handler)
}

@discardableResult
public func on(_ event: NatsEventKind, _ handler: @escaping (NatsEvent) -> Void) -> String {
guard let connectionHandler = self.connectionHandler else {
return ""
}
return connectionHandler.addListeners(for: [event], using: handler)
}

func off(_ id: String) {
guard let connectionHandler = self.connectionHandler else {
return
}
connectionHandler.removeListener(id)
}
}
29 changes: 5 additions & 24 deletions Sources/NatsSwift/NatsClient/NatsClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,12 @@ public struct Auth {
}

public class Client {
var urls: [URL] = []
var pingInteval: TimeInterval = 1.0
var reconnectWait: TimeInterval = 2.0
var maxReconnects: Int? = nil
var auth: Auth? = nil

internal let allocator = ByteBufferAllocator()
internal var buffer: ByteBuffer
internal var connectionHandler: ConnectionHandler?

internal init() {
self.buffer = allocator.buffer(capacity: 1024)
self.connectionHandler = ConnectionHandler(
inputBuffer: buffer,
urls: urls,
reconnectWait: reconnectWait,
maxReconnects: maxReconnects,
pingInterval: pingInteval,
auth: auth
)
}
}

Expand All @@ -73,17 +59,15 @@ extension Client {
//TODO(jrm): handle response
logger.debug("connect")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
try await connectionHandler.connect()
}

public func close() async throws {
logger.debug("close")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
try await connectionHandler.close()
}
Expand All @@ -93,26 +77,23 @@ extension Client {
) throws {
logger.debug("publish")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
try connectionHandler.write(operation: ClientOp.publish((subject, reply, payload, headers)))
}

public func flush() async throws {
logger.debug("flush")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
connectionHandler.channel?.flush()
}

public func subscribe(to subject: String) async throws -> Subscription {
logger.info("subscribe to subject \(subject)")
guard let connectionHandler = self.connectionHandler else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "empty connection handler"])
throw NatsClientError("internal error: empty connection handler")
}
return try await connectionHandler.subscribe(subject)

Expand Down
138 changes: 120 additions & 18 deletions Sources/NatsSwift/NatsConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class ConnectionHandler: ChannelInboundHandler {
internal let allocator = ByteBufferAllocator()
internal var inputBuffer: ByteBuffer

internal var eventHandlerStore: [NatsEventKind: [NatsEventHandler]] = [:]

// Connection options
internal var urls: [URL]
// nanoseconds representation of TimeInterval
Expand All @@ -40,7 +42,6 @@ class ConnectionHandler: ChannelInboundHandler {
inputBuffer.writeBuffer(&byteBuffer)
}

// TODO(pp): errors in parser should trigger context.fireErrorCaught() which invokes errorCaught() and invokes reconnect
func channelReadComplete(context: ChannelHandlerContext) {
var inputChunk = Data(buffer: inputBuffer)

Expand Down Expand Up @@ -97,6 +98,7 @@ class ConnectionHandler: ChannelInboundHandler {
} catch {
// TODO(pp): handle async error
logger.error("error sending pong: \(error)")
self.fire(.error(NatsClientError("error sending pong: \(error)")))
continue
}
case .pong:
Expand All @@ -112,6 +114,8 @@ class ConnectionHandler: ChannelInboundHandler {
{
inputBuffer.clear()
context.fireErrorCaught(err)
} else {
self.fire(.error(err))
}
// TODO(pp): handle auth errors here
case .message(let msg):
Expand Down Expand Up @@ -164,16 +168,15 @@ class ConnectionHandler: ChannelInboundHandler {
channel.pipeline.addHandler(self).whenComplete { result in
switch result {
case .success():
print("success")
logger.debug("success")
case .failure(let error):
print("error: \(error)")
logger.debug("error: \(error)")
}
}
return channel.eventLoop.makeSucceededFuture(())
}.connectTimeout(.seconds(5))
guard let url = self.urls.first, let host = url.host, let port = url.port else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "no url"])
throw NatsClientError("no url")
}
self.channel = try await bootstrap.connect(host: host, port: port).get()
} catch {
Expand All @@ -195,18 +198,13 @@ class ConnectionHandler: ChannelInboundHandler {
if let credentialsPath = auth.credentialsPath {
let credentials = try await URLSession.shared.data(from: credentialsPath).0
guard let jwt = JwtUtils.parseDecoratedJWT(contents: credentials) else {
throw NSError(
domain: "nats_swift", code: 1,
userInfo: ["message": "failed to extract JWT from credentials file"])
throw NatsClientError("failed to extract JWT from credentials file")
}
guard let nkey = JwtUtils.parseDecoratedNKey(contents: credentials) else {
throw NSError(
domain: "nats_swift", code: 1,
userInfo: ["message": "failed to extract NKEY from credentials file"])
throw NatsClientError("failed to extract NKEY from credentials file")
}
guard let nonce = self.serverInfo?.nonce else {
throw NSError(
domain: "nats_swift", code: 1, userInfo: ["message": "missing nonce"])
throw NatsClientError("missing nonce")
}
let keypair = try KeyPair(seed: String(data: nkey, encoding: .utf8)!)
let nonceData = nonce.data(using: .utf8)!
Expand All @@ -229,9 +227,10 @@ class ConnectionHandler: ChannelInboundHandler {
}
}
}
self.state = .pending
self.state = .connected
self.fire(.connected)
guard let channel = self.channel else {
throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "empty channel"])
throw NatsClientError("internal error: empty channel")
}
// Schedule the task to send a PING periodically
let pingInterval = TimeAmount.nanoseconds(Int64(self.pingInterval * 1_000_000_000))
Expand All @@ -246,6 +245,7 @@ class ConnectionHandler: ChannelInboundHandler {
func close() async throws {
self.state = .closed
try await disconnect()
self.fire(.closed)
try await self.group.shutdownGracefully()
}

Expand Down Expand Up @@ -280,15 +280,19 @@ class ConnectionHandler: ChannelInboundHandler {
func channelInactive(context: ChannelHandlerContext) {
logger.debug("TCP channel inactive")

if self.state == .pending {
if self.state == .connected {
handleDisconnect()
}
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
// TODO(pp): implement Close() on the connection and call it here
logger.debug("Encountered error on the channel: \(error)")
context.close(promise: nil)
if let natsErr = error as? NatsError {
self.fire(.error(natsErr))
} else {
logger.error("unexpected error: \(error)")
}
if self.state == .pending {
handleDisconnect()
} else if self.state == .disconnected {
Expand All @@ -304,13 +308,17 @@ class ConnectionHandler: ChannelInboundHandler {
do {
try await self.disconnect()
promise.succeed()
} catch ChannelError.alreadyClosed {
// if the channel was already closed, no need to return error
promise.succeed()
} catch {
promise.fail(error)
}
}
promise.futureResult.whenComplete { result in
do {
try result.get()
self.fire(.disconnected)
} catch {
logger.error("Error closing connection: \(error)")
}
Expand All @@ -336,6 +344,17 @@ class ConnectionHandler: ChannelInboundHandler {
logger.debug("reconnected")
break
}
if self.state != .connected {
logger.error("could not reconnect; maxReconnects exceeded")
logger.debug("closing connection")
do {
try await self.close()
} catch {
logger.error("error closing connection: \(error)")
return
}
return
}
for (sid, sub) in self.subscriptions {
try write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
}
Expand All @@ -362,7 +381,7 @@ class ConnectionHandler: ChannelInboundHandler {

func write(operation: ClientOp) throws {
guard let allocator = self.channel?.allocator else {
throw NSError(domain: "nats_swift", code: 1, userInfo: ["message": "no allocator"])
throw NatsClientError("internal error: no allocator")
}
let payload = try operation.asBytes(using: allocator)
try self.writeMessage(payload)
Expand All @@ -384,3 +403,86 @@ class ConnectionHandler: ChannelInboundHandler {
return sub
}
}

extension ConnectionHandler {

internal func fire(_ event: NatsEvent) {
let eventKind = event.kind()
guard let handlerStore = self.eventHandlerStore[eventKind] else { return }

handlerStore.forEach {
$0.handler(event)
}

}

internal func addListeners(
for events: [NatsEventKind], using handler: @escaping (NatsEvent) -> Void
) -> String {

let id = String.hash()

for event in events {
if self.eventHandlerStore[event] == nil {
self.eventHandlerStore[event] = []
}
self.eventHandlerStore[event]?.append(
NatsEventHandler(lid: id, handler: handler))
}

return id

}

internal func removeListener(_ id: String) {

for event in NatsEventKind.all {

let handlerStore = self.eventHandlerStore[event]
if let store = handlerStore {
self.eventHandlerStore[event] = store.filter { $0.listenerId != id }
}

}

}

}

/// Nats events
public enum NatsEventKind: String {
case connected = "connected"
case disconnected = "disconnected"
case closed = "closed"
case error = "error"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this variant also ontain the error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just an event kind so that it can be easily registered. I went with 2 types - NatsEvent and NatsEventKind because NatsEvent may contain additional data for an event (like in the case of .error), but the kind is just a simple string for registering.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see it now.

static let all = [connected, disconnected, closed, error]
}

public enum NatsEvent {
case connected
case disconnected
case closed
case error(NatsError)

func kind() -> NatsEventKind {
switch self {
case .connected:
return .connected
case .disconnected:
return .disconnected
case .closed:
return .closed
case .error(_):
return .error
}
}
}

internal struct NatsEventHandler {
let listenerId: String
let handler: (NatsEvent) -> Void
init(lid: String, handler: @escaping (NatsEvent) -> Void) {
self.listenerId = lid
self.handler = handler
}
}
Loading
Loading