Skip to content

Commit

Permalink
Add batch write buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
mtmk committed Mar 18, 2024
1 parent 15bddc9 commit 381395a
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 48 deletions.
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ let package = Package(
]
),

.executableTarget(name: "bench", dependencies: ["Nats"]),
.executableTarget(name: "Benchmark", dependencies: ["Nats"]),
.executableTarget(name: "BenchmarkPubSub", dependencies: ["Nats"]),
.executableTarget(name: "BenchmarkSub", dependencies: ["Nats"]),
Expand Down
4 changes: 2 additions & 2 deletions Sources/Benchmark/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ let data = "foo".data(using: .utf8)!
// Warmup
print("Warming up...")
for _ in 0..<10_000 {
try! nats.publish(data, subject: "foo")
try! await nats.publish(data, subject: "foo")
}
print("Starting benchmark...")
let now = DispatchTime.now()
let numMsgs = 10_000_000
for _ in 0..<numMsgs {
try! nats.publish(data, subject: "foo")
try! await nats.publish(data, subject: "foo")
}
try! await nats.flush()
let elapsed = DispatchTime.now().uptimeNanoseconds - now.uptimeNanoseconds
Expand Down
4 changes: 2 additions & 2 deletions Sources/BenchmarkPubSub/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ let data = "foo".data(using: .utf8)!
// Warmup
print("Warming up...")
for _ in 0..<10_000 {
try! nats.publish(data, subject: "foo")
try! await nats.publish(data, subject: "foo")
}
print("Starting benchmark...")
let now = DispatchTime.now()
Expand Down Expand Up @@ -64,7 +64,7 @@ try await withThrowingTaskGroup(of: Void.self) { group in
hm.append(try! HeaderName("foo"), HeaderValue("baz"))
hm.insert(try! HeaderName("another"), HeaderValue("one"))
for i in 0..<numMsgs {
try nats.publish("\(i)".data(using: .utf8)!, subject: "foo", headers: hm)
try await nats.publish("\(i)".data(using: .utf8)!, subject: "foo", headers: hm)
if i % 1000 == 0 {
print("published \(i) msgs")
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/Example/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ for i in 1...3 {
headers.append(try! HeaderName("X-Example"), HeaderValue("example value"))

if let data = "data\(i)".data(using: .utf8) {
try nats.publish(data, subject: "foo.\(i)", headers: headers)
try await nats.publish(data, subject: "foo.\(i)", headers: headers)
}
}

print("signalling done...")
try nats.publish(Data(), subject: "foo.done")
try await nats.publish(Data(), subject: "foo.done")

try await loop.value

Expand Down
134 changes: 134 additions & 0 deletions Sources/Nats/BatchBuffer.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2024 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation
import NIO
import NIOConcurrencyHelpers

class BatchBuffer {
private let batchSize: Int
private var buffer: ByteBuffer
private let channel: Channel
private let lock = NIOLock()
private var waitingPromises: [EventLoopPromise<Void>] = []
private var isWriteInProgress: Bool = false

init(channel: Channel, batchSize: Int = 16*1024) {
self.batchSize = batchSize
self.buffer = channel.allocator.buffer(capacity: batchSize)
self.channel = channel
}

func write<Bytes: Sequence>(_ data: Bytes) async throws where Bytes.Element == UInt8 {
#if SWIFT_NATS_BATCH_BUFFER_DISABLED
let b = channel.allocator.buffer(bytes: data)
try await channel.writeAndFlush(b)
#else
// Batch writes and if we have more than the batch size
// already in the buffer await until buffer is flushed
// to handle any back pressure
try await withCheckedThrowingContinuation { continuation in
self.lock.withLock {
guard self.buffer.readableBytes < self.batchSize else {
let promise = self.channel.eventLoop.makePromise(of: Void.self)
promise.futureResult.whenComplete { result in
switch result {
case .success:
// we should be in lock when completed here
self.buffer.writeBytes(data)
self.flushWhenIdle()
continuation.resume()
case .failure(let error):
continuation.resume(throwing: error)
}
}
waitingPromises.append(promise)
return
}

self.buffer.writeBytes(data)
continuation.resume()
}

flushWhenIdle()
}
#endif
}

func clear() {
lock.withLock {
self.buffer.clear()
}
}

private func flushWhenIdle() {
channel.eventLoop.execute {

// We have to use lock/unlock calls rather than the withLock
// since we need writeBuffer reference
self.lock.lock()

// The idea is to keep writing to the buffer while a writeAndFlush() is
// in progress, so we can batch as many messages as possible.
guard !self.isWriteInProgress else {
self.lock.unlock()
return
}

// We need a separate write buffer so we can free the message buffer for more
// messages to be collected.
guard let writeBuffer = self.getWriteBuffer() else {
self.lock.unlock()
return
}

self.isWriteInProgress = true

self.lock.unlock()

let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
writePromise.futureResult.whenComplete { result in
self.lock.withLock {
self.isWriteInProgress = false
switch result {
case .success:
self.waitingPromises.forEach { $0.succeed(()) }
self.waitingPromises.removeAll()
case .failure(let error):
self.waitingPromises.forEach { $0.fail(error) }
self.waitingPromises.removeAll()
}

// Check if there are any pending flushes
if self.buffer.readableBytes > 0 {
self.flushWhenIdle()
}
}
}

self.channel.writeAndFlush(writeBuffer, promise: writePromise)
}
}

private func getWriteBuffer() -> ByteBuffer? {
guard buffer.readableBytes > 0 else {
return nil
}

var writeBuffer = channel.allocator.buffer(capacity: buffer.readableBytes)
writeBuffer.writeBytes(buffer.readableBytesView)
buffer.clear()

return writeBuffer
}
}
8 changes: 4 additions & 4 deletions Sources/Nats/NatsClient/NatsClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ extension Client {

public func publish(
_ payload: Data, subject: String, reply: String? = nil, headers: HeaderMap? = nil
) throws {
) async throws {
logger.debug("publish")
guard let connectionHandler = self.connectionHandler else {
throw NatsClientError("internal error: empty connection handler")
}
try connectionHandler.write(operation: ClientOp.publish((subject, reply, payload, headers)))
try await connectionHandler.write(operation: ClientOp.publish((subject, reply, payload, headers)))
}

public func request(
Expand All @@ -106,7 +106,7 @@ extension Client {
let inbox = "_INBOX.\(nextNuid())"

let response = try await connectionHandler.subscribe(inbox)
try connectionHandler.write(operation: ClientOp.publish((to, inbox, payload, headers)))
try await connectionHandler.write(operation: ClientOp.publish((to, inbox, payload, headers)))
connectionHandler.channel?.flush()
if let message = await response.makeAsyncIterator().next() {
if let status = message.status, status == StatusCode.noResponders {
Expand Down Expand Up @@ -139,7 +139,7 @@ extension Client {
throw NatsClientError("internal error: empty connection handler")
}
let ping = RttCommand.makeFrom(channel: connectionHandler.channel)
connectionHandler.sendPing(ping)
await connectionHandler.sendPing(ping)
return try await ping.getRoundTripTime()
}
}
50 changes: 28 additions & 22 deletions Sources/Nats/NatsConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ConnectionHandler: ChannelInboundHandler {
private var connectionEstablishedContinuation: CheckedContinuation<Void, Error>?

private let pingQueue = ConcurrentQueue<RttCommand>()
private var batchBuffer: BatchBuffer?

init(
inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?,
Expand Down Expand Up @@ -143,12 +144,13 @@ class ConnectionHandler: ChannelInboundHandler {
switch op {
case .ping:
logger.debug("ping")
do {
try self.write(operation: .pong)
} catch {
logger.error("error sending pong: \(error)")
self.fire(.error(NatsClientError("error sending pong: \(error)")))
continue
Task {
do {
try await self.write(operation: .pong)
} catch {
logger.error("error sending pong: \(error)")
self.fire(.error(NatsClientError("error sending pong: \(error)")))
}
}
case .pong:
logger.debug("pong")
Expand Down Expand Up @@ -215,7 +217,7 @@ class ConnectionHandler: ChannelInboundHandler {
// if there are more reconnect attempts than the number of servers,
// we are after the initial connect, so sleep between servers
let shouldSleep = self.reconnectAttempts >= self.urls.count
print(self.reconnectAttempts)
logger.debug("reconnect attempts: \(self.reconnectAttempts)")
for s in servers {
if let maxReconnects {
if reconnectAttempts >= maxReconnects {
Expand Down Expand Up @@ -255,7 +257,7 @@ class ConnectionHandler: ChannelInboundHandler {
self.pingTask = channel.eventLoop.scheduleRepeatedTask(
initialDelay: pingInterval, delay: pingInterval
) { [weak self] task in
self?.sendPing()
Task { [weak self] in await self?.sendPing() }
}
logger.debug("connection established")
return
Expand All @@ -271,6 +273,10 @@ class ConnectionHandler: ChannelInboundHandler {
throw NatsConfigError("no url")
}
self.channel = try await bootstrap.connect(host: host, port: port).get()
guard let channel = self.channel else {
throw NatsClientError("internal error: empty channel")
}
self.batchBuffer = BatchBuffer(channel: channel)
} catch {
continuation.resume(throwing: error)
}
Expand Down Expand Up @@ -345,8 +351,8 @@ class ConnectionHandler: ChannelInboundHandler {
self.connectionEstablishedContinuation = continuation
Task.detached {
do {
try self.write(operation: ClientOp.connect(connect))
try self.write(operation: ClientOp.ping)
try await self.write(operation: ClientOp.connect(connect))
try await self.write(operation: ClientOp.ping)
self.channel?.flush()
} catch {
continuation.resume(throwing: error)
Expand Down Expand Up @@ -426,7 +432,7 @@ class ConnectionHandler: ChannelInboundHandler {
try await self.channel?.close().get()
}

internal func sendPing(_ rttCommand: RttCommand? = nil) {
internal func sendPing(_ rttCommand: RttCommand? = nil) async {
let pingsOut = self.outstandingPings.wrappingIncrementThenLoad(
ordering: AtomicUpdateOrdering.relaxed)
if pingsOut > 2 {
Expand All @@ -436,7 +442,7 @@ class ConnectionHandler: ChannelInboundHandler {
let ping = ClientOp.ping
do {
self.pingQueue.enqueue(rttCommand ?? RttCommand.makeFrom(channel: self.channel))
try self.write(operation: ping)
try await self.write(operation: ping)
logger.debug("sent ping: \(pingsOut)")
} catch {
logger.error("Unable to send ping: \(error)")
Expand Down Expand Up @@ -537,30 +543,30 @@ class ConnectionHandler: ChannelInboundHandler {
return
}
for (sid, sub) in self.subscriptions {
try write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
try await write(operation: ClientOp.subscribe((sid, sub.subject, nil)))
}
}
}

func write(operation: ClientOp) throws {
func write(operation: ClientOp) async throws {
guard let allocator = self.channel?.allocator else {
throw NatsClientError("internal error: no allocator")
}
let payload = try operation.asBytes(using: allocator)
try self.writeMessage(payload)
try await self.writeMessage(payload)
}

func writeMessage(_ message: ByteBuffer) throws {
_ = channel?.write(message)
if channel?.isWritable ?? true {
channel?.flush()
func writeMessage(_ message: ByteBuffer) async throws {
guard let buffer = self.batchBuffer else {
throw NatsClientError("not connected")
}
try await buffer.write(message.readableBytesView)
}

internal func subscribe(_ subject: String) async throws -> Subscription {
let sid = self.subscriptionCounter.wrappingIncrementThenLoad(
ordering: AtomicUpdateOrdering.relaxed)
try write(operation: ClientOp.subscribe((sid, subject, nil)))
try await write(operation: ClientOp.subscribe((sid, subject, nil)))
let sub = Subscription(sid: sid, subject: subject, conn: self)
self.subscriptions[sid] = sub
return sub
Expand All @@ -570,12 +576,12 @@ class ConnectionHandler: ChannelInboundHandler {
if let max, sub.delivered < max {
// if max is set and the sub has not yet reached it, send unsub with max set
// and do not remove the sub from connection
try write(operation: ClientOp.unsubscribe((sid: sub.sid, max: max)))
try await write(operation: ClientOp.unsubscribe((sid: sub.sid, max: max)))
sub.max = max
} else {
// if max is not set or the subscription received at least as meny
// messages as max, send unsub command without max and remove sub from connection
try write(operation: ClientOp.unsubscribe((sid: sub.sid, max: nil)))
try await write(operation: ClientOp.unsubscribe((sid: sub.sid, max: nil)))
self.removeSub(sub: sub)
}
}
Expand Down
Loading

0 comments on commit 381395a

Please sign in to comment.