159 changes: 159 additions & 0 deletions helper/tlsutil/generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package tlsutil

import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"io"
"net"
"testing"
"time"

"strings"

"github.com/stretchr/testify/require"
)

func TestSerialNumber(t *testing.T) {
n1, err := GenerateSerialNumber()
require.Nil(t, err)

n2, err := GenerateSerialNumber()
require.Nil(t, err)
require.NotEqual(t, n1, n2)

n3, err := GenerateSerialNumber()
require.Nil(t, err)
require.NotEqual(t, n1, n3)
require.NotEqual(t, n2, n3)

}

func TestGeneratePrivateKey(t *testing.T) {
t.Parallel()
_, p, err := GeneratePrivateKey()
require.Nil(t, err)
require.NotEmpty(t, p)
require.Contains(t, p, "BEGIN EC PRIVATE KEY")
require.Contains(t, p, "END EC PRIVATE KEY")

block, _ := pem.Decode([]byte(p))
pk, err := x509.ParseECPrivateKey(block.Bytes)

require.Nil(t, err)
require.NotNil(t, pk)
require.Equal(t, 256, pk.Params().BitSize)
}

type TestSigner struct {
public interface{}
}

func (s *TestSigner) Public() crypto.PublicKey {
return s.public
}

func (s *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
return []byte{}, nil
}

func TestGenerateCA(t *testing.T) {
t.Run("no signer", func(t *testing.T) {
ca, pk, err := GenerateCA(CAOpts{Signer: &TestSigner{}})
require.Error(t, err)
require.Empty(t, ca)
require.Empty(t, pk)
})

t.Run("wrong key", func(t *testing.T) {
ca, pk, err := GenerateCA(CAOpts{Signer: &TestSigner{public: &rsa.PublicKey{}}})
require.Error(t, err)
require.Empty(t, ca)
require.Empty(t, pk)
})

t.Run("valid key", func(t *testing.T) {
ca, pk, err := GenerateCA(CAOpts{})
require.Nil(t, err)
require.NotEmpty(t, ca)
require.NotEmpty(t, pk)

cert, err := parseCert(ca)
require.Nil(t, err)
require.True(t, strings.HasPrefix(cert.Subject.CommonName, "Consul Agent CA"))
require.Equal(t, true, cert.IsCA)
require.Equal(t, true, cert.BasicConstraintsValid)

require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute)
require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute)

require.Equal(t, x509.KeyUsageCertSign|x509.KeyUsageCRLSign|x509.KeyUsageDigitalSignature, cert.KeyUsage)
})

t.Run("RSA key", func(t *testing.T) {
ca, pk, err := GenerateCA(CAOpts{})
require.NoError(t, err)
require.NotEmpty(t, ca)
require.NotEmpty(t, pk)

cert, err := parseCert(ca)
require.NoError(t, err)
require.True(t, strings.HasPrefix(cert.Subject.CommonName, "Consul Agent CA"))
require.Equal(t, true, cert.IsCA)
require.Equal(t, true, cert.BasicConstraintsValid)

require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute)
require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute)

require.Equal(t, x509.KeyUsageCertSign|x509.KeyUsageCRLSign|x509.KeyUsageDigitalSignature, cert.KeyUsage)
})
}

func TestGenerateCert(t *testing.T) {
t.Parallel()
signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.Nil(t, err)
ca, _, err := GenerateCA(CAOpts{Signer: signer})
require.Nil(t, err)

DNSNames := []string{"server.dc1.consul"}
IPAddresses := []net.IP{net.ParseIP("123.234.243.213")}
extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
name := "Cert Name"
certificate, pk, err := GenerateCert(CertOpts{
Signer: signer, CA: ca, Name: name, Days: 365,
DNSNames: DNSNames, IPAddresses: IPAddresses, ExtKeyUsage: extKeyUsage,
})
require.Nil(t, err)
require.NotEmpty(t, certificate)
require.NotEmpty(t, pk)

cert, err := parseCert(certificate)
require.Nil(t, err)
require.Equal(t, name, cert.Subject.CommonName)
require.Equal(t, true, cert.BasicConstraintsValid)
signee, err := ParseSigner(pk)
require.Nil(t, err)
certID, err := keyID(signee.Public())
require.Nil(t, err)
require.Equal(t, certID, cert.SubjectKeyId)
caID, err := keyID(signer.Public())
require.Nil(t, err)
require.Equal(t, caID, cert.AuthorityKeyId)
require.Contains(t, cert.Issuer.CommonName, "Consul Agent CA")
require.Equal(t, false, cert.IsCA)

require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute)
require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute)

require.Equal(t, x509.KeyUsageDigitalSignature|x509.KeyUsageKeyEncipherment, cert.KeyUsage)
require.Equal(t, extKeyUsage, cert.ExtKeyUsage)

// https://github.com/golang/go/blob/10538a8f9e2e718a47633ac5a6e90415a2c3f5f1/src/crypto/x509/verify.go#L414
require.Equal(t, DNSNames, cert.DNSNames)
require.True(t, IPAddresses[0].Equal(cert.IPAddresses[0]))
}
40 changes: 40 additions & 0 deletions nomad/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC

case pool.RpcRaft:
metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1)
// Ensure that when TLS is configured, only certificates from `server.<region>.nomad` are accepted for Raft connections.
if err := r.validateRaftTLS(rpcCtx); err != nil {
conn.Close()
return
}
r.raftLayer.Handoff(ctx, conn)

case pool.RpcMultiplex:
Expand Down Expand Up @@ -825,3 +830,38 @@ RUN_QUERY:
}
return err
}

func (r *rpcHandler) validateRaftTLS(rpcCtx *RPCContext) error {
// TLS is not configured or not to be enforced
tlsConf := r.config.TLSConfig
if !tlsConf.EnableRPC || !tlsConf.VerifyServerHostname || tlsConf.RPCUpgradeMode {
return nil
}

// defensive conditions: these should have already been enforced by handleConn
if rpcCtx == nil || !rpcCtx.TLS {
return errors.New("non-TLS connection attempted")
}
if len(rpcCtx.VerifiedChains) == 0 || len(rpcCtx.VerifiedChains[0]) == 0 {
// this should never happen, as rpcNameAndRegionValidate should have enforced it
return errors.New("missing cert info")
}

// check that `server.<region>.nomad` is present in cert
expected := "server." + r.Region() + ".nomad"

cert := rpcCtx.VerifiedChains[0][0]
for _, dnsName := range cert.DNSNames {
if dnsName == expected {
// Certificate is valid for the expected name
return nil
}
}
if cert.Subject.CommonName == expected {
// Certificate is valid for the expected name
return nil
}

r.logger.Warn("unauthorized connection", "required_hostname", expected, "found", cert.DNSNames)
return fmt.Errorf("certificate is invalid for expected role or region: %q", expected)
}
229 changes: 229 additions & 0 deletions nomad/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package nomad
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/rpc"
"os"
"path"
"path/filepath"
"testing"
"time"

Expand Down Expand Up @@ -1014,3 +1018,228 @@ func TestRPC_Limits_Streaming(t *testing.T) {
require.NoError(t, err)
})
}

func TestRPC_TLS_Enforcement(t *testing.T) {
t.Parallel()

defer func() {
//TODO Avoid panics from logging during shutdown
time.Sleep(1 * time.Second)
}()

dir := tmpDir(t)
defer os.RemoveAll(dir)

caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"})
require.NoError(t, err)

err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600)
require.NoError(t, err)

nodeID := 1
newCert := func(t *testing.T, name string) string {
t.Helper()

node := fmt.Sprintf("node%d", nodeID)
nodeID++
signer, err := tlsutil.ParseSigner(pk)
require.NoError(t, err)

pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{
Signer: signer,
CA: caPEM,
Name: name,
Days: 5,
DNSNames: []string{node + "." + name, name, "localhost"},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
})
require.NoError(t, err)

err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600)
require.NoError(t, err)
err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600)
require.NoError(t, err)

return filepath.Join(dir, node+"-"+name)
}

connect := func(t *testing.T, s *Server, c *config.TLSConfig) net.Conn {
conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
require.NoError(t, err)

// configure TLS
_, err = conn.Write([]byte{byte(pool.RpcTLS)})
require.NoError(t, err)

// Client TLS verification isn't necessary for
// our assertions
tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true)
require.NoError(t, err)
outTLSConf, err := tlsConf.OutgoingTLSConfig()
require.NoError(t, err)
outTLSConf.InsecureSkipVerify = true

tlsConn := tls.Client(conn, outTLSConf)
require.NoError(t, tlsConn.Handshake())

return tlsConn
}

nomadRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
conn := connect(t, s, c)
defer conn.Close()
_, err := conn.Write([]byte{byte(pool.RpcNomad)})
require.NoError(t, err)

codec := pool.NewClientCodec(conn)

arg := struct{}{}
var out struct{}
return msgpackrpc.CallWithCodec(codec, "Status.Ping", arg, &out)
}

parseCert := func(t *testing.T, path string) *x509.Certificate {
bytes, err := ioutil.ReadFile(path)
require.NoError(t, err)
block, _ := pem.Decode([]byte(bytes))
require.NoError(t, err)
cert, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err)

return cert
}

raftRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error {
// TODO: Actually make an RPC Call
// Raft Layer requires a wellformed RPC, otherwise, the connection
// is closed immediately - similar to the RaftTLS failure

clientCert := parseCert(t, c.CertFile)
rootCert := parseCert(t, c.CAFile)

rpcCtx := &RPCContext{
TLS: true,
VerifiedChains: [][]*x509.Certificate{
{clientCert, rootCert},
},
}

return s.validateRaftTLS(rpcCtx)
}

// generate server cert
serverCert := newCert(t, "server.global.nomad")

mtlsS, cleanup := TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: serverCert + ".pem",
KeyFile: serverCert + ".key",
}
})
defer cleanup()

nonVerifyS, cleanup := TestServer(t, func(c *Config) {
c.TLSConfig = &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: false,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: serverCert + ".pem",
KeyFile: serverCert + ".key",
}
})
defer cleanup()

// When VerifyServerHostname is enabled:
// Only all servers and local clients can make RPC requests
// Only local servers can connect to the Raft layer
cases := []struct {
name string
cn string
canRPC bool
canRaft bool
}{
{
name: "local server",
cn: "server.global.nomad",
canRPC: true,
canRaft: true,
},
{
name: "local client",
cn: "client.global.nomad",
canRPC: true,
canRaft: false,
},
{
name: "other region server",
cn: "server.other.nomad",
canRPC: true,
canRaft: false,
},
{
name: "other client server",
cn: "client.other.nomad",
canRPC: false,
canRaft: false,
},
{
name: "irrelevant cert",
cn: "nomad.example.com",
canRPC: false,
canRaft: false,
},
{
name: "globs",
cn: "*.global.nomad",
canRPC: false,
canRaft: false,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
certPath := newCert(t, tc.cn)

cfg := &config.TLSConfig{
EnableRPC: true,
VerifyServerHostname: true,
CAFile: filepath.Join(dir, "ca.pem"),
CertFile: certPath + ".pem",
KeyFile: certPath + ".key",
}

t.Run("nomad RPC: verify_hostname=true", func(t *testing.T) {
err := nomadRPC(t, mtlsS, cfg)

if tc.canRPC {
require.NoError(t, err)
} else {
require.Error(t, err)
require.Contains(t, err.Error(), "bad certificate")
}
})
t.Run("nomad RPC: verify_hostname=false", func(t *testing.T) {
err := nomadRPC(t, nonVerifyS, cfg)
require.NoError(t, err)
})

t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) {
err := raftRPC(t, mtlsS, cfg)

if tc.canRaft {
require.NoError(t, err)
} else {
require.Error(t, err)
require.Contains(t, err.Error(), "invalid for expected role or region")
}
})
t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) {
err := raftRPC(t, nonVerifyS, cfg)
require.NoError(t, err)
})
})
}
}