Skip to content

Commit

Permalink
Add rtt method
Browse files Browse the repository at this point in the history
  • Loading branch information
mtmk authored Feb 26, 2024
1 parent 9aa1350 commit eaef1db
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 1 deletion.
19 changes: 19 additions & 0 deletions Sources/NatsSwift/ConcurrentQueue.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import NIOConcurrencyHelpers

internal class ConcurrentQueue<T> {
private var elements: [T] = []
private let lock = NIOLock()

func enqueue(_ element: T) {
lock.lock()
defer { lock.unlock() }
elements.append(element)
}

func dequeue() -> T? {
lock.lock()
defer { lock.unlock() }
guard !elements.isEmpty else { return nil }
return elements.removeFirst()
}
}
9 changes: 9 additions & 0 deletions Sources/NatsSwift/NatsClient/NatsClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,13 @@ extension Client {
return try await connectionHandler.subscribe(subject)

}

public func rtt() async throws -> Duration {
guard let connectionHandler = self.connectionHandler else {
throw NatsClientError("internal error: empty connection handler")
}
let ping = RttCommand.makeFrom(channel: connectionHandler.channel)
connectionHandler.sendPing(ping)
return try await ping.getRoundTripTime ()
}
}
6 changes: 5 additions & 1 deletion Sources/NatsSwift/NatsConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ConnectionHandler: ChannelInboundHandler {
private var serverInfoContinuation: CheckedContinuation<ServerInfo, Error>?
private var connectionEstablishedContinuation: CheckedContinuation<Void, Error>?

private let pingQueue = ConcurrentQueue<RttCommand>()

init(
inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?,
retainServersOrder: Bool,
Expand Down Expand Up @@ -143,6 +145,7 @@ class ConnectionHandler: ChannelInboundHandler {
case .pong:
logger.debug("pong")
self.outstandingPings.store(0, ordering: AtomicStoreOrdering.relaxed)
self.pingQueue.dequeue()?.setRoundTripTime()
case .error(let err):
logger.debug("error \(err)")

Expand Down Expand Up @@ -414,7 +417,7 @@ class ConnectionHandler: ChannelInboundHandler {
try await self.channel?.close().get()
}

private func sendPing() {
internal func sendPing(_ rttCommand: RttCommand? = nil) {
let pingsOut = self.outstandingPings.wrappingIncrementThenLoad(
ordering: AtomicUpdateOrdering.relaxed)
if pingsOut > 2 {
Expand All @@ -423,6 +426,7 @@ class ConnectionHandler: ChannelInboundHandler {
}
let ping = ClientOp.ping
do {
self.pingQueue.enqueue(rttCommand ?? RttCommand.makeFrom(channel: self.channel))
try self.write(operation: ping)
logger.debug("sent ping: \(pingsOut)")
} catch {
Expand Down
24 changes: 24 additions & 0 deletions Sources/NatsSwift/RttCommand.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import NIOCore

internal class RttCommand {
let startTime = ContinuousClock().now
let promise: EventLoopPromise<Duration>?

static func makeFrom(channel: Channel?) -> RttCommand {
RttCommand(promise: channel?.eventLoop.makePromise(of: Duration.self))
}

private init(promise: EventLoopPromise<Duration>?) {
self.promise = promise
}

func setRoundTripTime() {
let now: ContinuousClock.Instant = ContinuousClock().now
let rtt: Duration = now - startTime
promise?.succeed(rtt)
}

func getRoundTripTime() async throws -> Duration {
try await promise?.futureResult.get() ?? Duration.zero
}
}
15 changes: 15 additions & 0 deletions Tests/NatsSwiftTests/Integration/ConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import XCTest
class CoreNatsTests: XCTestCase {

static var allTests = [
("testRtt", testRtt),
("testPublish", testPublish),
("testPublishWithReply", testPublishWithReply),
("testSubscribe", testSubscribe),
Expand All @@ -30,6 +31,20 @@ class CoreNatsTests: XCTestCase {
natsServer.stop()
}

func testRtt() async throws {
natsServer.start()
logger.logLevel = .debug
let client = ClientOptions()
.url(URL(string: natsServer.clientURL)!)
.build()
try await client.connect()

let rtt: Duration = try await client.rtt()
XCTAssertGreaterThan(rtt, Duration.zero, "should have RTT")

try await client.close()
}

func testPublish() async throws {
natsServer.start()
logger.logLevel = .debug
Expand Down

0 comments on commit eaef1db

Please sign in to comment.