-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tls support #33
Add tls support #33
Changes from all commits
e835516
08288e1
0bc491a
96c41cf
4590450
b8e6329
8ca251b
cf2703f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ import Dispatch | |
import Foundation | ||
import NIO | ||
import NIOFoundationCompat | ||
import NIOSSL | ||
import NKeys | ||
|
||
class ConnectionHandler: ChannelInboundHandler { | ||
|
@@ -25,6 +26,11 @@ class ConnectionHandler: ChannelInboundHandler { | |
internal let reconnectWait: UInt64 | ||
internal let maxReconnects: Int? | ||
internal let pingInterval: TimeInterval | ||
internal let requireTls: Bool | ||
internal let tlsFirst: Bool | ||
internal var rootCertificate: URL? | ||
internal var clientCertificate: URL? | ||
internal var clientKey: URL? | ||
|
||
typealias InboundIn = ByteBuffer | ||
internal var state: NatsState = .pending | ||
|
@@ -96,7 +102,6 @@ class ConnectionHandler: ChannelInboundHandler { | |
do { | ||
try self.write(operation: .pong) | ||
} catch { | ||
// TODO(pp): handle async error | ||
logger.error("error sending pong: \(error)") | ||
self.fire(.error(NatsClientError("error sending pong: \(error)"))) | ||
continue | ||
|
@@ -117,7 +122,6 @@ class ConnectionHandler: ChannelInboundHandler { | |
} else { | ||
self.fire(.error(err)) | ||
} | ||
// TODO(pp): handle auth errors here | ||
case .message(let msg): | ||
self.handleIncomingMessage(msg) | ||
case .hMessage(let msg): | ||
|
@@ -133,7 +137,9 @@ class ConnectionHandler: ChannelInboundHandler { | |
} | ||
init( | ||
inputBuffer: ByteBuffer, urls: [URL], reconnectWait: TimeInterval, maxReconnects: Int?, | ||
pingInterval: TimeInterval, auth: Auth? | ||
pingInterval: TimeInterval, auth: Auth?, requireTls: Bool, tlsFirst: Bool, | ||
clientCertificate: URL?, clientKey: URL?, | ||
rootCertificate: URL? | ||
) { | ||
self.inputBuffer = self.allocator.buffer(capacity: 1024) | ||
self.urls = urls | ||
|
@@ -144,6 +150,11 @@ class ConnectionHandler: ChannelInboundHandler { | |
self.maxReconnects = maxReconnects | ||
self.auth = auth | ||
self.pingInterval = pingInterval | ||
self.requireTls = requireTls | ||
self.tlsFirst = tlsFirst | ||
self.clientCertificate = clientCertificate | ||
self.clientKey = clientKey | ||
self.rootCertificate = rootCertificate | ||
} | ||
|
||
internal var group: MultiThreadedEventLoopGroup | ||
|
@@ -164,16 +175,58 @@ class ConnectionHandler: ChannelInboundHandler { | |
value: 1 | ||
) | ||
.channelInitializer { channel in | ||
//Fixme(jrm): do not ignore error from addHandler future. | ||
channel.pipeline.addHandler(self).whenComplete { result in | ||
switch result { | ||
case .success(): | ||
if self.requireTls && self.tlsFirst { | ||
var tlsConfiguration = TLSConfiguration.makeClientConfiguration() | ||
do { | ||
if let rootCertificate = self.rootCertificate { | ||
tlsConfiguration.trustRoots = .file(rootCertificate.path) | ||
} | ||
if let clientCertificate = self.clientCertificate, | ||
let clientKey = self.clientKey | ||
{ | ||
// Load the client certificate from the PEM file | ||
let certificate = try NIOSSLCertificate.fromPEMFile( | ||
clientCertificate.path | ||
).map { NIOSSLCertificateSource.certificate($0) } | ||
tlsConfiguration.certificateChain = certificate | ||
|
||
// Load the private key from the file | ||
let privateKey = try NIOSSLPrivateKey( | ||
file: clientKey.path, format: .pem) | ||
tlsConfiguration.privateKey = .privateKey(privateKey) | ||
} | ||
let sslContext = try NIOSSLContext( | ||
configuration: tlsConfiguration) | ||
// FIXME(jrm): Consider better way to pick hostname. | ||
let sslHandler = try NIOSSLClientHandler( | ||
context: sslContext, serverHostname: self.urls[0].host()!) | ||
//Fixme(jrm): do not ignore error from addHandler future. | ||
channel.pipeline.addHandler(sslHandler).flatMap { _ in | ||
channel.pipeline.addHandler(self) | ||
}.whenComplete { result in | ||
switch result { | ||
case .success(): | ||
print("success") | ||
case .failure(let error): | ||
print("error: \(error)") | ||
} | ||
} | ||
return channel.eventLoop.makeSucceededFuture(()) | ||
} catch { | ||
return channel.eventLoop.makeFailedFuture(error) | ||
} | ||
} else { | ||
//Fixme(jrm): do not ignore error from addHandler future. | ||
channel.pipeline.addHandler(self).whenComplete { result in | ||
switch result { | ||
case .success(): | ||
logger.debug("success") | ||
case .failure(let error): | ||
case .failure(let error): | ||
logger.debug("error: \(error)") | ||
} | ||
} | ||
return channel.eventLoop.makeSucceededFuture(()) | ||
} | ||
return channel.eventLoop.makeSucceededFuture(()) | ||
}.connectTimeout(.seconds(5)) | ||
guard let url = self.urls.first, let host = url.host, let port = url.port else { | ||
throw NatsClientError("no url") | ||
|
@@ -186,7 +239,28 @@ class ConnectionHandler: ChannelInboundHandler { | |
// Wait for the first message after sending the connect request | ||
} | ||
self.serverInfo = info | ||
// TODO(jrm): Add rest of auth here. | ||
if (info.tlsRequired ?? false || self.requireTls) && !self.tlsFirst { | ||
var tlsConfiguration = TLSConfiguration.makeClientConfiguration() | ||
if let rootCertificate = self.rootCertificate { | ||
tlsConfiguration.trustRoots = .file(rootCertificate.path) | ||
} | ||
if let clientCertificate = self.clientCertificate, let clientKey = self.clientKey { | ||
// Load the client certificate from the PEM file | ||
let certificate = try NIOSSLCertificate.fromPEMFile(clientCertificate.path).map { | ||
NIOSSLCertificateSource.certificate($0) | ||
} | ||
tlsConfiguration.certificateChain = certificate | ||
|
||
// Load the private key from the file | ||
let privateKey = try NIOSSLPrivateKey(file: clientKey.path, format: .pem) | ||
tlsConfiguration.privateKey = .privateKey(privateKey) | ||
} | ||
let sslContext = try NIOSSLContext(configuration: tlsConfiguration) | ||
// FIXME(jrm): Consider better way to pick hostname. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we pick the hostname in Go? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, but currently we always pick the first url anyway. I'll add url randomization in the next PR and use that one. |
||
let sslHandler = try NIOSSLClientHandler( | ||
context: sslContext, serverHostname: self.urls[0].host()!) | ||
try await self.channel?.pipeline.addHandler(sslHandler, position: .first) | ||
} | ||
|
||
var initialConnect = ConnectInfo( | ||
verbose: false, pedantic: false, userJwt: nil, nkey: "", name: "", echo: true, | ||
|
@@ -293,6 +367,17 @@ class ConnectionHandler: ChannelInboundHandler { | |
} else { | ||
logger.error("unexpected error: \(error)") | ||
} | ||
if let continuation = self.serverInfoContinuation { | ||
self.serverInfoContinuation = nil | ||
continuation.resume(throwing: error) | ||
return | ||
} | ||
|
||
if let continuation = self.connectionEstablishedContinuation { | ||
self.connectionEstablishedContinuation = nil | ||
continuation.resume(throwing: error) | ||
return | ||
} | ||
if self.state == .pending { | ||
handleDisconnect() | ||
} else if self.state == .disconnected { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and maybe for consistency rename then that one too?