New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
xDS: Atomically read and write xDS security configuration client side #6796
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ import ( | |
"strings" | ||
"testing" | ||
"time" | ||
"unsafe" | ||
|
||
"google.golang.org/grpc/credentials" | ||
"google.golang.org/grpc/credentials/tls/certprovider" | ||
|
@@ -219,11 +220,13 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert | |
// Creating the HandshakeInfo and adding it to the attributes is very | ||
// similar to what the CDS balancer would do when it intercepts calls to | ||
// NewSubConn(). | ||
info := xdsinternal.NewHandshakeInfo(root, identity) | ||
var sm []matcher.StringMatcher | ||
if sanExactMatch != "" { | ||
info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}) | ||
sm = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)} | ||
} | ||
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info) | ||
info := xdsinternal.NewHandshakeInfo(root, identity, sm, false) | ||
uPtr := unsafe.Pointer(info) | ||
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr) | ||
|
||
// Moving the attributes from the resolver.Address to the context passed to | ||
// the handshaker is done in the transport layer. Since we directly call the | ||
|
@@ -533,13 +536,12 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { | |
// Create a root provider which will fail the handshake because it does not | ||
// use the correct trust roots. | ||
root1 := makeRootProvider(t, "x509/client_ca_cert.pem") | ||
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil) | ||
handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}) | ||
|
||
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false) | ||
// We need to repeat most of what newTestContextWithHandshakeInfo() does | ||
// here because we need access to the underlying HandshakeInfo so that we | ||
// can update it before the next call to ClientHandshake(). | ||
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo) | ||
uPtr := unsafe.Pointer(handshakeInfo) | ||
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr) | ||
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) | ||
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { | ||
t.Fatal("ClientHandshake() succeeded when expected to fail") | ||
|
@@ -560,7 +562,10 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { | |
// Create a new root provider which uses the correct trust roots. And update | ||
// the HandshakeInfo with the new provider. | ||
root2 := makeRootProvider(t, "x509/server_ca_cert.pem") | ||
handshakeInfo.SetRootCertProvider(root2) | ||
handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false) | ||
uPtr = unsafe.Pointer(handshakeInfo) | ||
addr = xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr) | ||
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just update the existing pointer instead of creating a new entry in the context? That's how the production code will work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, good point. Changed. |
||
_, ai, err := creds.ClientHandshake(ctx, authority, conn) | ||
if err != nil { | ||
t.Fatalf("ClientHandshake() returned failed: %q", err) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ import ( | |
"errors" | ||
"fmt" | ||
"strings" | ||
"sync" | ||
"unsafe" | ||
|
||
"google.golang.org/grpc/attributes" | ||
"google.golang.org/grpc/credentials/tls/certprovider" | ||
|
@@ -66,16 +66,16 @@ func (hi *HandshakeInfo) Equal(other *HandshakeInfo) bool { | |
} | ||
|
||
// SetHandshakeInfo returns a copy of addr in which the Attributes field is | ||
// updated with hInfo. | ||
func SetHandshakeInfo(addr resolver.Address, hInfo *HandshakeInfo) resolver.Address { | ||
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hInfo) | ||
// updated with hiPtr. | ||
func SetHandshakeInfo(addr resolver.Address, hiPtr *unsafe.Pointer) resolver.Address { | ||
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hiPtr) | ||
return addr | ||
} | ||
|
||
// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe "pointer to the *HandshakeInfo"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah fair. Switched. |
||
func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo { | ||
func GetHandshakeInfo(attr *attributes.Attributes) *unsafe.Pointer { | ||
v := attr.Value(handshakeAttrKey{}) | ||
hi, _ := v.(*HandshakeInfo) | ||
hi, _ := v.(*unsafe.Pointer) | ||
return hi | ||
} | ||
|
||
|
@@ -85,40 +85,21 @@ func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo { | |
// | ||
// Safe for concurrent access. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Delete this now? Or say it's immutable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like technically it's fine to call any methods on it concurrently since it's all read only and set at init time, but since it's not really important to call out/doesn't apply I'll delete it. |
||
type HandshakeInfo struct { | ||
mu sync.Mutex | ||
// All fields written at init time and read only after that, so no | ||
// synchronization needed. | ||
rootProvider certprovider.Provider | ||
identityProvider certprovider.Provider | ||
sanMatchers []matcher.StringMatcher // Only on the client side. | ||
requireClientCert bool // Only on server side. | ||
} | ||
|
||
// SetRootCertProvider updates the root certificate provider. | ||
func (hi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) { | ||
hi.mu.Lock() | ||
hi.rootProvider = root | ||
hi.mu.Unlock() | ||
} | ||
|
||
// SetIdentityCertProvider updates the identity certificate provider. | ||
func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) { | ||
hi.mu.Lock() | ||
hi.identityProvider = identity | ||
hi.mu.Unlock() | ||
} | ||
|
||
// SetSANMatchers updates the list of SAN matchers. | ||
func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []matcher.StringMatcher) { | ||
hi.mu.Lock() | ||
hi.sanMatchers = sanMatchers | ||
hi.mu.Unlock() | ||
} | ||
|
||
// SetRequireClientCert updates whether a client cert is required during the | ||
// ServerHandshake(). A value of true indicates that we are performing mTLS. | ||
func (hi *HandshakeInfo) SetRequireClientCert(require bool) { | ||
hi.mu.Lock() | ||
hi.requireClientCert = require | ||
hi.mu.Unlock() | ||
func NewHandshakeInfo(rootProvider certprovider.Provider, identityProvider certprovider.Provider, sanMatchers []matcher.StringMatcher, requireClientCert bool) *HandshakeInfo { | ||
return &HandshakeInfo{ | ||
rootProvider: rootProvider, | ||
identityProvider: identityProvider, | ||
sanMatchers: sanMatchers, | ||
requireClientCert: requireClientCert, | ||
} | ||
} | ||
|
||
// UseFallbackCreds returns true when fallback credentials are to be used based | ||
|
@@ -127,24 +108,18 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool { | |
if hi == nil { | ||
return true | ||
} | ||
|
||
hi.mu.Lock() | ||
defer hi.mu.Unlock() | ||
return hi.identityProvider == nil && hi.rootProvider == nil | ||
} | ||
|
||
// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo. | ||
// To be used only for testing purposes. | ||
func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher { | ||
hi.mu.Lock() | ||
defer hi.mu.Unlock() | ||
return append([]matcher.StringMatcher{}, hi.sanMatchers...) | ||
} | ||
|
||
// ClientSideTLSConfig constructs a tls.Config to be used in a client-side | ||
// handshake based on the contents of the HandshakeInfo. | ||
func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) { | ||
hi.mu.Lock() | ||
// On the client side, rootProvider is mandatory. IdentityProvider is | ||
// optional based on whether the client is doing TLS or mTLS. | ||
if hi.rootProvider == nil { | ||
|
@@ -153,7 +128,6 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, | |
// Since the call to KeyMaterial() can block, we read the providers under | ||
// the lock but call the actual function after releasing the lock. | ||
rootProv, idProv := hi.rootProvider, hi.identityProvider | ||
hi.mu.Unlock() | ||
|
||
// InsecureSkipVerify needs to be set to true because we need to perform | ||
// custom verification to check the SAN on the received certificate. | ||
|
@@ -188,7 +162,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, | |
ClientAuth: tls.NoClientCert, | ||
NextProtos: []string{"h2"}, | ||
} | ||
hi.mu.Lock() | ||
// On the server side, identityProvider is mandatory. RootProvider is | ||
// optional based on whether the server is doing TLS or mTLS. | ||
if hi.identityProvider == nil { | ||
|
@@ -200,7 +173,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, | |
if hi.requireClientCert { | ||
cfg.ClientAuth = tls.RequireAndVerifyClientCert | ||
} | ||
hi.mu.Unlock() | ||
|
||
// identityProvider is mandatory on the server side. | ||
km, err := idProv.KeyMaterial(ctx) | ||
|
@@ -225,8 +197,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, | |
// If the list of SAN matchers in the HandshakeInfo is empty, this function | ||
// returns true for all input certificates. | ||
func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool { | ||
hi.mu.Lock() | ||
defer hi.mu.Unlock() | ||
if len(hi.sanMatchers) == 0 { | ||
return true | ||
} | ||
|
@@ -325,9 +295,3 @@ func dnsMatch(host, san string) bool { | |
hostPrefix := strings.TrimSuffix(host, san[1:]) | ||
return !strings.Contains(hostPrefix, ".") | ||
} | ||
|
||
// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root | ||
// and identity certificate providers. | ||
func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo { | ||
return &HandshakeInfo{rootProvider: root, identityProvider: identity} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit/optional:
sms
to pluralize string matcher.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair. Switched.