Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xDS: Atomically read and write xDS security configuration client side #6796

Merged
merged 2 commits into from Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion credentials/xds/xds.go
Expand Up @@ -27,6 +27,7 @@ import (
"errors"
"fmt"
"net"
"sync/atomic"
"time"

"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -114,7 +115,9 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
if chi.Attributes == nil {
return c.fallback.ClientHandshake(ctx, authority, rawConn)
}
hi := xdsinternal.GetHandshakeInfo(chi.Attributes)

uPtr := xdsinternal.GetHandshakeInfo(chi.Attributes)
hi := (*xdsinternal.HandshakeInfo)(atomic.LoadPointer(uPtr))
if hi.UseFallbackCreds() {
return c.fallback.ClientHandshake(ctx, authority, rawConn)
}
Expand Down
21 changes: 13 additions & 8 deletions credentials/xds/xds_client_test.go
Expand Up @@ -29,6 +29,7 @@ import (
"strings"
"testing"
"time"
"unsafe"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
Expand Down Expand Up @@ -219,11 +220,13 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert
// Creating the HandshakeInfo and adding it to the attributes is very
// similar to what the CDS balancer would do when it intercepts calls to
// NewSubConn().
info := xdsinternal.NewHandshakeInfo(root, identity)
var sm []matcher.StringMatcher
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit/optional: sms to pluralize string matcher.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair. Switched.

if sanExactMatch != "" {
info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)})
sm = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}
}
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)
info := xdsinternal.NewHandshakeInfo(root, identity, sm, false)
uPtr := unsafe.Pointer(info)
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)

// Moving the attributes from the resolver.Address to the context passed to
// the handshaker is done in the transport layer. Since we directly call the
Expand Down Expand Up @@ -533,13 +536,12 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
// Create a root provider which will fail the handshake because it does not
// use the correct trust roots.
root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil)
handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)})

handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
// We need to repeat most of what newTestContextWithHandshakeInfo() does
// here because we need access to the underlying HandshakeInfo so that we
// can update it before the next call to ClientHandshake().
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
uPtr := unsafe.Pointer(handshakeInfo)
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
t.Fatal("ClientHandshake() succeeded when expected to fail")
Expand All @@ -560,7 +562,10 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
// Create a new root provider which uses the correct trust roots. And update
// the HandshakeInfo with the new provider.
root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
handshakeInfo.SetRootCertProvider(root2)
handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
uPtr = unsafe.Pointer(handshakeInfo)
addr = xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just update the existing pointer instead of creating a new entry in the context? That's how the production code will work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point. Changed.

_, ai, err := creds.ClientHandshake(ctx, authority, conn)
if err != nil {
t.Fatalf("ClientHandshake() returned failed: %q", err)
Expand Down
19 changes: 7 additions & 12 deletions credentials/xds/xds_server_test.go
Expand Up @@ -122,7 +122,7 @@ func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) {
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
}

info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil)
info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil, nil, false)
conn := newWrappedConn(nil, info, time.Time{})
if _, _, err := creds.ServerHandshake(conn); err == nil {
t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo")
Expand Down Expand Up @@ -158,7 +158,7 @@ func (s) TestServerCredsProviderFailure(t *testing.T) {
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider)
info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, false)
conn := newWrappedConn(nil, info, time.Time{})
if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
Expand Down Expand Up @@ -232,8 +232,7 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) {
// Create a test server which uses the xDS server credentials created above
// to perform TLS handshake on incoming connections.
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"))
hi.SetRequireClientCert(true)
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true)

// Create a wrapped conn which can return the HandshakeInfo created
// above with a very small deadline.
Expand Down Expand Up @@ -285,8 +284,7 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) {
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
// Create a HandshakeInfo which has a root provider which does not match
// the certificate sent by the client.
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
hi.SetRequireClientCert(true)
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)

// Create a wrapped conn which can return the HandshakeInfo and
// configured deadline to the xDS credentials' ServerHandshake()
Expand Down Expand Up @@ -367,8 +365,7 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) {
// created above to perform TLS handshake on incoming connections.
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
// Create a HandshakeInfo with information from the test table.
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider)
hi.SetRequireClientCert(test.requireClientCert)
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert)

// Create a wrapped conn which can return the HandshakeInfo and
// configured deadline to the xDS credentials' ServerHandshake()
Expand Down Expand Up @@ -448,8 +445,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
if cnt == 1 {
// Create a HandshakeInfo which has a root provider which does not match
// the certificate sent by the client.
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
hi.SetRequireClientCert(true)
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)

// Create a wrapped conn which can return the HandshakeInfo and
// configured deadline to the xDS credentials' ServerHandshake()
Expand All @@ -463,8 +459,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
return handshakeResult{}
}

hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"))
hi.SetRequireClientCert(true)
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), nil, true)

// Create a wrapped conn which can return the HandshakeInfo and
// configured deadline to the xDS credentials' ServerHandshake()
Expand Down
66 changes: 15 additions & 51 deletions internal/credentials/xds/handshake_info.go
Expand Up @@ -26,7 +26,7 @@ import (
"errors"
"fmt"
"strings"
"sync"
"unsafe"

"google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials/tls/certprovider"
Expand Down Expand Up @@ -66,16 +66,16 @@ func (hi *HandshakeInfo) Equal(other *HandshakeInfo) bool {
}

// SetHandshakeInfo returns a copy of addr in which the Attributes field is
// updated with hInfo.
func SetHandshakeInfo(addr resolver.Address, hInfo *HandshakeInfo) resolver.Address {
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hInfo)
// updated with hiPtr.
func SetHandshakeInfo(addr resolver.Address, hiPtr *unsafe.Pointer) resolver.Address {
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hiPtr)
return addr
}

// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "pointer to the *HandshakeInfo"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah fair. Switched.

func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo {
func GetHandshakeInfo(attr *attributes.Attributes) *unsafe.Pointer {
v := attr.Value(handshakeAttrKey{})
hi, _ := v.(*HandshakeInfo)
hi, _ := v.(*unsafe.Pointer)
return hi
}

Expand All @@ -85,40 +85,21 @@ func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo {
//
// Safe for concurrent access.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this now? Or say it's immutable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like technically it's fine to call any methods on it concurrently since it's all read only and set at init time, but since it's not really important to call out/doesn't apply I'll delete it.

type HandshakeInfo struct {
mu sync.Mutex
// All fields written at init time and read only after that, so no
// synchronization needed.
rootProvider certprovider.Provider
identityProvider certprovider.Provider
sanMatchers []matcher.StringMatcher // Only on the client side.
requireClientCert bool // Only on server side.
}

// SetRootCertProvider updates the root certificate provider.
func (hi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) {
hi.mu.Lock()
hi.rootProvider = root
hi.mu.Unlock()
}

// SetIdentityCertProvider updates the identity certificate provider.
func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) {
hi.mu.Lock()
hi.identityProvider = identity
hi.mu.Unlock()
}

// SetSANMatchers updates the list of SAN matchers.
func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []matcher.StringMatcher) {
hi.mu.Lock()
hi.sanMatchers = sanMatchers
hi.mu.Unlock()
}

// SetRequireClientCert updates whether a client cert is required during the
// ServerHandshake(). A value of true indicates that we are performing mTLS.
func (hi *HandshakeInfo) SetRequireClientCert(require bool) {
hi.mu.Lock()
hi.requireClientCert = require
hi.mu.Unlock()
func NewHandshakeInfo(rootProvider certprovider.Provider, identityProvider certprovider.Provider, sanMatchers []matcher.StringMatcher, requireClientCert bool) *HandshakeInfo {
return &HandshakeInfo{
rootProvider: rootProvider,
identityProvider: identityProvider,
sanMatchers: sanMatchers,
requireClientCert: requireClientCert,
}
}

// UseFallbackCreds returns true when fallback credentials are to be used based
Expand All @@ -127,24 +108,18 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool {
if hi == nil {
return true
}

hi.mu.Lock()
defer hi.mu.Unlock()
return hi.identityProvider == nil && hi.rootProvider == nil
}

// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo.
// To be used only for testing purposes.
func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher {
hi.mu.Lock()
defer hi.mu.Unlock()
return append([]matcher.StringMatcher{}, hi.sanMatchers...)
}

// ClientSideTLSConfig constructs a tls.Config to be used in a client-side
// handshake based on the contents of the HandshakeInfo.
func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) {
hi.mu.Lock()
// On the client side, rootProvider is mandatory. IdentityProvider is
// optional based on whether the client is doing TLS or mTLS.
if hi.rootProvider == nil {
Expand All @@ -153,7 +128,6 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config,
// Since the call to KeyMaterial() can block, we read the providers under
// the lock but call the actual function after releasing the lock.
rootProv, idProv := hi.rootProvider, hi.identityProvider
hi.mu.Unlock()

// InsecureSkipVerify needs to be set to true because we need to perform
// custom verification to check the SAN on the received certificate.
Expand Down Expand Up @@ -188,7 +162,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
ClientAuth: tls.NoClientCert,
NextProtos: []string{"h2"},
}
hi.mu.Lock()
// On the server side, identityProvider is mandatory. RootProvider is
// optional based on whether the server is doing TLS or mTLS.
if hi.identityProvider == nil {
Expand All @@ -200,7 +173,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
if hi.requireClientCert {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
hi.mu.Unlock()

// identityProvider is mandatory on the server side.
km, err := idProv.KeyMaterial(ctx)
Expand All @@ -225,8 +197,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
// If the list of SAN matchers in the HandshakeInfo is empty, this function
// returns true for all input certificates.
func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool {
hi.mu.Lock()
defer hi.mu.Unlock()
if len(hi.sanMatchers) == 0 {
return true
}
Expand Down Expand Up @@ -325,9 +295,3 @@ func dnsMatch(host, san string) bool {
hostPrefix := strings.TrimSuffix(host, san[1:])
return !strings.Contains(hostPrefix, ".")
}

// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root
// and identity certificate providers.
func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo {
return &HandshakeInfo{rootProvider: root, identityProvider: identity}
}
6 changes: 2 additions & 4 deletions internal/credentials/xds/handshake_info_test.go
Expand Up @@ -188,8 +188,7 @@ func TestMatchingSANExists_FailureCases(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
hi := NewHandshakeInfo(nil, nil)
hi.SetSANMatchers(test.sanMatchers)
hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false)

if hi.MatchingSANExists(inputCert) {
t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers)
Expand Down Expand Up @@ -289,8 +288,7 @@ func TestMatchingSANExists_Success(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
hi := NewHandshakeInfo(nil, nil)
hi.SetSANMatchers(test.sanMatchers)
hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false)

if !hi.MatchingSANExists(inputCert) {
t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers)
Expand Down
2 changes: 1 addition & 1 deletion internal/internal.go
Expand Up @@ -57,7 +57,7 @@ var (
// GetXDSHandshakeInfoForTesting returns a pointer to the xds.HandshakeInfo
// stored in the passed in attributes. This is set by
// credentials/xds/xds.go.
GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *xds.HandshakeInfo
GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *unsafe.Pointer
// GetServerCredentials returns the transport credentials configured on a
// gRPC server. An xDS-enabled server needs to know what type of credentials
// is configured on the underlying gRPC server. This is set by server.go.
Expand Down