Skip to content

Commit

Permalink
xds: support matchers for accepted SANs
Browse files Browse the repository at this point in the history
  • Loading branch information
easwars committed Mar 5, 2021
1 parent 930c791 commit 96376a2
Show file tree
Hide file tree
Showing 9 changed files with 806 additions and 63 deletions.
23 changes: 16 additions & 7 deletions credentials/xds/xds_client_test.go
Expand Up @@ -36,14 +36,15 @@ import (
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/xds"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/testdata"
)

const (
defaultTestTimeout = 10 * time.Second
defaultTestShortTimeout = 10 * time.Millisecond
defaultTestCertSAN = "*.test.example.com"
defaultTestCertSAN = "abc.test.example.com"
authority = "authority"
)

Expand Down Expand Up @@ -214,11 +215,14 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider {

// newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
// context value added to it.
func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context {
func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context {
// Creating the HandshakeInfo and adding it to the attributes is very
// similar to what the CDS balancer would do when it intercepts calls to
// NewSubConn().
info := xdsinternal.NewHandshakeInfo(root, identity, sans...)
info := xdsinternal.NewHandshakeInfo(root, identity)
if sanExactMatch != "" {
info.SetSANMatchers([]xds.StringMatcher{{ExactMatch: newStringP(sanExactMatch)}})
}
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)

// Moving the attributes from the resolver.Address to the context passed to
Expand Down Expand Up @@ -292,7 +296,7 @@ func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {

pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{})
ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "")
if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil {
t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo")
}
Expand Down Expand Up @@ -329,7 +333,7 @@ func (s) TestClientCredsProviderFailure(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider)
ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "")
if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
}
Expand Down Expand Up @@ -371,7 +375,7 @@ func (s) TestClientCredsSuccess(t *testing.T) {
desc: "mTLS with no acceptedSANs specified",
handshakeFunc: testServerMutualTLSHandshake,
handshakeInfoCtx: func(ctx context.Context) context.Context {
return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"))
return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "")
},
},
}
Expand Down Expand Up @@ -530,7 +534,8 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
// Create a root provider which will fail the handshake because it does not
// use the correct trust roots.
root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, defaultTestCertSAN)
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil)
handshakeInfo.SetSANMatchers([]xds.StringMatcher{{ExactMatch: newStringP(defaultTestCertSAN)}})

// We need to repeat most of what newTestContextWithHandshakeInfo() does
// here because we need access to the underlying HandshakeInfo so that we
Expand Down Expand Up @@ -582,3 +587,7 @@ func (s) TestClientClone(t *testing.T) {
t.Fatal("return value from Clone() doesn't point to new credentials instance")
}
}

func newStringP(s string) *string {
return &s
}
173 changes: 143 additions & 30 deletions internal/credentials/xds/handshake_info.go
Expand Up @@ -25,11 +25,13 @@ import (
"crypto/x509"
"errors"
"fmt"
"strings"
"sync"

"google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/internal"
xdsinternal "google.golang.org/grpc/internal/xds"
"google.golang.org/grpc/resolver"
)

Expand Down Expand Up @@ -64,8 +66,8 @@ type HandshakeInfo struct {
mu sync.Mutex
rootProvider certprovider.Provider
identityProvider certprovider.Provider
acceptedSANs map[string]bool // Only on the client side.
requireClientCert bool // Only on server side.
sanMatchers []xdsinternal.StringMatcher // Only on the client side.
requireClientCert bool // Only on server side.
}

// SetRootCertProvider updates the root certificate provider.
Expand All @@ -82,13 +84,10 @@ func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider)
hi.mu.Unlock()
}

// SetAcceptedSANs updates the list of accepted SANs.
func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) {
// SetSANMatchers updates the list of SAN matchers.
func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []xdsinternal.StringMatcher) {
hi.mu.Lock()
hi.acceptedSANs = make(map[string]bool, len(sans))
for _, san := range sans {
hi.acceptedSANs[san] = true
}
hi.sanMatchers = sanMatchers
hi.mu.Unlock()
}

Expand All @@ -112,6 +111,14 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool {
return hi.identityProvider == nil && hi.rootProvider == nil
}

// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo.
// To be used only for testing purposes.
func (hi *HandshakeInfo) GetSANMatchersForTesting() []xdsinternal.StringMatcher {
hi.mu.Lock()
defer hi.mu.Unlock()
return append([]xdsinternal.StringMatcher{}, hi.sanMatchers...)
}

// ClientSideTLSConfig constructs a tls.Config to be used in a client-side
// handshake based on the contents of the HandshakeInfo.
func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) {
Expand Down Expand Up @@ -184,47 +191,153 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
return cfg, nil
}

// MatchingSANExists returns true if the SAN contained in the passed in
// certificate is present in the list of accepted SANs in the HandshakeInfo.
// MatchingSANExists returns true if the SANs contained in cert match the
// criteria enforced by the list of SAN matchers in HandshakeInfo.
//
// If the list of accepted SANs in the HandshakeInfo is empty, this function
// If the list of SAN matchers in the HandshakeInfo is empty, this function
// returns true for all input certificates.
func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool {
if len(hi.acceptedSANs) == 0 {
hi.mu.Lock()
defer hi.mu.Unlock()
if len(hi.sanMatchers) == 0 {
return true
}

var sans []string
// SANs can be specified in any of these four fields on the parsed cert.
sans = append(sans, cert.DNSNames...)
sans = append(sans, cert.EmailAddresses...)
for _, ip := range cert.IPAddresses {
sans = append(sans, ip.String())
for _, san := range cert.DNSNames {
if hi.matchSAN(san, true) {
return true
}
}
for _, uri := range cert.URIs {
sans = append(sans, uri.String())
for _, san := range cert.EmailAddresses {
if hi.matchSAN(san, false) {
return true
}
}

hi.mu.Lock()
defer hi.mu.Unlock()
for _, san := range sans {
if hi.acceptedSANs[san] {
for _, san := range cert.IPAddresses {
if hi.matchSAN(san.String(), false) {
return true
}
}
for _, san := range cert.URIs {
if hi.matchSAN(san.String(), false) {
return true
}
}
return false
}

// Caller must hold mu.
func (hi *HandshakeInfo) matchSAN(san string, isDNS bool) bool {
for _, matcher := range hi.sanMatchers {
if matcher.IgnoreCase {
san = strings.ToLower(san)
}
switch {
case matcher.ExactMatch != nil:
if isDNS {
// This is a special case which is documented in the xDS protos.
// If the DNS SAN is a wildcard entry, and the match criteria is
// `exact`, then we need to perform DNS wildcard matching
// instead of regular string comparison.
if dnsMatch(*matcher.ExactMatch, san) {
return true
}
continue
}

pattern := *matcher.ExactMatch
if matcher.IgnoreCase {
pattern = strings.ToLower(pattern)
}
if san == pattern {
return true
}
case matcher.PrefixMatch != nil:
pattern := *matcher.PrefixMatch
if matcher.IgnoreCase {
pattern = strings.ToLower(pattern)
}
if strings.HasPrefix(san, pattern) {
return true
}
case matcher.SuffixMatch != nil:
pattern := *matcher.SuffixMatch
if matcher.IgnoreCase {
pattern = strings.ToLower(pattern)
}
if strings.HasSuffix(san, pattern) {
return true
}
case matcher.RegexMatch != nil:
if matcher.RegexMatch.MatchString(san) {
return true
}
case matcher.ContainsMatch != nil:
pattern := *matcher.ContainsMatch
if matcher.IgnoreCase {
pattern = strings.ToLower(pattern)
}
if strings.Contains(san, pattern) {
return true
}
}
}
return false
}

// dnsMatch implements a DNS wildcard matching algorithm based on RFC2828 and
// grpc-java's implementation in `OkHostnameVerifier` class.
func dnsMatch(host, pattern string) bool {
// Add trailing "." and turn them into absolute domain names.
if !strings.HasSuffix(host, ".") {
host += "."
}
if !strings.HasSuffix(pattern, ".") {
pattern += "."
}
// Domain names are case-insensitive.
host = strings.ToLower(host)
pattern = strings.ToLower(pattern)

// If pattern does not contain a wildcard pattern, do exact match.
if !strings.Contains(pattern, "*") {
return host == pattern
}

// Wildcard pattern rules
// - '*' is only permitted in the left-most label and must be the only
// character in that label. For example, *.example.com is permitted, while
// *a.example.com, a*.example.com, a*b.example.com, a.*.example.com are
// not permitted.
// - '*' matches a single domain name component. For example, *.example.com
// matches test.example.com but does not match sub.test.example.com.
// - Wildcard patterns for single-label domain names are not permitted.
if pattern == "*." || !strings.HasPrefix(pattern, "*.") || strings.Contains(pattern[1:], "*") {
return false
}
// Optimization: at this point, we know that the pattern contains a '*' and
// is the first domain component of pattern. So, the host name must be at
// least as long as the pattern to be able to match.
if len(host) < len(pattern) {
return false
}
// Hostname must end with the non-wildcard portion of pattern.
if !strings.HasSuffix(host, pattern[1:]) {
return false
}
// At this point we know that the hostName and pattern share the same suffix
// (the non-wildcard portion of pattern). Now, we just need to make sure
// that the '*' does not match across domain components.
hostPrefix := strings.TrimSuffix(host, pattern[1:])
return !strings.Contains(hostPrefix, ".")
}

// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root
// and identity certificate providers.
func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *HandshakeInfo {
acceptedSANs := make(map[string]bool, len(sans))
for _, san := range sans {
acceptedSANs[san] = true
}
func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo {
return &HandshakeInfo{
rootProvider: root,
identityProvider: identity,
acceptedSANs: acceptedSANs,
}
}

0 comments on commit 96376a2

Please sign in to comment.