Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tls support #33

Merged
merged 8 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import PackageDescription
let package = Package(
name: "NatsSwift",
platforms: [
.macOS(.v10_15)
.macOS(.v13)
],
products: [
.library(name: "NatsSwift", targets: ["NatsSwift"])
Expand All @@ -14,12 +14,14 @@ let package = Package(
.package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.4.2"),
.package(url: "https://github.com/nats-io/nkeys.swift.git", from: "0.1.1"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"),
],
targets: [
.target(
name: "NatsSwift",
dependencies: [
.product(name: "NIO", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl"),
.product(name: "Logging", package: "swift-log"),
.product(name: "NIOFoundationCompat", package: "swift-nio"),
.product(name: "NKeys", package: "nkeys.swift"),
Expand Down
33 changes: 32 additions & 1 deletion Sources/NatsSwift/NatsClient/NatsClientOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ public class ClientOptions {
private var reconnectWait: TimeInterval = 2.0
private var maxReconnects: Int = 60
private var auth: Auth? = nil
private var withTls: Bool = false
private var tlsFirst: Bool = false
private var rootCertificate: URL? = nil
private var clientCertificate: URL? = nil
private var clientKey: URL? = nil

public init() {}

Expand Down Expand Up @@ -70,6 +75,27 @@ public class ClientOptions {
return self
}

public func requireTls() -> ClientOptions {
self.withTls = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and maybe for consistency rename then that one too?

return self
}

public func withTlsFirst() -> ClientOptions {
self.tlsFirst = true
return self
}

public func rootCertificates(_ rootCertificate: URL) -> ClientOptions {
self.rootCertificate = rootCertificate
return self
}

public func clientCertificate(_ clientCertificate: URL, _ clientKey: URL) -> ClientOptions {
self.clientCertificate = clientCertificate
self.clientKey = clientKey
return self
}

public func build() -> Client {
let client = Client()
client.connectionHandler = ConnectionHandler(
Expand All @@ -78,7 +104,12 @@ public class ClientOptions {
reconnectWait: reconnectWait,
maxReconnects: maxReconnects,
pingInterval: pingInterval,
auth: auth
auth: auth,
requireTls: withTls,
tlsFirst: tlsFirst,
clientCertificate: clientCertificate,
clientKey: clientKey,
rootCertificate: rootCertificate
)

return client
Expand Down
105 changes: 95 additions & 10 deletions Sources/NatsSwift/NatsConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Dispatch
import Foundation
import NIO
import NIOFoundationCompat
import NIOSSL
import NKeys

class ConnectionHandler: ChannelInboundHandler {
Expand All @@ -25,6 +26,11 @@ class ConnectionHandler: ChannelInboundHandler {
internal let reconnectWait: UInt64
internal let maxReconnects: Int?
internal let pingInterval: TimeInterval
internal let requireTls: Bool
internal let tlsFirst: Bool
internal var rootCertificate: URL?
internal var clientCertificate: URL?
internal var clientKey: URL?

typealias InboundIn = ByteBuffer
internal var state: NatsState = .pending
Expand Down Expand Up @@ -96,7 +102,6 @@ class ConnectionHandler: ChannelInboundHandler {
do {
try self.write(operation: .pong)
} catch {
// TODO(pp): handle async error
logger.error("error sending pong: \(error)")
self.fire(.error(NatsClientError("error sending pong: \(error)")))
continue
Expand All @@ -117,7 +122,6 @@ class ConnectionHandler: ChannelInboundHandler {
} else {
self.fire(.error(err))
}
// TODO(pp): handle auth errors here
case .message(let msg):
self.handleIncomingMessage(msg)
case .hMessage(let msg):
Expand All @@ -133,7 +137,9 @@ class ConnectionHandler: ChannelInboundHandler {
}
init(
inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?,
pingInterval: TimeInterval, auth: Auth?
pingInterval: TimeInterval, auth: Auth?, requireTls: Bool, tlsFirst: Bool,
clientCertificate: URL?, clientKey: URL?,
rootCertificate: URL?
) {
self.inputBuffer = self.allocator.buffer(capacity: 1024)
self.urls = urls
Expand All @@ -144,6 +150,11 @@ class ConnectionHandler: ChannelInboundHandler {
self.maxReconnects = maxReconnects
self.auth = auth
self.pingInterval = pingInterval
self.requireTls = requireTls
self.tlsFirst = tlsFirst
self.clientCertificate = clientCertificate
self.clientKey = clientKey
self.rootCertificate = rootCertificate
}

internal var group: MultiThreadedEventLoopGroup
Expand All @@ -164,16 +175,58 @@ class ConnectionHandler: ChannelInboundHandler {
value: 1
)
.channelInitializer { channel in
//Fixme(jrm): do not ignore error from addHandler future.
channel.pipeline.addHandler(self).whenComplete { result in
switch result {
case .success():
if self.requireTls && self.tlsFirst {
var tlsConfiguration = TLSConfiguration.makeClientConfiguration()
do {
if let rootCertificate = self.rootCertificate {
tlsConfiguration.trustRoots = .file(rootCertificate.path)
}
if let clientCertificate = self.clientCertificate,
let clientKey = self.clientKey
{
// Load the client certificate from the PEM file
let certificate = try NIOSSLCertificate.fromPEMFile(
clientCertificate.path
).map { NIOSSLCertificateSource.certificate($0) }
tlsConfiguration.certificateChain = certificate

// Load the private key from the file
let privateKey = try NIOSSLPrivateKey(
file: clientKey.path, format: .pem)
tlsConfiguration.privateKey = .privateKey(privateKey)
}
let sslContext = try NIOSSLContext(
configuration: tlsConfiguration)
// FIXME(jrm): Consider better way to pick hostname.
let sslHandler = try NIOSSLClientHandler(
context: sslContext, serverHostname: self.urls[0].host()!)
//Fixme(jrm): do not ignore error from addHandler future.
channel.pipeline.addHandler(sslHandler).flatMap { _ in
channel.pipeline.addHandler(self)
}.whenComplete { result in
switch result {
case .success():
print("success")
case .failure(let error):
print("error: \(error)")
}
}
return channel.eventLoop.makeSucceededFuture(())
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
} else {
//Fixme(jrm): do not ignore error from addHandler future.
channel.pipeline.addHandler(self).whenComplete { result in
switch result {
case .success():
logger.debug("success")
case .failure(let error):
case .failure(let error):
logger.debug("error: \(error)")
}
}
return channel.eventLoop.makeSucceededFuture(())
}
return channel.eventLoop.makeSucceededFuture(())
}.connectTimeout(.seconds(5))
guard let url = self.urls.first, let host = url.host, let port = url.port else {
throw NatsClientError("no url")
Expand All @@ -186,7 +239,28 @@ class ConnectionHandler: ChannelInboundHandler {
// Wait for the first message after sending the connect request
}
self.serverInfo = info
// TODO(jrm): Add rest of auth here.
if (info.tlsRequired ?? false || self.requireTls) && !self.tlsFirst {
var tlsConfiguration = TLSConfiguration.makeClientConfiguration()
if let rootCertificate = self.rootCertificate {
tlsConfiguration.trustRoots = .file(rootCertificate.path)
}
if let clientCertificate = self.clientCertificate, let clientKey = self.clientKey {
// Load the client certificate from the PEM file
let certificate = try NIOSSLCertificate.fromPEMFile(clientCertificate.path).map {
NIOSSLCertificateSource.certificate($0)
}
tlsConfiguration.certificateChain = certificate

// Load the private key from the file
let privateKey = try NIOSSLPrivateKey(file: clientKey.path, format: .pem)
tlsConfiguration.privateKey = .privateKey(privateKey)
}
let sslContext = try NIOSSLContext(configuration: tlsConfiguration)
// FIXME(jrm): Consider better way to pick hostname.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we pick the hostname in Go?
We discussed it a bit - but maybe we should pick the hostname of the URL we are currently trying to connect to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but currently we always pick the first url anyway. I'll add url randomization in the next PR and use that one.

let sslHandler = try NIOSSLClientHandler(
context: sslContext, serverHostname: self.urls[0].host()!)
try await self.channel?.pipeline.addHandler(sslHandler, position: .first)
}

var initialConnect = ConnectInfo(
verbose: false, pedantic: false, userJwt: nil, nkey: "", name: "", echo: true,
Expand Down Expand Up @@ -293,6 +367,17 @@ class ConnectionHandler: ChannelInboundHandler {
} else {
logger.error("unexpected error: \(error)")
}
if let continuation = self.serverInfoContinuation {
self.serverInfoContinuation = nil
continuation.resume(throwing: error)
return
}

if let continuation = self.connectionEstablishedContinuation {
self.connectionEstablishedContinuation = nil
continuation.resume(throwing: error)
return
}
if self.state == .pending {
handleDisconnect()
} else if self.state == .disconnected {
Expand Down
112 changes: 112 additions & 0 deletions Tests/NatsSwiftTests/Integration/ConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class CoreNatsTests: XCTestCase {
("testReconnect", testReconnect),
("testUsernameAndPassword", testUsernameAndPassword),
("testTokenAuth", testTokenAuth),
("testCredentialsAuth", testCredentialsAuth),
("testTlsWithDemoServer", testTlsWithDemoServer),
("testMutualTls", testMutualTls),
("testTlsFirst", testTlsFirst),
]
var natsServer = NatsServer()

Expand Down Expand Up @@ -236,6 +240,7 @@ class CoreNatsTests: XCTestCase {
testsDir
.appendingPathComponent("Integration/Resources/jwt.conf", isDirectory: false)
natsServer.start(cfg: resourceURL.path)
print("server started with file: \(resourceURL.path)")

let credsURL = testsDir.appendingPathComponent(
"Integration/Resources/TestUser.creds", isDirectory: false)
Expand All @@ -248,6 +253,113 @@ class CoreNatsTests: XCTestCase {
try client.publish("data".data(using: .utf8)!, subject: "foo")
let message = await subscribe.next()
print("message: \(message!.subject)")
}

func testTlsWithDemoServer() async throws {
logger.logLevel = .debug

natsServer.start()
let client = ClientOptions()
.url(URL(string: "tls://demo.nats.io:4222")!)
.requireTls()
.build()

try await client.connect()
print("connected")
try client.publish("msg".data(using: .utf8)!, subject: "test")
try await client.flush()
}

func testMutualTls() async throws {
logger.logLevel = .debug
let currentFile = URL(fileURLWithPath: #file)
// Navigate up to the Tests directory
let testsDir = currentFile.deletingLastPathComponent().deletingLastPathComponent()
// Construct the path to the resource
let resourceURL =
testsDir
.appendingPathComponent("Integration/Resources/tls.conf", isDirectory: false)
natsServer.start(cfg: resourceURL.path)
let certsURL = testsDir.appendingPathComponent(
"Integration/Resources/certs/rootCA.pem", isDirectory: false)
let client = ClientOptions()
.url(URL(string: natsServer.clientURL)!)
.requireTls()
.rootCertificates(certsURL)
.clientCertificate(
testsDir.appendingPathComponent(
"Integration/Resources/certs/client-cert.pem", isDirectory: false),
testsDir.appendingPathComponent(
"Integration/Resources/certs/client-key.pem", isDirectory: false)
)
.build()
try await client.connect()
try client.publish("msg".data(using: .utf8)!, subject: "test")
try await client.flush()
_ = try await client.subscribe(to: "test")
XCTAssertNotNil(client, "Client should not be nil")
}

func testTlsFirst() async throws {
logger.logLevel = .debug
let currentFile = URL(fileURLWithPath: #file)
// Navigate up to the Tests directory
let testsDir = currentFile.deletingLastPathComponent().deletingLastPathComponent()
// Construct the path to the resource
let resourceURL =
testsDir
.appendingPathComponent("Integration/Resources/tls_first.conf", isDirectory: false)
natsServer.start(cfg: resourceURL.path)
let certsURL = testsDir.appendingPathComponent(
"Integration/Resources/certs/rootCA.pem", isDirectory: false)
let client = ClientOptions()
.url(URL(string: natsServer.clientURL)!)
.requireTls()
.rootCertificates(certsURL)
.clientCertificate(
testsDir.appendingPathComponent(
"Integration/Resources/certs/client-cert.pem", isDirectory: false),
testsDir.appendingPathComponent(
"Integration/Resources/certs/client-key.pem", isDirectory: false)
)
.withTlsFirst()
.build()
try await client.connect()
try client.publish("msg".data(using: .utf8)!, subject: "test")
try await client.flush()
_ = try await client.subscribe(to: "test")
XCTAssertNotNil(client, "Client should not be nil")
}

func testInvalidCertificate() async throws {
logger.logLevel = .debug
let currentFile = URL(fileURLWithPath: #file)
// Navigate up to the Tests directory
let testsDir = currentFile.deletingLastPathComponent().deletingLastPathComponent()
// Construct the path to the resource
let resourceURL =
testsDir
.appendingPathComponent("Integration/Resources/tls.conf", isDirectory: false)
natsServer.start(cfg: resourceURL.path)
let certsURL = testsDir.appendingPathComponent(
"Integration/Resources/certs/rootCA.pem", isDirectory: false)
let client = ClientOptions()
.url(URL(string: natsServer.clientURL)!)
.requireTls()
.rootCertificates(certsURL)
.clientCertificate(
testsDir.appendingPathComponent(
"Integration/Resources/certs/client-cert-invalid.pem", isDirectory: false),
testsDir.appendingPathComponent(
"Integration/Resources/certs/client-key-invalid.pem", isDirectory: false)
)
.build()
do {
try await client.connect()
} catch {
return
}
XCTFail("Expected error from connect")
}
}

Loading
Loading