From aeef705367825316a6fc1751160a3e718db19d04 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Mon, 14 Sep 2020 16:11:59 -0700 Subject: [PATCH 1/8] credentials/xds: Implementation of client-side xDS credentials. --- credentials/xds/xds.go | 339 +++++++++++++++++++++++ credentials/xds/xds_test.go | 532 ++++++++++++++++++++++++++++++++++++ 2 files changed, 871 insertions(+) create mode 100644 credentials/xds/xds.go create mode 100644 credentials/xds/xds_test.go diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go new file mode 100644 index 000000000000..7e8e35439dd3 --- /dev/null +++ b/credentials/xds/xds.go @@ -0,0 +1,339 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds provides a transport credentials implementation where the +// security configuration is pushed by a management server using xDS APIs. +// +// All APIs in this package are EXPERIMENTAL. +package xds + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "sync" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/tls/certprovider" + credinternal "google.golang.org/grpc/internal/credentials" +) + +// ClientOptions contains parameters to configure a new client-side xDS +// credentials implementation. +type ClientOptions struct { + // FallbackCreds specifies the fallback credentials to be used when either + // the `xds` scheme is not used in the user's dial target or when the xDS + // server does not return any security configuration. Attempts to create + // client credentials without a fallback credentials will fail. + FallbackCreds credentials.TransportCredentials +} + +// NewClientCredentials returns a new client-side transport credentials +// implementation which uses xDS APIs to fetch its security configuration. +func NewClientCredentials(opts ClientOptions) (credentials.TransportCredentials, error) { + if opts.FallbackCreds == nil { + return nil, errors.New("missing fallback credentials") + } + return &credsImpl{ + isClient: true, + fallback: opts.FallbackCreds, + }, nil +} + +// credsImpl is an implementation of the credentials.TransportCredentials +// interface which uses xDS APIs to fetch its security configuration. +type credsImpl struct { + isClient bool + fallback credentials.TransportCredentials +} + +// handshakeCtxKey is the context key used to store HandshakeInfo values. +type handshakeCtxKey struct{} + +// HandshakeInfo wraps all the security configuration required by client and +// server handshake methods in credsImpl. The xDS implementation will be +// responsible for populating these fields. +// +// Safe for concurrent access. +type HandshakeInfo struct { + mu sync.Mutex + rootProvider certprovider.Provider + identityProvider certprovider.Provider + acceptedSANs []string // Only on the client side. +} + +// SetRootCertProvider updates the root certificate provider. +func (chi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) { + chi.mu.Lock() + chi.rootProvider = root + chi.mu.Unlock() +} + +// SetIdentityCertProvider updates the identity certificate provider. +func (chi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) { + chi.mu.Lock() + chi.identityProvider = identity + chi.mu.Unlock() +} + +// SetAcceptedSANs updates the list of accepted SANs. +func (chi *HandshakeInfo) SetAcceptedSANs(sans []string) { + chi.mu.Lock() + chi.acceptedSANs = sans + chi.mu.Unlock() +} + +func (chi *HandshakeInfo) validate(isClient bool) error { + chi.mu.Lock() + defer chi.mu.Unlock() + + // On the client side, rootProvider is mandatory. IdentityProvider is + // optional based on whether the client is doing TLS or mTLS. + if isClient && chi.rootProvider == nil { + return errors.New("root certificate provider is missing") + } + + // On the server side, identityProvider is mandatory. RootProvider is + // optional based on whether the server is doing TLS or mTLS. + if !isClient && chi.identityProvider == nil { + return errors.New("identity certificate provider is missing") + } + + return nil +} + +func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { + var sans []string + // SANs can be specified in any of these four fields on the parsed cert. + sans = append(sans, cert.DNSNames...) + sans = append(sans, cert.EmailAddresses...) + for _, ip := range cert.IPAddresses { + sans = append(sans, ip.String()) + } + for _, uri := range cert.URIs { + sans = append(sans, uri.String()) + } + + chi.mu.Lock() + defer chi.mu.Unlock() + for _, san := range sans { + for _, asan := range chi.acceptedSANs { + if san == asan { + return true + } + } + } + return false +} + +// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root +// and identity certificate providers. +func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *HandshakeInfo { + return &HandshakeInfo{ + rootProvider: root, + identityProvider: identity, + acceptedSANs: sans, + } +} + +// NewContextWithHandshakeInfo returns a copy of the parent context with the +// provided HandshakeInfo stored as a value. +func NewContextWithHandshakeInfo(parent context.Context, info *HandshakeInfo) context.Context { + return context.WithValue(parent, handshakeCtxKey{}, info) +} + +// handshakeInfoFromCtx returns a pointer to the HandshakeInfo stored in ctx. +func handshakeInfoFromCtx(ctx context.Context) *HandshakeInfo { + val, ok := ctx.Value(handshakeCtxKey{}).(*HandshakeInfo) + if !ok { + return nil + } + return val +} + +// ClientHandshake performs the TLS handshake on the client-side. +// +// It looks for the presence of a HandshakeInfo value in the passed in context +// (added using a call to NewContextWithHandshakeInfo()), and retrieves identity +// and root certificates from there. It also retrieves a list of acceptable SANs +// and uses a custom verification function to validate the certificate presented +// by the peer. It uses fallback credentials if no HandshakeInfo is present in +// the passed in context. +func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + if !c.isClient { + return nil, nil, errors.New("ClientHandshake() is not supported for server credentials") + } + + chi := handshakeInfoFromCtx(ctx) + if chi == nil { + // A missing handshake info in the provided context could mean either + // the user did not specify an `xds` scheme in their dial target or that + // the xDS server did not provide any security configuration. In both of + // these cases, we use the fallback credentials specified by the user. + return c.fallback.ClientHandshake(ctx, authority, rawConn) + } + if err := chi.validate(c.isClient); err != nil { + return nil, nil, err + } + + // We build the tls.Config with the following values + // 1. Root certificate as returned by the root provider. + // 2. Identity certificate as returned by the identity provider. This may be + // empty on the client side, if the client is not doing mTLS. + // 3. InsecureSkipVerify to true. Certificates used in Mesh environments + // usually contains the identity of the workload presenting the + // certificate as a SAN (instead of a hostname in the CommonName field). + // This means that normal certificate verification as done by the + // standard library will fail. + // 4. Key usage to match whether client/server usage. + // 5. A `VerifyPeerCertificate` function which performs normal peer + // cert verification using configured roots, and the custom SAN checks. + var certs []tls.Certificate + var roots *x509.CertPool + err := func() error { + // We use this anonymous function trick to be able to defer the unlock. + chi.mu.Lock() + defer chi.mu.Unlock() + + if chi.rootProvider != nil { + km, err := chi.rootProvider.KeyMaterial(ctx) + if err != nil { + return fmt.Errorf("fetching root certificates failed: %v", err) + } + roots = km.Roots + } + if chi.identityProvider != nil { + km, err := chi.identityProvider.KeyMaterial(ctx) + if err != nil { + return fmt.Errorf("fetching identity certificates failed: %v", err) + } + certs = km.Certs + } + return nil + }() + if err != nil { + return nil, nil, err + } + + cfg := &tls.Config{ + Certificates: certs, + InsecureSkipVerify: true, + } + var keyUsages []x509.ExtKeyUsage + if c.isClient { + cfg.RootCAs = roots + keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + } else { + cfg.ClientCAs = roots + keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} + } + cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + // Parse all raw certificates presented by the peer. + var certs []*x509.Certificate + for _, rc := range rawCerts { + cert, err := x509.ParseCertificate(rc) + if err != nil { + return err + } + certs = append(certs, cert) + } + + // Build the intermediates list and verify that the leaf certificate + // is signed by one of the root certificates. + intermediates := x509.NewCertPool() + for _, cert := range certs[1:] { + intermediates.AddCert(cert) + } + opts := x509.VerifyOptions{ + Roots: roots, + Intermediates: intermediates, + KeyUsages: keyUsages, + } + if _, err := certs[0].Verify(opts); err != nil { + return err + } + // The SANs sent by the MeshCA are encoded as SPIFFE IDs. We need to + // only look at the SANs on the leaf cert. + if !chi.matchingSANExists(certs[0]) { + return fmt.Errorf("SANs received in leaf certificate %+v does not match any of the accepted SANs", certs[0]) + } + return nil + } + + // Perform the TLS handshake with the tls.Config that we have. We run the + // actual Handshake() function in a goroutine because we need to respect the + // deadline specified on the passed in context, and we need a way to cancel + // the handshake if the context is cancelled. + var conn *tls.Conn + if c.isClient { + conn = tls.Client(rawConn, cfg) + } else { + conn = tls.Server(rawConn, cfg) + } + + errCh := make(chan error, 1) + go func() { + errCh <- conn.Handshake() + close(errCh) + }() + select { + case err := <-errCh: + if err != nil { + conn.Close() + return nil, nil, err + } + case <-ctx.Done(): + conn.Close() + return nil, nil, ctx.Err() + } + info := credentials.TLSInfo{ + State: conn.ConnectionState(), + CommonAuthInfo: credentials.CommonAuthInfo{ + SecurityLevel: credentials.PrivacyAndIntegrity, + }, + SPIFFEID: credinternal.SPIFFEIDFromState(conn.ConnectionState()), + } + return credinternal.WrapSyscallConn(rawConn, conn), info, nil +} + +// ServerHandshake performs the TLS handshake on the server-side. +func (c *credsImpl) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) { + if c.isClient { + return nil, nil, errors.New("ServerHandshake is not supported for client credentials") + } + // TODO(easwars): Implement along with server side xDS implementation. + return nil, nil, errors.New("not implemented") +} + +// Info provides the ProtocolInfo of this TransportCredentials. +func (c *credsImpl) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{SecurityProtocol: "tls"} +} + +// Clone makes a copy of this TransportCredentials. +func (c *credsImpl) Clone() credentials.TransportCredentials { + clone := *c + return &clone +} + +func (c *credsImpl) OverrideServerName(_ string) error { + return errors.New("serverName for peer validation must be configured as a list of acceptable SANs") +} diff --git a/credentials/xds/xds_test.go b/credentials/xds/xds_test.go new file mode 100644 index 000000000000..b08d18a9ab38 --- /dev/null +++ b/credentials/xds/xds_test.go @@ -0,0 +1,532 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + "net" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/testdata" +) + +const ( + defaultTestTimeout = 1 * time.Second + defaultTestCertSAN = "*.test.example.com" + authority = "authority" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// Helper function to create a real TLS client credentials which is used as +// fallback credentials from multiple tests. +func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials { + creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") + if err != nil { + t.Fatal(err) + } + return creds +} + +// testServer is a no-op server which listens on a local TCP port for incoming +// connections, and performs a manual TLS handshake on the received raw +// connection using a user specified handshake function. It then makes the +// result of the handshake operation available through a channel for tests to +// inspect. Tests should stop the testServer as part of their cleanup. +type testServer struct { + lis net.Listener + address string // Listening address of the test server. + handshakeFunc testHandshakeFunc // Test specified handshake function. + hsResult *testutils.Channel // Channel to deliver handshake results. +} + +// handshakeResult wraps the result of the handshake operation on the test +// server. It consists of TLS connection state and an error, if the handshake +// failed. This result is delivered on the `hsResult` channel on the testServer. +type handshakeResult struct { + connState tls.ConnectionState + err error +} + +// Configurable handshake function for the testServer. Tests can set this to +// simulate different conditions like handshake success, failure, timeout etc. +type testHandshakeFunc func(net.Conn) handshakeResult + +// newTestServerWithHandshakeFunc starts a new testServer which listens for +// connections on a local TCP port, and uses the provided custom handshake +// function to perform TLS handshake. +func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer { + ts := &testServer{ + handshakeFunc: f, + hsResult: testutils.NewChannel(), + } + ts.start() + return ts +} + +// starts actually starts listening on a local TCP port, and spawns a goroutine +// to handle new connections. +func (ts *testServer) start() error { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + return err + } + ts.lis = lis + ts.address = lis.Addr().String() + go ts.handleConn() + return nil +} + +// handleconn accepts a new raw connection, and invokes the test provided +// handshake function to perform TLS handshake, and returns the result on the +// `hsResult` channel. +func (ts *testServer) handleConn() { + for { + rawConn, err := ts.lis.Accept() + if err != nil { + // Once the listeners closed, Accept() will return with an error. + return + } + hsr := ts.handshakeFunc(rawConn) + ts.hsResult.Send(hsr) + } +} + +// stop closes the associated listener which causes the connection handling +// goroutine to exit. +func (ts *testServer) stop() { + ts.lis.Close() +} + +// A handshake function which simulates a handshake timeout. Tests usually pass +// `defaultTestTimeout` to the ClientHandshake() method. This function just +// hangs around for twice that duration, thus making sure that the context +// passes to the credentials code times out. +func testServerTLSHandshakeTimeout(_ net.Conn) handshakeResult { + ctx, cancel := context.WithTimeout(context.Background(), 2*defaultTestTimeout) + <-ctx.Done() + cancel() + return handshakeResult{err: ctx.Err()} +} + +// A handshake function which simulates a successful handshake without client +// authentication (server does not request for client certificate during the +// handshake here). +func testServerTLSHandshake(rawConn net.Conn) handshakeResult { + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + return handshakeResult{err: err} + } + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + conn := tls.Server(rawConn, cfg) + if err := conn.Handshake(); err != nil { + return handshakeResult{err: err} + } + return handshakeResult{connState: conn.ConnectionState()} +} + +// A handshake function which simulates a successful handshake with mutual +// authentication. +func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult { + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + return handshakeResult{err: err} + } + pemData, err := ioutil.ReadFile(testdata.Path("x509/client_ca_cert.pem")) + if err != nil { + return handshakeResult{err: err} + } + roots := x509.NewCertPool() + roots.AppendCertsFromPEM(pemData) + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientCAs: roots, + } + conn := tls.Server(rawConn, cfg) + if err := conn.Handshake(); err != nil { + return handshakeResult{err: err} + } + return handshakeResult{connState: conn.ConnectionState()} +} + +// fakeProvider is an implementation of the certprovider.Provider interface +// which returns the configured key material and error in calls to +// KeyMaterial(). +type fakeProvider struct { + km *certprovider.KeyMaterial + err error +} + +func (f *fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { + return f.km, f.err +} + +func (f *fakeProvider) Close() {} + +// makeIdentityProvider creates a new instance of the fakeProvider returning the +// identity key material specified in the provider file paths. +func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider { + t.Helper() + cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath)) + if err != nil { + t.Fatal(err) + } + return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}} +} + +// makeRootProvider creates a new instance of the fakeProvider returning the +// root key material specified in the provider file paths. +func makeRootProvider(t *testing.T, caPath string) *fakeProvider { + pemData, err := ioutil.ReadFile(testdata.Path(caPath)) + if err != nil { + t.Fatal(err) + } + roots := x509.NewCertPool() + roots.AppendCertsFromPEM(pemData) + return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}} +} + +// newTestContextWithHandshakeInfo returns a copy of the passed in context with +// HandshakeInfo context value added to it. +func newTestContextWithHandshakeInfo(ctx context.Context, root, identity certprovider.Provider, sans ...string) context.Context { + info := NewHandshakeInfo(root, identity, sans...) + return NewContextWithHandshakeInfo(ctx, info) +} + +// compareAuthInfo compares the AuthInfo received on the client side after a +// successful handshake with the authInfo available on the testServer. +func compareAuthInfo(ts *testServer, ai credentials.AuthInfo) error { + if ai.AuthType() != "tls" { + return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls") + } + info, ok := ai.(credentials.TLSInfo) + if !ok { + return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{}) + } + gotState := info.State + + // Read the handshake result from the testServer which contains the TLS + // connection state and compare it with the one received on the client-side. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + val, err := ts.hsResult.Receive(ctx) + if err != nil { + return fmt.Errorf("testServer failed to return handshake result: %v", err) + } + hsr := val.(handshakeResult) + if hsr.err != nil { + return fmt.Errorf("testServer handshake failure: %v", hsr.err) + } + // AuthInfo contains a variety of information. We only verify a subset here. + // This is the same subset which is verified in TLS credentials tests. + if err := compareConnState(gotState, hsr.connState); err != nil { + return err + } + return nil +} + +func compareConnState(got, want tls.ConnectionState) error { + switch { + case got.Version != want.Version: + return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version) + case got.HandshakeComplete != want.HandshakeComplete: + return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete) + case got.CipherSuite != want.CipherSuite: + return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite) + case got.NegotiatedProtocol != want.NegotiatedProtocol: + return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol) + } + return nil +} + +// TestClientCredsWithoutFallback verifies that the call to +// NewClientCredentials() fails when no fallback is specified. +func (s) TestClientCredsWithoutFallback(t *testing.T) { + if _, err := NewClientCredentials(ClientOptions{}); err == nil { + t.Fatal("NewClientCredentials() succeeded without specifying fallback") + } +} + +// TestClientCredsInvalidHandshakeInfo verifies scenarios where the passed in +// HandshakeInfo is invalid because it does not contain the expected certificate +// providers. +func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) { + opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} + creds, err := NewClientCredentials(opts) + if err != nil { + t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) + } + + pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx := newTestContextWithHandshakeInfo(pCtx, nil, nil) + if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil { + t.Fatal("ClientHandshake succeeded without certificate providers in HandshakeInfo") + } + + ctx = newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}) + if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil { + t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo") + } +} + +// TestClientCredsProviderFailure verifies the cases where an expected +// certificate provider is missing in the HandshakeInfo value in the context. +func (s) TestClientCredsProviderFailure(t *testing.T) { + opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} + creds, err := NewClientCredentials(opts) + if err != nil { + t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) + } + + tests := []struct { + desc string + rootProvider certprovider.Provider + identityProvider certprovider.Provider + wantErr string + }{ + { + desc: "erroring root provider", + rootProvider: &fakeProvider{err: errors.New("root provider error")}, + wantErr: "root provider error", + }, + { + desc: "erroring identity provider", + rootProvider: &fakeProvider{km: &certprovider.KeyMaterial{}}, + identityProvider: &fakeProvider{err: errors.New("identity provider error")}, + wantErr: "identity provider error", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider) + if _, _, err := creds.ClientHandshake(ctx, authority, nil); !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr) + } + }) + } +} + +// TestClientCredsSuccess verifies successful client handshake cases. +func (s) TestClientCredsSuccess(t *testing.T) { + tests := []struct { + desc string + handshakeFunc testHandshakeFunc + rootProvider certprovider.Provider + identityProvider certprovider.Provider + }{ + { + // Since we don't specify rootProvider and identityProvider here, + // the test does not add a HandshakeInfo context value, and thereby + // the ClientHandshake() method will delegate to the fallback. + desc: "fallback", + handshakeFunc: testServerTLSHandshake, + }, + { + desc: "TLS", + handshakeFunc: testServerTLSHandshake, + rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"), + }, + { + desc: "mTLS", + handshakeFunc: testServerMutualTLSHandshake, + rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"), + identityProvider: makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ts := newTestServerWithHandshakeFunc(test.handshakeFunc) + defer ts.stop() + + opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} + creds, err := NewClientCredentials(opts) + if err != nil { + t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) + } + + conn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if test.rootProvider != nil || test.identityProvider != nil { + ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, defaultTestCertSAN) + } + _, ai, err := creds.ClientHandshake(ctx, authority, conn) + if err != nil { + t.Fatalf("ClientHandshake() returned failed: %q", err) + } + if err := compareAuthInfo(ts, ai); err != nil { + t.Fatal(err) + } + }) + } +} + +// TestClientCredsHandshakeFailure verifies different handshake failure cases. +func (s) TestClientCredsHandshakeFailure(t *testing.T) { + tests := []struct { + desc string + handshakeFunc testHandshakeFunc + rootProvider certprovider.Provider + san string + wantErr string + }{ + { + desc: "cert validation failure", + handshakeFunc: testServerTLSHandshake, + rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"), + san: defaultTestCertSAN, + wantErr: "x509: certificate signed by unknown authority", + }, + { + desc: "handshake times out", + handshakeFunc: testServerTLSHandshakeTimeout, + rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"), + san: defaultTestCertSAN, + wantErr: "context deadline exceeded", + }, + { + desc: "SAN mismatch", + handshakeFunc: testServerTLSHandshake, + rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"), + san: "bad-san", + wantErr: "does not match any of the accepted SANs", + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ts := newTestServerWithHandshakeFunc(test.handshakeFunc) + defer ts.stop() + + opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} + creds, err := NewClientCredentials(opts) + if err != nil { + t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) + } + + conn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san) + if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr) + } + }) + } +} + +// TestClientCredsProviderSwitch verifies the case where the first attempt of +// ClientHandshake fails because of a handshake failure. Then we update the +// certificate provider and the second attempt succeeds. This is an +// approximation of the flow of events when the control plane specifies new +// security config which results in new certificate providers being used. +func (s) TestClientCredsProviderSwitch(t *testing.T) { + ts := newTestServerWithHandshakeFunc(testServerTLSHandshake) + defer ts.stop() + + opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} + creds, err := NewClientCredentials(opts) + if err != nil { + t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) + } + + conn, err := net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + root1 := makeRootProvider(t, "x509/client_ca_cert.pem") + handshakeInfo := NewHandshakeInfo(root1, nil, defaultTestCertSAN) + ctx = NewContextWithHandshakeInfo(ctx, handshakeInfo) + if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { + t.Fatal("ClientHandshake() succeeded when expected to fail") + } + // Drain the result channel on the test server so that we can inspect the + // result for the next handshake. + _, err = ts.hsResult.Receive(ctx) + if err != nil { + t.Errorf("testServer failed to return handshake result: %v", err) + } + + conn, err = net.Dial("tcp", ts.address) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) + } + defer conn.Close() + + root2 := makeRootProvider(t, "x509/server_ca_cert.pem") + handshakeInfo.SetRootCertProvider(root2) + _, ai, err := creds.ClientHandshake(ctx, authority, conn) + if err != nil { + t.Fatalf("ClientHandshake() returned failed: %q", err) + } + if err := compareAuthInfo(ts, ai); err != nil { + t.Fatal(err) + } +} + +// TestClone verifies the Clone() method. +func (s) TestClone(t *testing.T) { + opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} + orig, err := NewClientCredentials(opts) + if err != nil { + t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) + } + + // The credsImpl does not have any exported fields, and it does not make + // sense to use any cmp options to look deep into. So, all we make sure here + // is that the cloned object points to a different locaiton in memory. + if clone := orig.Clone(); clone == orig { + t.Fatal("return value from Clone() doesn't point to new credentials instance") + } +} From fc53eda4756b4101f9cfb776dfcbfc413f416cd9 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Tue, 22 Sep 2020 09:27:06 -0700 Subject: [PATCH 2/8] Remove code for server side handling in ClientHandshake. --- credentials/xds/xds.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index 7e8e35439dd3..490b096de356 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -236,14 +236,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo cfg := &tls.Config{ Certificates: certs, InsecureSkipVerify: true, - } - var keyUsages []x509.ExtKeyUsage - if c.isClient { - cfg.RootCAs = roots - keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} - } else { - cfg.ClientCAs = roots - keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} + RootCAs: roots, } cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { // Parse all raw certificates presented by the peer. @@ -265,7 +258,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo opts := x509.VerifyOptions{ Roots: roots, Intermediates: intermediates, - KeyUsages: keyUsages, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, } if _, err := certs[0].Verify(opts); err != nil { return err From 71eaa870436849625a58c5f71d72c430506387e3 Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Tue, 22 Sep 2020 10:31:23 -0700 Subject: [PATCH 3/8] Use a set instead of slice for acceptedSANs. --- credentials/xds/xds.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index 490b096de356..3f831af1ef37 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -77,7 +77,7 @@ type HandshakeInfo struct { mu sync.Mutex rootProvider certprovider.Provider identityProvider certprovider.Provider - acceptedSANs []string // Only on the client side. + acceptedSANs map[string]bool // Only on the client side. } // SetRootCertProvider updates the root certificate provider. @@ -97,7 +97,10 @@ func (chi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider // SetAcceptedSANs updates the list of accepted SANs. func (chi *HandshakeInfo) SetAcceptedSANs(sans []string) { chi.mu.Lock() - chi.acceptedSANs = sans + chi.acceptedSANs = make(map[string]bool) + for _, san := range sans { + chi.acceptedSANs[san] = true + } chi.mu.Unlock() } @@ -135,10 +138,8 @@ func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { chi.mu.Lock() defer chi.mu.Unlock() for _, san := range sans { - for _, asan := range chi.acceptedSANs { - if san == asan { - return true - } + if chi.acceptedSANs[san] { + return true } } return false @@ -147,10 +148,14 @@ func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { // NewHandshakeInfo returns a new instance of HandshakeInfo with the given root // and identity certificate providers. func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *HandshakeInfo { + acceptedSANs := make(map[string]bool) + for _, san := range sans { + acceptedSANs[san] = true + } return &HandshakeInfo{ rootProvider: root, identityProvider: identity, - acceptedSANs: sans, + acceptedSANs: acceptedSANs, } } From 8c5ba188b7d8da952e82299ed5a77da84ad1456d Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Thu, 24 Sep 2020 15:38:53 -0700 Subject: [PATCH 4/8] Add a function to create the TLS config. --- credentials/xds/xds.go | 81 +++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index 3f831af1ef37..77c2b389c7a3 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -19,7 +19,10 @@ // Package xds provides a transport credentials implementation where the // security configuration is pushed by a management server using xDS APIs. // -// All APIs in this package are EXPERIMENTAL. +// Experimental +// +// Notice: All APIs in this package are EXPERIMENTAL and may be removed in a +// later release. package xds import ( @@ -97,7 +100,7 @@ func (chi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider // SetAcceptedSANs updates the list of accepted SANs. func (chi *HandshakeInfo) SetAcceptedSANs(sans []string) { chi.mu.Lock() - chi.acceptedSANs = make(map[string]bool) + chi.acceptedSANs = make(map[string]bool, len(sans)) for _, san := range sans { chi.acceptedSANs[san] = true } @@ -111,18 +114,48 @@ func (chi *HandshakeInfo) validate(isClient bool) error { // On the client side, rootProvider is mandatory. IdentityProvider is // optional based on whether the client is doing TLS or mTLS. if isClient && chi.rootProvider == nil { - return errors.New("root certificate provider is missing") + return errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake") } // On the server side, identityProvider is mandatory. RootProvider is // optional based on whether the server is doing TLS or mTLS. if !isClient && chi.identityProvider == nil { - return errors.New("identity certificate provider is missing") + return errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake") } return nil } +func (chi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error) { + chi.mu.Lock() + // Since the call to KeyMaterial() can block, we read the providers under + // the lock but call the actual function after releasing the lock. + rootProv, idProv := chi.rootProvider, chi.identityProvider + chi.mu.Unlock() + + // InsecureSkipVerify needs to be set to true because we need to perform + // custom verification to check the SAN on the received certificate. + // Currently the Go stdlib does complete verification of the cert (which + // includes hostname verification) or none. We are forced to go with the + // latter and perform the normal cert validation ourselves. + cfg := &tls.Config{InsecureSkipVerify: true} + if rootProv != nil { + km, err := rootProv.KeyMaterial(ctx) + if err != nil { + return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err) + } + cfg.RootCAs = km.Roots + } + if idProv != nil { + km, err := idProv.KeyMaterial(ctx) + if err != nil { + return nil, fmt.Errorf("xds: fetching identity certificates from CertificateProvider failed: %v", err) + } + cfg.Certificates = km.Certs + } + return cfg, nil +} + func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { var sans []string // SANs can be specified in any of these four fields on the parsed cert. @@ -211,38 +244,10 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo // 4. Key usage to match whether client/server usage. // 5. A `VerifyPeerCertificate` function which performs normal peer // cert verification using configured roots, and the custom SAN checks. - var certs []tls.Certificate - var roots *x509.CertPool - err := func() error { - // We use this anonymous function trick to be able to defer the unlock. - chi.mu.Lock() - defer chi.mu.Unlock() - - if chi.rootProvider != nil { - km, err := chi.rootProvider.KeyMaterial(ctx) - if err != nil { - return fmt.Errorf("fetching root certificates failed: %v", err) - } - roots = km.Roots - } - if chi.identityProvider != nil { - km, err := chi.identityProvider.KeyMaterial(ctx) - if err != nil { - return fmt.Errorf("fetching identity certificates failed: %v", err) - } - certs = km.Certs - } - return nil - }() + cfg, err := chi.makeTLSConfig(ctx) if err != nil { return nil, nil, err } - - cfg := &tls.Config{ - Certificates: certs, - InsecureSkipVerify: true, - RootCAs: roots, - } cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { // Parse all raw certificates presented by the peer. var certs []*x509.Certificate @@ -261,7 +266,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo intermediates.AddCert(cert) } opts := x509.VerifyOptions{ - Roots: roots, + Roots: cfg.RootCAs, Intermediates: intermediates, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, } @@ -280,13 +285,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo // actual Handshake() function in a goroutine because we need to respect the // deadline specified on the passed in context, and we need a way to cancel // the handshake if the context is cancelled. - var conn *tls.Conn - if c.isClient { - conn = tls.Client(rawConn, cfg) - } else { - conn = tls.Server(rawConn, cfg) - } - + conn := tls.Client(rawConn, cfg) errCh := make(chan error, 1) go func() { errCh <- conn.Handshake() From 2c4475ba4e7c9a4c8be284442ef39d97f0e111bd Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Thu, 24 Sep 2020 16:31:26 -0700 Subject: [PATCH 5/8] Pass handshake info through address attributes. --- credentials/xds/xds.go | 117 ++++++++++++++++++++---------------- credentials/xds/xds_test.go | 30 +++++++-- 2 files changed, 90 insertions(+), 57 deletions(-) diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index 77c2b389c7a3..f2cf8e5095cf 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -34,6 +34,7 @@ import ( "net" "sync" + "google.golang.org/grpc/attributes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" credinternal "google.golang.org/grpc/internal/credentials" @@ -68,8 +69,21 @@ type credsImpl struct { fallback credentials.TransportCredentials } -// handshakeCtxKey is the context key used to store HandshakeInfo values. -type handshakeCtxKey struct{} +// handshakeAttrKey is the type used as the key to store HandshakeInfo in +// the Attributes field of resolver.Address. +type handshakeAttrKey struct{} + +// SetHandshakeInfo returns a copy of attr in which is updated with hInfo. +func SetHandshakeInfo(attr *attributes.Attributes, hInfo *HandshakeInfo) *attributes.Attributes { + return attr.WithValues(handshakeAttrKey{}, hInfo) +} + +// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr. +func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo { + v := attr.Value(handshakeAttrKey{}) + hi, _ := v.(*HandshakeInfo) + return hi +} // HandshakeInfo wraps all the security configuration required by client and // server handshake methods in credsImpl. The xDS implementation will be @@ -84,54 +98,54 @@ type HandshakeInfo struct { } // SetRootCertProvider updates the root certificate provider. -func (chi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) { - chi.mu.Lock() - chi.rootProvider = root - chi.mu.Unlock() +func (hi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) { + hi.mu.Lock() + hi.rootProvider = root + hi.mu.Unlock() } // SetIdentityCertProvider updates the identity certificate provider. -func (chi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) { - chi.mu.Lock() - chi.identityProvider = identity - chi.mu.Unlock() +func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) { + hi.mu.Lock() + hi.identityProvider = identity + hi.mu.Unlock() } // SetAcceptedSANs updates the list of accepted SANs. -func (chi *HandshakeInfo) SetAcceptedSANs(sans []string) { - chi.mu.Lock() - chi.acceptedSANs = make(map[string]bool, len(sans)) +func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) { + hi.mu.Lock() + hi.acceptedSANs = make(map[string]bool, len(sans)) for _, san := range sans { - chi.acceptedSANs[san] = true + hi.acceptedSANs[san] = true } - chi.mu.Unlock() + hi.mu.Unlock() } -func (chi *HandshakeInfo) validate(isClient bool) error { - chi.mu.Lock() - defer chi.mu.Unlock() +func (hi *HandshakeInfo) validate(isClient bool) error { + hi.mu.Lock() + defer hi.mu.Unlock() // On the client side, rootProvider is mandatory. IdentityProvider is // optional based on whether the client is doing TLS or mTLS. - if isClient && chi.rootProvider == nil { + if isClient && hi.rootProvider == nil { return errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake") } // On the server side, identityProvider is mandatory. RootProvider is // optional based on whether the server is doing TLS or mTLS. - if !isClient && chi.identityProvider == nil { + if !isClient && hi.identityProvider == nil { return errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake") } return nil } -func (chi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error) { - chi.mu.Lock() +func (hi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error) { + hi.mu.Lock() // Since the call to KeyMaterial() can block, we read the providers under // the lock but call the actual function after releasing the lock. - rootProv, idProv := chi.rootProvider, chi.identityProvider - chi.mu.Unlock() + rootProv, idProv := hi.rootProvider, hi.identityProvider + hi.mu.Unlock() // InsecureSkipVerify needs to be set to true because we need to perform // custom verification to check the SAN on the received certificate. @@ -156,7 +170,7 @@ func (chi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error return cfg, nil } -func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { +func (hi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { var sans []string // SANs can be specified in any of these four fields on the parsed cert. sans = append(sans, cert.DNSNames...) @@ -168,10 +182,10 @@ func (chi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { sans = append(sans, uri.String()) } - chi.mu.Lock() - defer chi.mu.Unlock() + hi.mu.Lock() + defer hi.mu.Unlock() for _, san := range sans { - if chi.acceptedSANs[san] { + if hi.acceptedSANs[san] { return true } } @@ -192,21 +206,6 @@ func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *Han } } -// NewContextWithHandshakeInfo returns a copy of the parent context with the -// provided HandshakeInfo stored as a value. -func NewContextWithHandshakeInfo(parent context.Context, info *HandshakeInfo) context.Context { - return context.WithValue(parent, handshakeCtxKey{}, info) -} - -// handshakeInfoFromCtx returns a pointer to the HandshakeInfo stored in ctx. -func handshakeInfoFromCtx(ctx context.Context) *HandshakeInfo { - val, ok := ctx.Value(handshakeCtxKey{}).(*HandshakeInfo) - if !ok { - return nil - } - return val -} - // ClientHandshake performs the TLS handshake on the client-side. // // It looks for the presence of a HandshakeInfo value in the passed in context @@ -220,15 +219,29 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo return nil, nil, errors.New("ClientHandshake() is not supported for server credentials") } - chi := handshakeInfoFromCtx(ctx) - if chi == nil { - // A missing handshake info in the provided context could mean either - // the user did not specify an `xds` scheme in their dial target or that - // the xDS server did not provide any security configuration. In both of - // these cases, we use the fallback credentials specified by the user. + // The CDS balancer constructs a new HandshakeInfo using a call to + // NewHandshakeInfo(), and then adds it to the attributes field of the + // resolver.Address when handling calls to NewSubConn(). The transport layer + // takes care of shipping these attributes in the context to this handshake + // function. We first read the credentials.ClientHandshakeInfo type from the + // context, which contains the attributes added by the CDS balancer. We then + // read the HandshakeInfo from the attributes to get to the actual data that + // we need here for the handshake. + chi := credentials.ClientHandshakeInfoFromContext(ctx) + // If there are no attributes in the received context or the attributes does + // not contain a HandshakeInfo, it could either mean that the user did not + // specify an `xds` scheme in their dial target or that the xDS server did + // not provide any security configuration. In both of these cases, we use + // the fallback credentials specified by the user. + if chi.Attributes == nil { return c.fallback.ClientHandshake(ctx, authority, rawConn) } - if err := chi.validate(c.isClient); err != nil { + hi := GetHandshakeInfo(chi.Attributes) + if hi == nil { + return c.fallback.ClientHandshake(ctx, authority, rawConn) + } + + if err := hi.validate(c.isClient); err != nil { return nil, nil, err } @@ -244,7 +257,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo // 4. Key usage to match whether client/server usage. // 5. A `VerifyPeerCertificate` function which performs normal peer // cert verification using configured roots, and the custom SAN checks. - cfg, err := chi.makeTLSConfig(ctx) + cfg, err := hi.makeTLSConfig(ctx) if err != nil { return nil, nil, err } @@ -275,7 +288,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo } // The SANs sent by the MeshCA are encoded as SPIFFE IDs. We need to // only look at the SANs on the leaf cert. - if !chi.matchingSANExists(certs[0]) { + if !hi.matchingSANExists(certs[0]) { return fmt.Errorf("SANs received in leaf certificate %+v does not match any of the accepted SANs", certs[0]) } return nil diff --git a/credentials/xds/xds_test.go b/credentials/xds/xds_test.go index b08d18a9ab38..68e0ed22caec 100644 --- a/credentials/xds/xds_test.go +++ b/credentials/xds/xds_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/testdata" @@ -219,11 +220,20 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider { return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}} } -// newTestContextWithHandshakeInfo returns a copy of the passed in context with -// HandshakeInfo context value added to it. -func newTestContextWithHandshakeInfo(ctx context.Context, root, identity certprovider.Provider, sans ...string) context.Context { +// newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo +// context value added to it. +func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context { + // Creating the HandshakeInfo and adding it to the attributes is very + // similar to what the CDS balancer would do when it intercepts calls to + // NewSubConn(). info := NewHandshakeInfo(root, identity, sans...) - return NewContextWithHandshakeInfo(ctx, info) + attr := SetHandshakeInfo(nil, info) + + // Moving the attributes from the resolver.Address to the context passed to + // the handshaker is done in the transport layer. Since we directly call the + // handshaker in these tests, we need to do the same here. + contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) + return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: attr}) } // compareAuthInfo compares the AuthInfo received on the client side after a @@ -485,9 +495,17 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() + // Create a root provider which will fail the handshake because it does not + // use the correct trust roots. root1 := makeRootProvider(t, "x509/client_ca_cert.pem") handshakeInfo := NewHandshakeInfo(root1, nil, defaultTestCertSAN) - ctx = NewContextWithHandshakeInfo(ctx, handshakeInfo) + + // We need to repeat most of what newTestContextWithHandshakeInfo() does + // here because we need access to the underlying HandshakeInfo so that we + // can update it before the next call to ClientHandshake(). + attr := SetHandshakeInfo(nil, handshakeInfo) + contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) + ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: attr}) if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { t.Fatal("ClientHandshake() succeeded when expected to fail") } @@ -504,6 +522,8 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { } defer conn.Close() + // Create a new root provider which uses the correct trust roots. And update + // the HandshakeInfo with the new provider. root2 := makeRootProvider(t, "x509/server_ca_cert.pem") handshakeInfo.SetRootCertProvider(root2) _, ai, err := creds.ClientHandshake(ctx, authority, conn) From f1dcc3a835919c9684ed6ad8bcc0e45200031fde Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Thu, 24 Sep 2020 16:40:27 -0700 Subject: [PATCH 6/8] Minor nits. --- credentials/xds/xds.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index f2cf8e5095cf..55f93a8299fa 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -73,7 +73,7 @@ type credsImpl struct { // the Attributes field of resolver.Address. type handshakeAttrKey struct{} -// SetHandshakeInfo returns a copy of attr in which is updated with hInfo. +// SetHandshakeInfo returns a copy of attr which is updated with hInfo. func SetHandshakeInfo(attr *attributes.Attributes, hInfo *HandshakeInfo) *attributes.Attributes { return attr.WithValues(handshakeAttrKey{}, hInfo) } @@ -195,7 +195,7 @@ func (hi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool { // NewHandshakeInfo returns a new instance of HandshakeInfo with the given root // and identity certificate providers. func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *HandshakeInfo { - acceptedSANs := make(map[string]bool) + acceptedSANs := make(map[string]bool, len(sans)) for _, san := range sans { acceptedSANs[san] = true } From 7b7d29507886696a2c42017ee4fe0b758e486a0a Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Tue, 29 Sep 2020 09:37:17 -0700 Subject: [PATCH 7/8] Make SetHandshakeInfo work with resolver.Address --- credentials/xds/xds.go | 13 ++++++++----- credentials/xds/xds_test.go | 9 +++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index 55f93a8299fa..67666340a6ab 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -38,6 +38,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" credinternal "google.golang.org/grpc/internal/credentials" + "google.golang.org/grpc/resolver" ) // ClientOptions contains parameters to configure a new client-side xDS @@ -73,9 +74,11 @@ type credsImpl struct { // the Attributes field of resolver.Address. type handshakeAttrKey struct{} -// SetHandshakeInfo returns a copy of attr which is updated with hInfo. -func SetHandshakeInfo(attr *attributes.Attributes, hInfo *HandshakeInfo) *attributes.Attributes { - return attr.WithValues(handshakeAttrKey{}, hInfo) +// SetHandshakeInfo returns a copy of addr in which the Attributes field is +// updated with hInfo. +func SetHandshakeInfo(addr resolver.Address, hInfo *HandshakeInfo) resolver.Address { + addr.Attributes = addr.Attributes.WithValues(handshakeAttrKey{}, hInfo) + return addr } // GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr. @@ -128,13 +131,13 @@ func (hi *HandshakeInfo) validate(isClient bool) error { // On the client side, rootProvider is mandatory. IdentityProvider is // optional based on whether the client is doing TLS or mTLS. if isClient && hi.rootProvider == nil { - return errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake") + return errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake. Please check configuration on the management server") } // On the server side, identityProvider is mandatory. RootProvider is // optional based on whether the server is doing TLS or mTLS. if !isClient && hi.identityProvider == nil { - return errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake") + return errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake. Please check configuration on the management server") } return nil diff --git a/credentials/xds/xds_test.go b/credentials/xds/xds_test.go index 68e0ed22caec..a2adcf4558ea 100644 --- a/credentials/xds/xds_test.go +++ b/credentials/xds/xds_test.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/testdata" ) @@ -227,13 +228,13 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert // similar to what the CDS balancer would do when it intercepts calls to // NewSubConn(). info := NewHandshakeInfo(root, identity, sans...) - attr := SetHandshakeInfo(nil, info) + addr := SetHandshakeInfo(resolver.Address{}, info) // Moving the attributes from the resolver.Address to the context passed to // the handshaker is done in the transport layer. Since we directly call the // handshaker in these tests, we need to do the same here. contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: attr}) + return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) } // compareAuthInfo compares the AuthInfo received on the client side after a @@ -503,9 +504,9 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { // We need to repeat most of what newTestContextWithHandshakeInfo() does // here because we need access to the underlying HandshakeInfo so that we // can update it before the next call to ClientHandshake(). - attr := SetHandshakeInfo(nil, handshakeInfo) + addr := SetHandshakeInfo(resolver.Address{}, handshakeInfo) contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: attr}) + ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { t.Fatal("ClientHandshake() succeeded when expected to fail") } From 8d6f6d69dfdb90700d511b899f8d3cef6eed411a Mon Sep 17 00:00:00 2001 From: Easwar Swaminathan Date: Tue, 29 Sep 2020 10:57:17 -0700 Subject: [PATCH 8/8] Unexport GetHandshakeInfo. --- credentials/xds/xds.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index 67666340a6ab..cecc27d14559 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -81,8 +81,8 @@ func SetHandshakeInfo(addr resolver.Address, hInfo *HandshakeInfo) resolver.Addr return addr } -// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr. -func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo { +// getHandshakeInfo returns a pointer to the HandshakeInfo stored in attr. +func getHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo { v := attr.Value(handshakeAttrKey{}) hi, _ := v.(*HandshakeInfo) return hi @@ -239,7 +239,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo if chi.Attributes == nil { return c.fallback.ClientHandshake(ctx, authority, rawConn) } - hi := GetHandshakeInfo(chi.Attributes) + hi := getHandshakeInfo(chi.Attributes) if hi == nil { return c.fallback.ClientHandshake(ctx, authority, rawConn) }