From 61a922afcf12784281757402c8e0b61686ff855d Mon Sep 17 00:00:00 2001 From: Mahmood Ali Date: Thu, 15 Jul 2021 11:19:34 -0400 Subject: [PATCH] Apply authZ for nomad Raft RPC layer When mTLS is enabled, only nomad servers of the region should access the Raft RPC layer. Clients and servers in other regions should only use the Nomad RPC endpoints. Co-authored-by: Michael Schurter Co-authored-by: Seth Hoenig --- .changelog/11084.txt | 3 + helper/tlsutil/generate.go | 298 ++++++++++++++++++++++++++++++++ helper/tlsutil/generate_test.go | 159 +++++++++++++++++ nomad/rpc.go | 40 +++++ nomad/rpc_test.go | 229 ++++++++++++++++++++++++ 5 files changed, 729 insertions(+) create mode 100644 .changelog/11084.txt create mode 100644 helper/tlsutil/generate.go create mode 100644 helper/tlsutil/generate_test.go diff --git a/.changelog/11084.txt b/.changelog/11084.txt new file mode 100644 index 00000000000..01971be6f45 --- /dev/null +++ b/.changelog/11084.txt @@ -0,0 +1,3 @@ +```release-note:security +Restricted access to the Raft RPC layer, so only servers within the region can issue Raft RPC requests. Previously, local clients and federated servers can issue Raft RPC requests directly. CVE-2021-37218 +``` diff --git a/helper/tlsutil/generate.go b/helper/tlsutil/generate.go new file mode 100644 index 00000000000..ecbba85cb57 --- /dev/null +++ b/helper/tlsutil/generate.go @@ -0,0 +1,298 @@ +package tlsutil + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "time" +) + +// GenerateSerialNumber returns random bigint generated with crypto/rand +func GenerateSerialNumber() (*big.Int, error) { + l := new(big.Int).Lsh(big.NewInt(1), 128) + s, err := rand.Int(rand.Reader, l) + if err != nil { + return nil, err + } + return s, nil +} + +// GeneratePrivateKey generates a new ecdsa private key +func GeneratePrivateKey() (crypto.Signer, string, error) { + curve := elliptic.P256() + + pk, err := ecdsa.GenerateKey(curve, rand.Reader) + if err != nil { + return nil, "", fmt.Errorf("error generating ECDSA private key: %s", err) + } + + bs, err := x509.MarshalECPrivateKey(pk) + if err != nil { + return nil, "", fmt.Errorf("error marshaling ECDSA private key: %s", err) + } + + pemBlock, err := pemEncodeKey(bs, "EC PRIVATE KEY") + if err != nil { + return nil, "", err + } + + return pk, pemBlock, nil +} + +func pemEncodeKey(key []byte, blockType string) (string, error) { + var buf bytes.Buffer + + if err := pem.Encode(&buf, &pem.Block{Type: blockType, Bytes: key}); err != nil { + return "", fmt.Errorf("error encoding private key: %s", err) + } + return buf.String(), nil +} + +type CAOpts struct { + Signer crypto.Signer + Serial *big.Int + Days int + PermittedDNSDomains []string + Domain string + Name string +} + +type CertOpts struct { + Signer crypto.Signer + CA string + Serial *big.Int + Name string + Days int + DNSNames []string + IPAddresses []net.IP + ExtKeyUsage []x509.ExtKeyUsage +} + +// GenerateCA generates a new CA for agent TLS (not to be confused with Connect TLS) +func GenerateCA(opts CAOpts) (string, string, error) { + signer := opts.Signer + var pk string + if signer == nil { + var err error + signer, pk, err = GeneratePrivateKey() + if err != nil { + return "", "", err + } + } + + id, err := keyID(signer.Public()) + if err != nil { + return "", "", err + } + + sn := opts.Serial + if sn == nil { + var err error + sn, err = GenerateSerialNumber() + if err != nil { + return "", "", err + } + } + name := opts.Name + if name == "" { + name = fmt.Sprintf("Consul Agent CA %d", sn) + } + + days := opts.Days + if opts.Days == 0 { + days = 365 + } + + // Create the CA cert + template := x509.Certificate{ + SerialNumber: sn, + Subject: pkix.Name{ + Country: []string{"US"}, + PostalCode: []string{"94105"}, + Province: []string{"CA"}, + Locality: []string{"San Francisco"}, + StreetAddress: []string{"101 Second Street"}, + Organization: []string{"HashiCorp Inc."}, + CommonName: name, + }, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature, + IsCA: true, + NotAfter: time.Now().AddDate(0, 0, days), + NotBefore: time.Now(), + AuthorityKeyId: id, + SubjectKeyId: id, + } + + if len(opts.PermittedDNSDomains) > 0 { + template.PermittedDNSDomainsCritical = true + template.PermittedDNSDomains = opts.PermittedDNSDomains + } + bs, err := x509.CreateCertificate( + rand.Reader, &template, &template, signer.Public(), signer) + if err != nil { + return "", "", fmt.Errorf("error generating CA certificate: %s", err) + } + + var buf bytes.Buffer + err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs}) + if err != nil { + return "", "", fmt.Errorf("error encoding private key: %s", err) + } + + return buf.String(), pk, nil +} + +// GenerateCert generates a new certificate for agent TLS (not to be confused with Connect TLS) +func GenerateCert(opts CertOpts) (string, string, error) { + parent, err := parseCert(opts.CA) + if err != nil { + return "", "", err + } + + signee, pk, err := GeneratePrivateKey() + if err != nil { + return "", "", err + } + + id, err := keyID(signee.Public()) + if err != nil { + return "", "", err + } + + sn := opts.Serial + if sn == nil { + var err error + sn, err = GenerateSerialNumber() + if err != nil { + return "", "", err + } + } + + template := x509.Certificate{ + SerialNumber: sn, + Subject: pkix.Name{CommonName: opts.Name}, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: opts.ExtKeyUsage, + IsCA: false, + NotAfter: time.Now().AddDate(0, 0, opts.Days), + NotBefore: time.Now(), + SubjectKeyId: id, + DNSNames: opts.DNSNames, + IPAddresses: opts.IPAddresses, + } + + bs, err := x509.CreateCertificate(rand.Reader, &template, parent, signee.Public(), opts.Signer) + if err != nil { + return "", "", err + } + + var buf bytes.Buffer + err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: bs}) + if err != nil { + return "", "", fmt.Errorf("error encoding private key: %s", err) + } + + return buf.String(), pk, nil +} + +// KeyId returns a x509 KeyId from the given signing key. +func keyID(raw interface{}) ([]byte, error) { + switch raw.(type) { + case *ecdsa.PublicKey: + case *rsa.PublicKey: + default: + return nil, fmt.Errorf("invalid key type: %T", raw) + } + + // This is not standard; RFC allows any unique identifier as long as they + // match in subject/authority chains but suggests specific hashing of DER + // bytes of public key including DER tags. + bs, err := x509.MarshalPKIXPublicKey(raw) + if err != nil { + return nil, err + } + + // String formatted + kID := sha256.Sum256(bs) + return kID[:], nil +} + +func parseCert(pemValue string) (*x509.Certificate, error) { + // The _ result below is not an error but the remaining PEM bytes. + block, _ := pem.Decode([]byte(pemValue)) + if block == nil { + return nil, fmt.Errorf("no PEM-encoded data found") + } + + if block.Type != "CERTIFICATE" { + return nil, fmt.Errorf("first PEM-block should be CERTIFICATE type") + } + + return x509.ParseCertificate(block.Bytes) +} + +// ParseSigner parses a crypto.Signer from a PEM-encoded key. The private key +// is expected to be the first block in the PEM value. +func ParseSigner(pemValue string) (crypto.Signer, error) { + // The _ result below is not an error but the remaining PEM bytes. + block, _ := pem.Decode([]byte(pemValue)) + if block == nil { + return nil, fmt.Errorf("no PEM-encoded data found") + } + + switch block.Type { + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + + case "PRIVATE KEY": + signer, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + pk, ok := signer.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("private key is not a valid format") + } + + return pk, nil + + default: + return nil, fmt.Errorf("unknown PEM block type for signing key: %s", block.Type) + } +} + +func Verify(caString, certString, dns string) error { + roots := x509.NewCertPool() + ok := roots.AppendCertsFromPEM([]byte(caString)) + if !ok { + return fmt.Errorf("failed to parse root certificate") + } + + cert, err := parseCert(certString) + if err != nil { + return fmt.Errorf("failed to parse certificate") + } + + opts := x509.VerifyOptions{ + DNSName: fmt.Sprint(dns), + Roots: roots, + } + + _, err = cert.Verify(opts) + return err +} diff --git a/helper/tlsutil/generate_test.go b/helper/tlsutil/generate_test.go new file mode 100644 index 00000000000..5be9f7e2b5f --- /dev/null +++ b/helper/tlsutil/generate_test.go @@ -0,0 +1,159 @@ +package tlsutil + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "io" + "net" + "testing" + "time" + + "strings" + + "github.com/stretchr/testify/require" +) + +func TestSerialNumber(t *testing.T) { + n1, err := GenerateSerialNumber() + require.Nil(t, err) + + n2, err := GenerateSerialNumber() + require.Nil(t, err) + require.NotEqual(t, n1, n2) + + n3, err := GenerateSerialNumber() + require.Nil(t, err) + require.NotEqual(t, n1, n3) + require.NotEqual(t, n2, n3) + +} + +func TestGeneratePrivateKey(t *testing.T) { + t.Parallel() + _, p, err := GeneratePrivateKey() + require.Nil(t, err) + require.NotEmpty(t, p) + require.Contains(t, p, "BEGIN EC PRIVATE KEY") + require.Contains(t, p, "END EC PRIVATE KEY") + + block, _ := pem.Decode([]byte(p)) + pk, err := x509.ParseECPrivateKey(block.Bytes) + + require.Nil(t, err) + require.NotNil(t, pk) + require.Equal(t, 256, pk.Params().BitSize) +} + +type TestSigner struct { + public interface{} +} + +func (s *TestSigner) Public() crypto.PublicKey { + return s.public +} + +func (s *TestSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + return []byte{}, nil +} + +func TestGenerateCA(t *testing.T) { + t.Run("no signer", func(t *testing.T) { + ca, pk, err := GenerateCA(CAOpts{Signer: &TestSigner{}}) + require.Error(t, err) + require.Empty(t, ca) + require.Empty(t, pk) + }) + + t.Run("wrong key", func(t *testing.T) { + ca, pk, err := GenerateCA(CAOpts{Signer: &TestSigner{public: &rsa.PublicKey{}}}) + require.Error(t, err) + require.Empty(t, ca) + require.Empty(t, pk) + }) + + t.Run("valid key", func(t *testing.T) { + ca, pk, err := GenerateCA(CAOpts{}) + require.Nil(t, err) + require.NotEmpty(t, ca) + require.NotEmpty(t, pk) + + cert, err := parseCert(ca) + require.Nil(t, err) + require.True(t, strings.HasPrefix(cert.Subject.CommonName, "Consul Agent CA")) + require.Equal(t, true, cert.IsCA) + require.Equal(t, true, cert.BasicConstraintsValid) + + require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute) + require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute) + + require.Equal(t, x509.KeyUsageCertSign|x509.KeyUsageCRLSign|x509.KeyUsageDigitalSignature, cert.KeyUsage) + }) + + t.Run("RSA key", func(t *testing.T) { + ca, pk, err := GenerateCA(CAOpts{}) + require.NoError(t, err) + require.NotEmpty(t, ca) + require.NotEmpty(t, pk) + + cert, err := parseCert(ca) + require.NoError(t, err) + require.True(t, strings.HasPrefix(cert.Subject.CommonName, "Consul Agent CA")) + require.Equal(t, true, cert.IsCA) + require.Equal(t, true, cert.BasicConstraintsValid) + + require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute) + require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute) + + require.Equal(t, x509.KeyUsageCertSign|x509.KeyUsageCRLSign|x509.KeyUsageDigitalSignature, cert.KeyUsage) + }) +} + +func TestGenerateCert(t *testing.T) { + t.Parallel() + signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.Nil(t, err) + ca, _, err := GenerateCA(CAOpts{Signer: signer}) + require.Nil(t, err) + + DNSNames := []string{"server.dc1.consul"} + IPAddresses := []net.IP{net.ParseIP("123.234.243.213")} + extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + name := "Cert Name" + certificate, pk, err := GenerateCert(CertOpts{ + Signer: signer, CA: ca, Name: name, Days: 365, + DNSNames: DNSNames, IPAddresses: IPAddresses, ExtKeyUsage: extKeyUsage, + }) + require.Nil(t, err) + require.NotEmpty(t, certificate) + require.NotEmpty(t, pk) + + cert, err := parseCert(certificate) + require.Nil(t, err) + require.Equal(t, name, cert.Subject.CommonName) + require.Equal(t, true, cert.BasicConstraintsValid) + signee, err := ParseSigner(pk) + require.Nil(t, err) + certID, err := keyID(signee.Public()) + require.Nil(t, err) + require.Equal(t, certID, cert.SubjectKeyId) + caID, err := keyID(signer.Public()) + require.Nil(t, err) + require.Equal(t, caID, cert.AuthorityKeyId) + require.Contains(t, cert.Issuer.CommonName, "Consul Agent CA") + require.Equal(t, false, cert.IsCA) + + require.WithinDuration(t, cert.NotBefore, time.Now(), time.Minute) + require.WithinDuration(t, cert.NotAfter, time.Now().AddDate(0, 0, 365), time.Minute) + + require.Equal(t, x509.KeyUsageDigitalSignature|x509.KeyUsageKeyEncipherment, cert.KeyUsage) + require.Equal(t, extKeyUsage, cert.ExtKeyUsage) + + // https://github.com/golang/go/blob/10538a8f9e2e718a47633ac5a6e90415a2c3f5f1/src/crypto/x509/verify.go#L414 + require.Equal(t, DNSNames, cert.DNSNames) + require.True(t, IPAddresses[0].Equal(cert.IPAddresses[0])) +} diff --git a/nomad/rpc.go b/nomad/rpc.go index 5f1b83d127c..973161cb96d 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -238,6 +238,11 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC case pool.RpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) + // Ensure that when TLS is configured, only certificates from `server..nomad` are accepted for Raft connections. + if err := r.validateRaftTLS(rpcCtx); err != nil { + conn.Close() + return + } r.raftLayer.Handoff(ctx, conn) case pool.RpcMultiplex: @@ -825,3 +830,38 @@ RUN_QUERY: } return err } + +func (r *rpcHandler) validateRaftTLS(rpcCtx *RPCContext) error { + // TLS is not configured or not to be enforced + tlsConf := r.config.TLSConfig + if !tlsConf.EnableRPC || !tlsConf.VerifyServerHostname || tlsConf.RPCUpgradeMode { + return nil + } + + // defensive conditions: these should have already been enforced by handleConn + if rpcCtx == nil || !rpcCtx.TLS { + return errors.New("non-TLS connection attempted") + } + if len(rpcCtx.VerifiedChains) == 0 || len(rpcCtx.VerifiedChains[0]) == 0 { + // this should never happen, as rpcNameAndRegionValidate should have enforced it + return errors.New("missing cert info") + } + + // check that `server..nomad` is present in cert + expected := "server." + r.Region() + ".nomad" + + cert := rpcCtx.VerifiedChains[0][0] + for _, dnsName := range cert.DNSNames { + if dnsName == expected { + // Certificate is valid for the expected name + return nil + } + } + if cert.Subject.CommonName == expected { + // Certificate is valid for the expected name + return nil + } + + r.logger.Warn("unauthorized connection", "required_hostname", expected, "found", cert.DNSNames) + return fmt.Errorf("certificate is invalid for expected role or region: %q", expected) +} diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index f92087892d0..f08313144f7 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -3,13 +3,17 @@ package nomad import ( "context" "crypto/tls" + "crypto/x509" + "encoding/pem" "errors" "fmt" "io" + "io/ioutil" "net" "net/rpc" "os" "path" + "path/filepath" "testing" "time" @@ -1014,3 +1018,228 @@ func TestRPC_Limits_Streaming(t *testing.T) { require.NoError(t, err) }) } + +func TestRPC_TLS_Enforcement(t *testing.T) { + t.Parallel() + + defer func() { + //TODO Avoid panics from logging during shutdown + time.Sleep(1 * time.Second) + }() + + dir := tmpDir(t) + defer os.RemoveAll(dir) + + caPEM, pk, err := tlsutil.GenerateCA(tlsutil.CAOpts{Days: 5, Domain: "nomad"}) + require.NoError(t, err) + + err = ioutil.WriteFile(filepath.Join(dir, "ca.pem"), []byte(caPEM), 0600) + require.NoError(t, err) + + nodeID := 1 + newCert := func(t *testing.T, name string) string { + t.Helper() + + node := fmt.Sprintf("node%d", nodeID) + nodeID++ + signer, err := tlsutil.ParseSigner(pk) + require.NoError(t, err) + + pem, key, err := tlsutil.GenerateCert(tlsutil.CertOpts{ + Signer: signer, + CA: caPEM, + Name: name, + Days: 5, + DNSNames: []string{node + "." + name, name, "localhost"}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + }) + require.NoError(t, err) + + err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".pem"), []byte(pem), 0600) + require.NoError(t, err) + err = ioutil.WriteFile(filepath.Join(dir, node+"-"+name+".key"), []byte(key), 0600) + require.NoError(t, err) + + return filepath.Join(dir, node+"-"+name) + } + + connect := func(t *testing.T, s *Server, c *config.TLSConfig) net.Conn { + conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second) + require.NoError(t, err) + + // configure TLS + _, err = conn.Write([]byte{byte(pool.RpcTLS)}) + require.NoError(t, err) + + // Client TLS verification isn't necessary for + // our assertions + tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true) + require.NoError(t, err) + outTLSConf, err := tlsConf.OutgoingTLSConfig() + require.NoError(t, err) + outTLSConf.InsecureSkipVerify = true + + tlsConn := tls.Client(conn, outTLSConf) + require.NoError(t, tlsConn.Handshake()) + + return tlsConn + } + + nomadRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error { + conn := connect(t, s, c) + defer conn.Close() + _, err := conn.Write([]byte{byte(pool.RpcNomad)}) + require.NoError(t, err) + + codec := pool.NewClientCodec(conn) + + arg := struct{}{} + var out struct{} + return msgpackrpc.CallWithCodec(codec, "Status.Ping", arg, &out) + } + + parseCert := func(t *testing.T, path string) *x509.Certificate { + bytes, err := ioutil.ReadFile(path) + require.NoError(t, err) + block, _ := pem.Decode([]byte(bytes)) + require.NoError(t, err) + cert, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + + return cert + } + + raftRPC := func(t *testing.T, s *Server, c *config.TLSConfig) error { + // TODO: Actually make an RPC Call + // Raft Layer requires a wellformed RPC, otherwise, the connection + // is closed immediately - similar to the RaftTLS failure + + clientCert := parseCert(t, c.CertFile) + rootCert := parseCert(t, c.CAFile) + + rpcCtx := &RPCContext{ + TLS: true, + VerifiedChains: [][]*x509.Certificate{ + {clientCert, rootCert}, + }, + } + + return s.validateRaftTLS(rpcCtx) + } + + // generate server cert + serverCert := newCert(t, "server.global.nomad") + + mtlsS, cleanup := TestServer(t, func(c *Config) { + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: true, + CAFile: filepath.Join(dir, "ca.pem"), + CertFile: serverCert + ".pem", + KeyFile: serverCert + ".key", + } + }) + defer cleanup() + + nonVerifyS, cleanup := TestServer(t, func(c *Config) { + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: false, + CAFile: filepath.Join(dir, "ca.pem"), + CertFile: serverCert + ".pem", + KeyFile: serverCert + ".key", + } + }) + defer cleanup() + + // When VerifyServerHostname is enabled: + // Only all servers and local clients can make RPC requests + // Only local servers can connect to the Raft layer + cases := []struct { + name string + cn string + canRPC bool + canRaft bool + }{ + { + name: "local server", + cn: "server.global.nomad", + canRPC: true, + canRaft: true, + }, + { + name: "local client", + cn: "client.global.nomad", + canRPC: true, + canRaft: false, + }, + { + name: "other region server", + cn: "server.other.nomad", + canRPC: true, + canRaft: false, + }, + { + name: "other client server", + cn: "client.other.nomad", + canRPC: false, + canRaft: false, + }, + { + name: "irrelevant cert", + cn: "nomad.example.com", + canRPC: false, + canRaft: false, + }, + { + name: "globs", + cn: "*.global.nomad", + canRPC: false, + canRaft: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + certPath := newCert(t, tc.cn) + + cfg := &config.TLSConfig{ + EnableRPC: true, + VerifyServerHostname: true, + CAFile: filepath.Join(dir, "ca.pem"), + CertFile: certPath + ".pem", + KeyFile: certPath + ".key", + } + + t.Run("nomad RPC: verify_hostname=true", func(t *testing.T) { + err := nomadRPC(t, mtlsS, cfg) + + if tc.canRPC { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), "bad certificate") + } + }) + t.Run("nomad RPC: verify_hostname=false", func(t *testing.T) { + err := nomadRPC(t, nonVerifyS, cfg) + require.NoError(t, err) + }) + + t.Run("Raft RPC: verify_hostname=true", func(t *testing.T) { + err := raftRPC(t, mtlsS, cfg) + + if tc.canRaft { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid for expected role or region") + } + }) + t.Run("Raft RPC: verify_hostname=false", func(t *testing.T) { + err := raftRPC(t, nonVerifyS, cfg) + require.NoError(t, err) + }) + }) + } +}