From 59c0aec9dc720d1788b4d9258658c8f02762b40e Mon Sep 17 00:00:00 2001 From: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:54:29 -0500 Subject: [PATCH] xDS: Atomically read and write xDS security configuration client side (#6796) --- credentials/xds/xds.go | 5 +- credentials/xds/xds_client_test.go | 22 +++--- credentials/xds/xds_server_test.go | 19 ++--- internal/credentials/xds/handshake_info.go | 70 +++++-------------- .../credentials/xds/handshake_info_test.go | 6 +- internal/internal.go | 2 +- .../balancer/cdsbalancer/cdsbalancer.go | 47 +++++++------ .../cdsbalancer/cdsbalancer_security_test.go | 10 +-- xds/internal/server/conn_wrapper.go | 6 +- 9 files changed, 76 insertions(+), 111 deletions(-) diff --git a/credentials/xds/xds.go b/credentials/xds/xds.go index d232e678674..2b5a5e58ec3 100644 --- a/credentials/xds/xds.go +++ b/credentials/xds/xds.go @@ -27,6 +27,7 @@ import ( "errors" "fmt" "net" + "sync/atomic" "time" "google.golang.org/grpc/credentials" @@ -114,7 +115,9 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo if chi.Attributes == nil { return c.fallback.ClientHandshake(ctx, authority, rawConn) } - hi := xdsinternal.GetHandshakeInfo(chi.Attributes) + + uPtr := xdsinternal.GetHandshakeInfo(chi.Attributes) + hi := (*xdsinternal.HandshakeInfo)(atomic.LoadPointer(uPtr)) if hi.UseFallbackCreds() { return c.fallback.ClientHandshake(ctx, authority, rawConn) } diff --git a/credentials/xds/xds_client_test.go b/credentials/xds/xds_client_test.go index 2fd2e21cdd7..cce40fc46f3 100644 --- a/credentials/xds/xds_client_test.go +++ b/credentials/xds/xds_client_test.go @@ -27,8 +27,10 @@ import ( "net" "os" "strings" + "sync/atomic" "testing" "time" + "unsafe" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" @@ -219,11 +221,13 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert // Creating the HandshakeInfo and adding it to the attributes is very // similar to what the CDS balancer would do when it intercepts calls to // NewSubConn(). - info := xdsinternal.NewHandshakeInfo(root, identity) + var sms []matcher.StringMatcher if sanExactMatch != "" { - info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}) + sms = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)} } - addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info) + info := xdsinternal.NewHandshakeInfo(root, identity, sms, false) + uPtr := unsafe.Pointer(info) + addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr) // Moving the attributes from the resolver.Address to the context passed to // the handshaker is done in the transport layer. Since we directly call the @@ -533,13 +537,12 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { // Create a root provider which will fail the handshake because it does not // use the correct trust roots. root1 := makeRootProvider(t, "x509/client_ca_cert.pem") - handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil) - handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}) - + handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false) // We need to repeat most of what newTestContextWithHandshakeInfo() does // here because we need access to the underlying HandshakeInfo so that we // can update it before the next call to ClientHandshake(). - addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo) + uPtr := unsafe.Pointer(handshakeInfo) + addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr) ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { t.Fatal("ClientHandshake() succeeded when expected to fail") @@ -560,7 +563,10 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { // Create a new root provider which uses the correct trust roots. And update // the HandshakeInfo with the new provider. root2 := makeRootProvider(t, "x509/server_ca_cert.pem") - handshakeInfo.SetRootCertProvider(root2) + handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false) + // Update the existing pointer, which address attribute will continue to + // point to. + atomic.StorePointer(&uPtr, unsafe.Pointer(handshakeInfo)) _, ai, err := creds.ClientHandshake(ctx, authority, conn) if err != nil { t.Fatalf("ClientHandshake() returned failed: %q", err) diff --git a/credentials/xds/xds_server_test.go b/credentials/xds/xds_server_test.go index bc32a04e69a..dd3d83aab89 100644 --- a/credentials/xds/xds_server_test.go +++ b/credentials/xds/xds_server_test.go @@ -122,7 +122,7 @@ func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) { t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) } - info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil) + info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil, nil, false) conn := newWrappedConn(nil, info, time.Time{}) if _, _, err := creds.ServerHandshake(conn); err == nil { t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo") @@ -158,7 +158,7 @@ func (s) TestServerCredsProviderFailure(t *testing.T) { } for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider) + info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, false) conn := newWrappedConn(nil, info, time.Time{}) if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) { t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr) @@ -232,8 +232,7 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) { // Create a test server which uses the xDS server credentials created above // to perform TLS handshake on incoming connections. ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { - hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem")) - hi.SetRequireClientCert(true) + hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true) // Create a wrapped conn which can return the HandshakeInfo created // above with a very small deadline. @@ -285,8 +284,7 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) { ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { // Create a HandshakeInfo which has a root provider which does not match // the certificate sent by the client. - hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem")) - hi.SetRequireClientCert(true) + hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true) // Create a wrapped conn which can return the HandshakeInfo and // configured deadline to the xDS credentials' ServerHandshake() @@ -367,8 +365,7 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) { // created above to perform TLS handshake on incoming connections. ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { // Create a HandshakeInfo with information from the test table. - hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider) - hi.SetRequireClientCert(test.requireClientCert) + hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert) // Create a wrapped conn which can return the HandshakeInfo and // configured deadline to the xDS credentials' ServerHandshake() @@ -448,8 +445,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) { if cnt == 1 { // Create a HandshakeInfo which has a root provider which does not match // the certificate sent by the client. - hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem")) - hi.SetRequireClientCert(true) + hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true) // Create a wrapped conn which can return the HandshakeInfo and // configured deadline to the xDS credentials' ServerHandshake() @@ -463,8 +459,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) { return handshakeResult{} } - hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem")) - hi.SetRequireClientCert(true) + hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), nil, true) // Create a wrapped conn which can return the HandshakeInfo and // configured deadline to the xDS credentials' ServerHandshake() diff --git a/internal/credentials/xds/handshake_info.go b/internal/credentials/xds/handshake_info.go index b6f1fa520fc..c657baeb321 100644 --- a/internal/credentials/xds/handshake_info.go +++ b/internal/credentials/xds/handshake_info.go @@ -26,7 +26,7 @@ import ( "errors" "fmt" "strings" - "sync" + "unsafe" "google.golang.org/grpc/attributes" "google.golang.org/grpc/credentials/tls/certprovider" @@ -66,59 +66,38 @@ func (hi *HandshakeInfo) Equal(other *HandshakeInfo) bool { } // SetHandshakeInfo returns a copy of addr in which the Attributes field is -// updated with hInfo. -func SetHandshakeInfo(addr resolver.Address, hInfo *HandshakeInfo) resolver.Address { - addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hInfo) +// updated with hiPtr. +func SetHandshakeInfo(addr resolver.Address, hiPtr *unsafe.Pointer) resolver.Address { + addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hiPtr) return addr } -// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr. -func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo { +// GetHandshakeInfo returns a pointer to the *HandshakeInfo stored in attr. +func GetHandshakeInfo(attr *attributes.Attributes) *unsafe.Pointer { v := attr.Value(handshakeAttrKey{}) - hi, _ := v.(*HandshakeInfo) + hi, _ := v.(*unsafe.Pointer) return hi } // HandshakeInfo wraps all the security configuration required by client and // server handshake methods in xds credentials. The xDS implementation will be // responsible for populating these fields. -// -// Safe for concurrent access. type HandshakeInfo struct { - mu sync.Mutex + // All fields written at init time and read only after that, so no + // synchronization needed. rootProvider certprovider.Provider identityProvider certprovider.Provider sanMatchers []matcher.StringMatcher // Only on the client side. requireClientCert bool // Only on server side. } -// SetRootCertProvider updates the root certificate provider. -func (hi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) { - hi.mu.Lock() - hi.rootProvider = root - hi.mu.Unlock() -} - -// SetIdentityCertProvider updates the identity certificate provider. -func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) { - hi.mu.Lock() - hi.identityProvider = identity - hi.mu.Unlock() -} - -// SetSANMatchers updates the list of SAN matchers. -func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []matcher.StringMatcher) { - hi.mu.Lock() - hi.sanMatchers = sanMatchers - hi.mu.Unlock() -} - -// SetRequireClientCert updates whether a client cert is required during the -// ServerHandshake(). A value of true indicates that we are performing mTLS. -func (hi *HandshakeInfo) SetRequireClientCert(require bool) { - hi.mu.Lock() - hi.requireClientCert = require - hi.mu.Unlock() +func NewHandshakeInfo(rootProvider certprovider.Provider, identityProvider certprovider.Provider, sanMatchers []matcher.StringMatcher, requireClientCert bool) *HandshakeInfo { + return &HandshakeInfo{ + rootProvider: rootProvider, + identityProvider: identityProvider, + sanMatchers: sanMatchers, + requireClientCert: requireClientCert, + } } // UseFallbackCreds returns true when fallback credentials are to be used based @@ -127,24 +106,18 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool { if hi == nil { return true } - - hi.mu.Lock() - defer hi.mu.Unlock() return hi.identityProvider == nil && hi.rootProvider == nil } // GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo. // To be used only for testing purposes. func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher { - hi.mu.Lock() - defer hi.mu.Unlock() return append([]matcher.StringMatcher{}, hi.sanMatchers...) } // ClientSideTLSConfig constructs a tls.Config to be used in a client-side // handshake based on the contents of the HandshakeInfo. func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) { - hi.mu.Lock() // On the client side, rootProvider is mandatory. IdentityProvider is // optional based on whether the client is doing TLS or mTLS. if hi.rootProvider == nil { @@ -153,7 +126,6 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, // Since the call to KeyMaterial() can block, we read the providers under // the lock but call the actual function after releasing the lock. rootProv, idProv := hi.rootProvider, hi.identityProvider - hi.mu.Unlock() // InsecureSkipVerify needs to be set to true because we need to perform // custom verification to check the SAN on the received certificate. @@ -188,7 +160,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, ClientAuth: tls.NoClientCert, NextProtos: []string{"h2"}, } - hi.mu.Lock() // On the server side, identityProvider is mandatory. RootProvider is // optional based on whether the server is doing TLS or mTLS. if hi.identityProvider == nil { @@ -200,7 +171,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, if hi.requireClientCert { cfg.ClientAuth = tls.RequireAndVerifyClientCert } - hi.mu.Unlock() // identityProvider is mandatory on the server side. km, err := idProv.KeyMaterial(ctx) @@ -225,8 +195,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, // If the list of SAN matchers in the HandshakeInfo is empty, this function // returns true for all input certificates. func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool { - hi.mu.Lock() - defer hi.mu.Unlock() if len(hi.sanMatchers) == 0 { return true } @@ -325,9 +293,3 @@ func dnsMatch(host, san string) bool { hostPrefix := strings.TrimSuffix(host, san[1:]) return !strings.Contains(hostPrefix, ".") } - -// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root -// and identity certificate providers. -func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo { - return &HandshakeInfo{rootProvider: root, identityProvider: identity} -} diff --git a/internal/credentials/xds/handshake_info_test.go b/internal/credentials/xds/handshake_info_test.go index 91257a1925d..4a791d8df58 100644 --- a/internal/credentials/xds/handshake_info_test.go +++ b/internal/credentials/xds/handshake_info_test.go @@ -188,8 +188,7 @@ func TestMatchingSANExists_FailureCases(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - hi := NewHandshakeInfo(nil, nil) - hi.SetSANMatchers(test.sanMatchers) + hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false) if hi.MatchingSANExists(inputCert) { t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers) @@ -289,8 +288,7 @@ func TestMatchingSANExists_Success(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - hi := NewHandshakeInfo(nil, nil) - hi.SetSANMatchers(test.sanMatchers) + hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false) if !hi.MatchingSANExists(inputCert) { t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers) diff --git a/internal/internal.go b/internal/internal.go index f28791b89b0..2eef978c8d3 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -57,7 +57,7 @@ var ( // GetXDSHandshakeInfoForTesting returns a pointer to the xds.HandshakeInfo // stored in the passed in attributes. This is set by // credentials/xds/xds.go. - GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *xds.HandshakeInfo + GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *unsafe.Pointer // GetServerCredentials returns the transport credentials configured on a // gRPC server. An xDS-enabled server needs to know what type of credentials // is configured on the underlying gRPC server. This is set by server.go. diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index 42c7665abcd..9eb27aa1d62 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -21,6 +21,8 @@ import ( "context" "encoding/json" "fmt" + "sync/atomic" + "unsafe" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" @@ -89,19 +91,21 @@ func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Bal } ctx, cancel := context.WithCancel(context.Background()) + hi := xdsinternal.NewHandshakeInfo(nil, nil, nil, false) + xdsHIPtr := unsafe.Pointer(hi) b := &cdsBalancer{ bOpts: opts, childConfigParser: parser, serializer: grpcsync.NewCallbackSerializer(ctx), serializerCancel: cancel, - xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), + xdsHIPtr: &xdsHIPtr, watchers: make(map[string]*watcherState), } b.ccw = &ccWrapper{ ClientConn: cc, - xdsHI: b.xdsHI, + xdsHIPtr: b.xdsHIPtr, } - b.logger = prefixLogger((b)) + b.logger = prefixLogger(b) b.logger.Infof("Created") var creds credentials.TransportCredentials @@ -149,11 +153,13 @@ type cdsBalancer struct { // The following fields are initialized at build time and are either // read-only after that or provide their own synchronization, and therefore // do not need to be guarded by a mutex. - ccw *ccWrapper // ClientConn interface passed to child LB. - bOpts balancer.BuildOptions // BuildOptions passed to child LB. - childConfigParser balancer.ConfigParser // Config parser for cluster_resolver LB policy. - xdsHI *xdsinternal.HandshakeInfo // Handshake info from security configuration. - logger *grpclog.PrefixLogger // Prefix logger for all logging. + ccw *ccWrapper // ClientConn interface passed to child LB. + bOpts balancer.BuildOptions // BuildOptions passed to child LB. + childConfigParser balancer.ConfigParser // Config parser for cluster_resolver LB policy. + logger *grpclog.PrefixLogger // Prefix logger for all logging. + xdsCredsInUse bool + + xdsHIPtr *unsafe.Pointer // Accessed atomically. // The serializer and its cancel func are initialized at build time, and the // rest of the fields here are only accessed from serializer callbacks (or @@ -170,7 +176,6 @@ type cdsBalancer struct { // a new provider is to be created. cachedRoot certprovider.Provider cachedIdentity certprovider.Provider - xdsCredsInUse bool } // handleSecurityConfig processes the security configuration received from the @@ -186,6 +191,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsresource.SecurityConfig) e if !b.xdsCredsInUse { return nil } + var xdsHI *xdsinternal.HandshakeInfo // Security config being nil is a valid case where the management server has // not sent any security configuration. The xdsCredentials implementation @@ -194,10 +200,10 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsresource.SecurityConfig) e // We need to explicitly set the fields to nil here since this might be // a case of switching from a good security configuration to an empty // one where fallback credentials are to be used. - b.xdsHI.SetRootCertProvider(nil) - b.xdsHI.SetIdentityCertProvider(nil) - b.xdsHI.SetSANMatchers(nil) + xdsHI = xdsinternal.NewHandshakeInfo(nil, nil, nil, false) + atomic.StorePointer(b.xdsHIPtr, unsafe.Pointer(xdsHI)) return nil + } bc := b.xdsClient.BootstrapConfig() @@ -234,12 +240,8 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsresource.SecurityConfig) e } b.cachedRoot = rootProvider b.cachedIdentity = identityProvider - - // We set all fields here, even if some of them are nil, since they - // could have been non-nil earlier. - b.xdsHI.SetRootCertProvider(rootProvider) - b.xdsHI.SetIdentityCertProvider(identityProvider) - b.xdsHI.SetSANMatchers(config.SubjectAltNameMatchers) + xdsHI = xdsinternal.NewHandshakeInfo(rootProvider, identityProvider, config.SubjectAltNameMatchers, false) + atomic.StorePointer(b.xdsHIPtr, unsafe.Pointer(xdsHI)) return nil } @@ -660,9 +662,7 @@ func (b *cdsBalancer) generateDMsForCluster(name string, depth int, dms []cluste type ccWrapper struct { balancer.ClientConn - // The certificate providers in this HandshakeInfo are updated based on the - // received security configuration in the Cluster resource. - xdsHI *xdsinternal.HandshakeInfo + xdsHIPtr *unsafe.Pointer } // NewSubConn intercepts NewSubConn() calls from the child policy and adds an @@ -671,8 +671,9 @@ type ccWrapper struct { func (ccw *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { newAddrs := make([]resolver.Address, len(addrs)) for i, addr := range addrs { - newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHI) + newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHIPtr) } + // No need to override opts.StateListener; just forward all calls to the // child that created the SubConn. return ccw.ClientConn.NewSubConn(newAddrs, opts) @@ -681,7 +682,7 @@ func (ccw *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubC func (ccw *ccWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { newAddrs := make([]resolver.Address, len(addrs)) for i, addr := range addrs { - newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHI) + newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHIPtr) } ccw.ClientConn.UpdateAddresses(sc, newAddrs) } diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index 10953010592..8f3cd9e9f7c 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "testing" + "unsafe" "github.com/google/go-cmp/cmp" "google.golang.org/grpc" @@ -75,14 +76,15 @@ func (tcc *testCCWrapper) NewSubConn(addrs []resolver.Address, opts balancer.New if len(addrs) != 1 { return nil, fmt.Errorf("NewSubConn got %d addresses, want 1", len(addrs)) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *unsafe.Pointer) hi := getHI(addrs[0].Attributes) if hi == nil { return nil, fmt.Errorf("NewSubConn got address without xDS handshake info") } + sc, err := tcc.ClientConn.NewSubConn(addrs, opts) select { - case tcc.handshakeInfoCh <- hi: + case tcc.handshakeInfoCh <- (*xdscredsinternal.HandshakeInfo)(*hi): default: } return sc, err @@ -292,7 +294,7 @@ func (s) TestSecurityConfigWithoutXDSCreds(t *testing.T) { case <-ctx.Done(): t.Fatal("Timeout when waiting to read handshake info passed to NewSubConn") } - wantHI := xdscredsinternal.NewHandshakeInfo(nil, nil) + wantHI := xdscredsinternal.NewHandshakeInfo(nil, nil, nil, false) if !cmp.Equal(gotHI, wantHI) { t.Fatalf("NewSubConn got handshake info %+v, want %+v", gotHI, wantHI) } @@ -343,7 +345,7 @@ func (s) TestNoSecurityConfigWithXDSCreds(t *testing.T) { case <-ctx.Done(): t.Fatal("Timeout when waiting to read handshake info passed to NewSubConn") } - wantHI := xdscredsinternal.NewHandshakeInfo(nil, nil) + wantHI := xdscredsinternal.NewHandshakeInfo(nil, nil, nil, false) if !cmp.Equal(gotHI, wantHI) { t.Fatalf("NewSubConn got handshake info %+v, want %+v", gotHI, wantHI) } diff --git a/xds/internal/server/conn_wrapper.go b/xds/internal/server/conn_wrapper.go index ec6da32fad1..c622455455f 100644 --- a/xds/internal/server/conn_wrapper.go +++ b/xds/internal/server/conn_wrapper.go @@ -106,7 +106,7 @@ func (c *connWrapper) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { // did not provide any security configuration and therefore we should // return an empty HandshakeInfo here so that the xdsCreds can use the // configured fallback credentials. - return xdsinternal.NewHandshakeInfo(nil, nil), nil + return xdsinternal.NewHandshakeInfo(nil, nil, nil, false), nil } cpc := c.parent.xdsC.BootstrapConfig().CertProviderConfigs @@ -128,9 +128,7 @@ func (c *connWrapper) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { c.identityProvider = ip c.rootProvider = rp - xdsHI := xdsinternal.NewHandshakeInfo(c.rootProvider, c.identityProvider) - xdsHI.SetRequireClientCert(secCfg.RequireClientCert) - return xdsHI, nil + return xdsinternal.NewHandshakeInfo(c.rootProvider, c.identityProvider, nil, secCfg.RequireClientCert), nil } // Close closes the providers and the underlying connection.