Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADDED] Setting TLS config with callbacks in Connect #1413

Merged
merged 1 commit into from Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
133 changes: 84 additions & 49 deletions nats.go
Expand Up @@ -90,55 +90,56 @@ const (

// Errors
var (
ErrConnectionClosed = errors.New("nats: connection closed")
ErrConnectionDraining = errors.New("nats: connection draining")
ErrDrainTimeout = errors.New("nats: draining connection timed out")
ErrConnectionReconnecting = errors.New("nats: connection reconnecting")
ErrSecureConnRequired = errors.New("nats: secure connection required")
ErrSecureConnWanted = errors.New("nats: secure connection not available")
ErrBadSubscription = errors.New("nats: invalid subscription")
ErrTypeSubscription = errors.New("nats: invalid subscription type")
ErrBadSubject = errors.New("nats: invalid subject")
ErrBadQueueName = errors.New("nats: invalid queue name")
ErrSlowConsumer = errors.New("nats: slow consumer, messages dropped")
ErrTimeout = errors.New("nats: timeout")
ErrBadTimeout = errors.New("nats: timeout invalid")
ErrAuthorization = errors.New("nats: authorization violation")
ErrAuthExpired = errors.New("nats: authentication expired")
ErrAuthRevoked = errors.New("nats: authentication revoked")
ErrAccountAuthExpired = errors.New("nats: account authentication expired")
ErrNoServers = errors.New("nats: no servers available for connection")
ErrJsonParse = errors.New("nats: connect message, json parse error")
ErrChanArg = errors.New("nats: argument needs to be a channel type")
ErrMaxPayload = errors.New("nats: maximum payload exceeded")
ErrMaxMessages = errors.New("nats: maximum messages delivered")
ErrSyncSubRequired = errors.New("nats: illegal call on an async subscription")
ErrMultipleTLSConfigs = errors.New("nats: multiple tls.Configs not allowed")
ErrNoInfoReceived = errors.New("nats: protocol exception, INFO not received")
ErrReconnectBufExceeded = errors.New("nats: outbound buffer limit exceeded")
ErrInvalidConnection = errors.New("nats: invalid connection")
ErrInvalidMsg = errors.New("nats: invalid message or message nil")
ErrInvalidArg = errors.New("nats: invalid argument")
ErrInvalidContext = errors.New("nats: invalid context")
ErrNoDeadlineContext = errors.New("nats: context requires a deadline")
ErrNoEchoNotSupported = errors.New("nats: no echo option not supported by this server")
ErrClientIDNotSupported = errors.New("nats: client ID not supported by this server")
ErrUserButNoSigCB = errors.New("nats: user callback defined without a signature handler")
ErrNkeyButNoSigCB = errors.New("nats: nkey defined without a signature handler")
ErrNoUserCB = errors.New("nats: user callback not defined")
ErrNkeyAndUser = errors.New("nats: user callback and nkey defined")
ErrNkeysNotSupported = errors.New("nats: nkeys not supported by the server")
ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION)
ErrTokenAlreadySet = errors.New("nats: token and token handler both set")
ErrMsgNotBound = errors.New("nats: message is not bound to subscription/connection")
ErrMsgNoReply = errors.New("nats: message does not have a reply")
ErrClientIPNotSupported = errors.New("nats: client IP not supported by this server")
ErrDisconnected = errors.New("nats: server is disconnected")
ErrHeadersNotSupported = errors.New("nats: headers not supported by this server")
ErrBadHeaderMsg = errors.New("nats: message could not decode headers")
ErrNoResponders = errors.New("nats: no responders available for request")
ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded")
ErrConnectionNotTLS = errors.New("nats: connection is not tls")
ErrConnectionClosed = errors.New("nats: connection closed")
ErrConnectionDraining = errors.New("nats: connection draining")
ErrDrainTimeout = errors.New("nats: draining connection timed out")
ErrConnectionReconnecting = errors.New("nats: connection reconnecting")
ErrSecureConnRequired = errors.New("nats: secure connection required")
ErrSecureConnWanted = errors.New("nats: secure connection not available")
ErrBadSubscription = errors.New("nats: invalid subscription")
ErrTypeSubscription = errors.New("nats: invalid subscription type")
ErrBadSubject = errors.New("nats: invalid subject")
ErrBadQueueName = errors.New("nats: invalid queue name")
ErrSlowConsumer = errors.New("nats: slow consumer, messages dropped")
ErrTimeout = errors.New("nats: timeout")
ErrBadTimeout = errors.New("nats: timeout invalid")
ErrAuthorization = errors.New("nats: authorization violation")
ErrAuthExpired = errors.New("nats: authentication expired")
ErrAuthRevoked = errors.New("nats: authentication revoked")
ErrAccountAuthExpired = errors.New("nats: account authentication expired")
ErrNoServers = errors.New("nats: no servers available for connection")
ErrJsonParse = errors.New("nats: connect message, json parse error")
ErrChanArg = errors.New("nats: argument needs to be a channel type")
ErrMaxPayload = errors.New("nats: maximum payload exceeded")
ErrMaxMessages = errors.New("nats: maximum messages delivered")
ErrSyncSubRequired = errors.New("nats: illegal call on an async subscription")
ErrMultipleTLSConfigs = errors.New("nats: multiple tls.Configs not allowed")
ErrClientCertOrRootCAsRequired = errors.New("nats: at least one of certCB or rootCAsCB must be set")
ErrNoInfoReceived = errors.New("nats: protocol exception, INFO not received")
ErrReconnectBufExceeded = errors.New("nats: outbound buffer limit exceeded")
ErrInvalidConnection = errors.New("nats: invalid connection")
ErrInvalidMsg = errors.New("nats: invalid message or message nil")
ErrInvalidArg = errors.New("nats: invalid argument")
ErrInvalidContext = errors.New("nats: invalid context")
ErrNoDeadlineContext = errors.New("nats: context requires a deadline")
ErrNoEchoNotSupported = errors.New("nats: no echo option not supported by this server")
ErrClientIDNotSupported = errors.New("nats: client ID not supported by this server")
ErrUserButNoSigCB = errors.New("nats: user callback defined without a signature handler")
ErrNkeyButNoSigCB = errors.New("nats: nkey defined without a signature handler")
ErrNoUserCB = errors.New("nats: user callback not defined")
ErrNkeyAndUser = errors.New("nats: user callback and nkey defined")
ErrNkeysNotSupported = errors.New("nats: nkeys not supported by the server")
ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION)
ErrTokenAlreadySet = errors.New("nats: token and token handler both set")
ErrMsgNotBound = errors.New("nats: message is not bound to subscription/connection")
ErrMsgNoReply = errors.New("nats: message does not have a reply")
ErrClientIPNotSupported = errors.New("nats: client IP not supported by this server")
ErrDisconnected = errors.New("nats: server is disconnected")
ErrHeadersNotSupported = errors.New("nats: headers not supported by this server")
ErrBadHeaderMsg = errors.New("nats: message could not decode headers")
ErrNoResponders = errors.New("nats: no responders available for request")
ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded")
ErrConnectionNotTLS = errors.New("nats: connection is not tls")
)

// GetDefaultOptions returns default configuration options for the client.
Expand Down Expand Up @@ -864,6 +865,40 @@ func Secure(tls ...*tls.Config) Option {
}
}

// ClientTLSConfig is an Option to set the TLS configuration for secure
// connections. It can be used to e.g. set TLS config with cert and root CAs
// from memory. For simple use case of loading cert and CAs from file,
// ClientCert and RootCAs options are more convenient.
// If Secure is not already set this will set it as well.
func ClientTLSConfig(certCB TLSCertHandler, rootCAsCB RootCAsHandler) Option {
return func(o *Options) error {
o.Secure = true

if certCB == nil && rootCAsCB == nil {
return ErrClientCertOrRootCAsRequired
}

// Smoke test the callbacks to fail early
Jarema marked this conversation as resolved.
Show resolved Hide resolved
// if they are not valid.
if certCB != nil {
if _, err := certCB(); err != nil {
return err
}
}
if rootCAsCB != nil {
if _, err := rootCAsCB(); err != nil {
return err
}
}
if o.TLSConfig == nil {
o.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
o.TLSCertCB = certCB
o.RootCAsCB = rootCAsCB
return nil
}
}

// RootCAs is a helper option to provide the RootCAs pool from a list of filenames.
// If Secure is not already set this will set it as well.
func RootCAs(file ...string) Option {
Expand Down
98 changes: 98 additions & 0 deletions test/conn_test.go
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -234,6 +235,103 @@ func TestServerSecureConnections(t *testing.T) {
}
}

func TestClientTLSConfig(t *testing.T) {
s, opts := RunServerWithConfig("./configs/tlsverify.conf")
defer s.Shutdown()

endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
secureURL := fmt.Sprintf("nats://%s", endpoint)

// Make sure this fails
nc, err := nats.Connect(secureURL, nats.Secure())
if err == nil {
nc.Close()
t.Fatal("Should have failed (TLS) connection without client certificate")
}
cert, err := os.ReadFile("./configs/certs/client-cert.pem")
if err != nil {
t.Fatal("Failed to read client certificate")
}
key, err := os.ReadFile("./configs/certs/client-key.pem")
if err != nil {
t.Fatal("Failed to read client key")
}
rootCAs, err := os.ReadFile("./configs/certs/ca.pem")
if err != nil {
t.Fatal("Failed to read root CAs")
}

certCB := func() (tls.Certificate, error) {
cert, err := tls.X509KeyPair(cert, key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("nats: error loading client certificate: %w", err)
}
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return tls.Certificate{}, fmt.Errorf("nats: error parsing client certificate: %w", err)
}
return cert, nil
}

caCB := func() (*x509.CertPool, error) {
pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(rootCAs)
if !ok {
return nil, fmt.Errorf("nats: failed to parse root certificate from")
}
return pool, nil
}

// Check parameters validity
_, err = nats.Connect(secureURL, nats.ClientTLSConfig(nil, nil))
if !errors.Is(err, nats.ErrClientCertOrRootCAsRequired) {
t.Fatalf("Expected error %q, got %q", nats.ErrClientCertOrRootCAsRequired, err)
}

certErr := &tls.CertificateVerificationError{}
// Should fail because of missing CA
_, err = nats.Connect(secureURL,
nats.ClientCert("./configs/certs/client-cert.pem", "./configs/certs/client-key.pem"))
if ok := errors.As(err, &certErr); !ok {
t.Fatalf("Expected error %q, got %q", nats.ErrClientCertOrRootCAsRequired, err)
}

// Should fail because of missing certificate
_, err = nats.Connect(secureURL,
nats.ClientTLSConfig(nil, caCB))
if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") {
t.Fatalf("Expected missing certificate error; got: %s", err)
}

nc, err = nats.Connect(secureURL,
nats.ClientTLSConfig(certCB, caCB))
if err != nil {
t.Fatalf("Failed to create (TLS) connection: %v", err)
}
defer nc.Close()

omsg := []byte("Hello!")
checkRecv := make(chan bool)

received := 0
nc.Subscribe("foo", func(m *nats.Msg) {
received++
if !bytes.Equal(m.Data, omsg) {
t.Fatal("Message received does not match")
}
checkRecv <- true
})
err = nc.Publish("foo", omsg)
if err != nil {
t.Fatalf("Failed to publish on secure (TLS) connection: %v", err)
}
nc.Flush()

if err := Wait(checkRecv); err != nil {
t.Fatal("Failed to receive message")
}
}

func TestClientCertificate(t *testing.T) {
s, opts := RunServerWithConfig("./configs/tlsverify.conf")
defer s.Shutdown()
Expand Down