diff --git a/credentials/xds/xds_client_test.go b/credentials/xds/xds_client_test.go index 219d0aefcba6..f4e9540a60b6 100644 --- a/credentials/xds/xds_client_test.go +++ b/credentials/xds/xds_client_test.go @@ -36,6 +36,7 @@ import ( xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" "google.golang.org/grpc/testdata" ) @@ -43,7 +44,7 @@ import ( const ( defaultTestTimeout = 10 * time.Second defaultTestShortTimeout = 10 * time.Millisecond - defaultTestCertSAN = "*.test.example.com" + defaultTestCertSAN = "abc.test.example.com" authority = "authority" ) @@ -214,11 +215,14 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider { // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo // context value added to it. -func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context { +func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context { // Creating the HandshakeInfo and adding it to the attributes is very // similar to what the CDS balancer would do when it intercepts calls to // NewSubConn(). - info := xdsinternal.NewHandshakeInfo(root, identity, sans...) + info := xdsinternal.NewHandshakeInfo(root, identity) + if sanExactMatch != "" { + info.SetSANMatchers([]xds.StringMatcher{{ExactMatch: newStringP(sanExactMatch)}}) + } addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info) // Moving the attributes from the resolver.Address to the context passed to @@ -292,7 +296,7 @@ func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) { pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}) + ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "") if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil { t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo") } @@ -329,7 +333,7 @@ func (s) TestClientCredsProviderFailure(t *testing.T) { t.Run(test.desc, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider) + ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "") if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) { t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr) } @@ -371,7 +375,7 @@ func (s) TestClientCredsSuccess(t *testing.T) { desc: "mTLS with no acceptedSANs specified", handshakeFunc: testServerMutualTLSHandshake, handshakeInfoCtx: func(ctx context.Context) context.Context { - return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem")) + return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "") }, }, } @@ -530,7 +534,8 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { // Create a root provider which will fail the handshake because it does not // use the correct trust roots. root1 := makeRootProvider(t, "x509/client_ca_cert.pem") - handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, defaultTestCertSAN) + handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil) + handshakeInfo.SetSANMatchers([]xds.StringMatcher{{ExactMatch: newStringP(defaultTestCertSAN)}}) // We need to repeat most of what newTestContextWithHandshakeInfo() does // here because we need access to the underlying HandshakeInfo so that we @@ -582,3 +587,7 @@ func (s) TestClientClone(t *testing.T) { t.Fatal("return value from Clone() doesn't point to new credentials instance") } } + +func newStringP(s string) *string { + return &s +} diff --git a/internal/credentials/xds/handshake_info.go b/internal/credentials/xds/handshake_info.go index 8b2035660639..5b4994757f35 100644 --- a/internal/credentials/xds/handshake_info.go +++ b/internal/credentials/xds/handshake_info.go @@ -25,11 +25,13 @@ import ( "crypto/x509" "errors" "fmt" + "strings" "sync" "google.golang.org/grpc/attributes" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" + xdsinternal "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" ) @@ -64,8 +66,8 @@ type HandshakeInfo struct { mu sync.Mutex rootProvider certprovider.Provider identityProvider certprovider.Provider - acceptedSANs map[string]bool // Only on the client side. - requireClientCert bool // Only on server side. + sanMatchers []xdsinternal.StringMatcher // Only on the client side. + requireClientCert bool // Only on server side. } // SetRootCertProvider updates the root certificate provider. @@ -82,13 +84,10 @@ func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) hi.mu.Unlock() } -// SetAcceptedSANs updates the list of accepted SANs. -func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) { +// SetSANMatchers updates the list of SAN matchers. +func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []xdsinternal.StringMatcher) { hi.mu.Lock() - hi.acceptedSANs = make(map[string]bool, len(sans)) - for _, san := range sans { - hi.acceptedSANs[san] = true - } + hi.sanMatchers = sanMatchers hi.mu.Unlock() } @@ -112,6 +111,14 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool { return hi.identityProvider == nil && hi.rootProvider == nil } +// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo. +// To be used only for testing purposes. +func (hi *HandshakeInfo) GetSANMatchersForTesting() []xdsinternal.StringMatcher { + hi.mu.Lock() + defer hi.mu.Unlock() + return append([]xdsinternal.StringMatcher{}, hi.sanMatchers...) +} + // ClientSideTLSConfig constructs a tls.Config to be used in a client-side // handshake based on the contents of the HandshakeInfo. func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) { @@ -184,47 +191,153 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, return cfg, nil } -// MatchingSANExists returns true if the SAN contained in the passed in -// certificate is present in the list of accepted SANs in the HandshakeInfo. +// MatchingSANExists returns true if the SANs contained in cert match the +// criteria enforced by the list of SAN matchers in HandshakeInfo. // -// If the list of accepted SANs in the HandshakeInfo is empty, this function +// If the list of SAN matchers in the HandshakeInfo is empty, this function // returns true for all input certificates. func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool { - if len(hi.acceptedSANs) == 0 { + hi.mu.Lock() + defer hi.mu.Unlock() + if len(hi.sanMatchers) == 0 { return true } - var sans []string // SANs can be specified in any of these four fields on the parsed cert. - sans = append(sans, cert.DNSNames...) - sans = append(sans, cert.EmailAddresses...) - for _, ip := range cert.IPAddresses { - sans = append(sans, ip.String()) + for _, san := range cert.DNSNames { + if hi.matchSAN(san, true) { + return true + } } - for _, uri := range cert.URIs { - sans = append(sans, uri.String()) + for _, san := range cert.EmailAddresses { + if hi.matchSAN(san, false) { + return true + } } - - hi.mu.Lock() - defer hi.mu.Unlock() - for _, san := range sans { - if hi.acceptedSANs[san] { + for _, san := range cert.IPAddresses { + if hi.matchSAN(san.String(), false) { + return true + } + } + for _, san := range cert.URIs { + if hi.matchSAN(san.String(), false) { return true } } return false } +// Caller must hold mu. +func (hi *HandshakeInfo) matchSAN(san string, isDNS bool) bool { + for _, matcher := range hi.sanMatchers { + if matcher.IgnoreCase { + san = strings.ToLower(san) + } + switch { + case matcher.ExactMatch != nil: + if isDNS { + // This is a special case which is documented in the xDS protos. + // If the DNS SAN is a wildcard entry, and the match criteria is + // `exact`, then we need to perform DNS wildcard matching + // instead of regular string comparison. + if dnsMatch(*matcher.ExactMatch, san) { + return true + } + continue + } + + pattern := *matcher.ExactMatch + if matcher.IgnoreCase { + pattern = strings.ToLower(pattern) + } + if san == pattern { + return true + } + case matcher.PrefixMatch != nil: + pattern := *matcher.PrefixMatch + if matcher.IgnoreCase { + pattern = strings.ToLower(pattern) + } + if strings.HasPrefix(san, pattern) { + return true + } + case matcher.SuffixMatch != nil: + pattern := *matcher.SuffixMatch + if matcher.IgnoreCase { + pattern = strings.ToLower(pattern) + } + if strings.HasSuffix(san, pattern) { + return true + } + case matcher.RegexMatch != nil: + if matcher.RegexMatch.MatchString(san) { + return true + } + case matcher.ContainsMatch != nil: + pattern := *matcher.ContainsMatch + if matcher.IgnoreCase { + pattern = strings.ToLower(pattern) + } + if strings.Contains(san, pattern) { + return true + } + } + } + return false +} + +// dnsMatch implements a DNS wildcard matching algorithm based on RFC2828 and +// grpc-java's implementation in `OkHostnameVerifier` class. +func dnsMatch(host, pattern string) bool { + // Add trailing "." and turn them into absolute domain names. + if !strings.HasSuffix(host, ".") { + host += "." + } + if !strings.HasSuffix(pattern, ".") { + pattern += "." + } + // Domain names are case-insensitive. + host = strings.ToLower(host) + pattern = strings.ToLower(pattern) + + // If pattern does not contain a wildcard pattern, do exact match. + if !strings.Contains(pattern, "*") { + return host == pattern + } + + // Wildcard pattern rules + // - '*' is only permitted in the left-most label and must be the only + // character in that label. For example, *.example.com is permitted, while + // *a.example.com, a*.example.com, a*b.example.com, a.*.example.com are + // not permitted. + // - '*' matches a single domain name component. For example, *.example.com + // matches test.example.com but does not match sub.test.example.com. + // - Wildcard patterns for single-label domain names are not permitted. + if pattern == "*." || !strings.HasPrefix(pattern, "*.") || strings.Contains(pattern[1:], "*") { + return false + } + // Optimization: at this point, we know that the pattern contains a '*' and + // is the first domain component of pattern. So, the host name must be at + // least as long as the pattern to be able to match. + if len(host) < len(pattern) { + return false + } + // Hostname must end with the non-wildcard portion of pattern. + if !strings.HasSuffix(host, pattern[1:]) { + return false + } + // At this point we know that the hostName and pattern share the same suffix + // (the non-wildcard portion of pattern). Now, we just need to make sure + // that the '*' does not match across domain components. + hostPrefix := strings.TrimSuffix(host, pattern[1:]) + return !strings.Contains(hostPrefix, ".") +} + // NewHandshakeInfo returns a new instance of HandshakeInfo with the given root // and identity certificate providers. -func NewHandshakeInfo(root, identity certprovider.Provider, sans ...string) *HandshakeInfo { - acceptedSANs := make(map[string]bool, len(sans)) - for _, san := range sans { - acceptedSANs[san] = true - } +func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo { return &HandshakeInfo{ rootProvider: root, identityProvider: identity, - acceptedSANs: acceptedSANs, } } diff --git a/internal/credentials/xds/handshake_info_test.go b/internal/credentials/xds/handshake_info_test.go new file mode 100644 index 000000000000..003f01cae779 --- /dev/null +++ b/internal/credentials/xds/handshake_info_test.go @@ -0,0 +1,317 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "crypto/x509" + "net" + "net/url" + "regexp" + "testing" + + xdsinternal "google.golang.org/grpc/internal/xds" +) + +func TestDNSMatch(t *testing.T) { + tests := []struct { + desc string + host string + pattern string + wantMatch bool + }{ + { + desc: "invalid wildcard 1", + host: "aa.example.com", + pattern: "*a.example.com", + wantMatch: false, + }, + { + desc: "invalid wildcard 2", + host: "aa.example.com", + pattern: "a*.example.com", + wantMatch: false, + }, + { + desc: "invalid wildcard 3", + host: "abc.example.com", + pattern: "a*c.example.com", + wantMatch: false, + }, + { + desc: "wildcard in one of the middle components", + host: "abc.test.example.com", + pattern: "abc.*.example.com", + wantMatch: false, + }, + { + desc: "single component wildcard", + host: "a.example.com", + pattern: "*", + wantMatch: false, + }, + { + desc: "short host name", + host: "a.com", + pattern: "*.example.com", + wantMatch: false, + }, + { + desc: "suffix mismatch", + host: "a.notexample.com", + pattern: "*.example.com", + wantMatch: false, + }, + { + desc: "wildcard match across components", + host: "sub.test.example.com", + pattern: "*.example.com.", + wantMatch: false, + }, + { + desc: "host doesn't end in period", + host: "test.example.com", + pattern: "test.example.com.", + wantMatch: true, + }, + { + desc: "pattern doesn't end in period", + host: "test.example.com.", + pattern: "test.example.com", + wantMatch: true, + }, + { + desc: "case insensitive", + host: "TEST.EXAMPLE.COM.", + pattern: "test.example.com.", + wantMatch: true, + }, + { + desc: "simple match", + host: "test.example.com", + pattern: "test.example.com", + wantMatch: true, + }, + { + desc: "good wildcard", + host: "a.example.com", + pattern: "*.example.com", + wantMatch: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotMatch := dnsMatch(test.host, test.pattern) + if gotMatch != test.wantMatch { + t.Fatalf("dnsMatch(%s, %s) = %v, want %v", test.host, test.pattern, gotMatch, test.wantMatch) + } + + }) + } +} + +func TestMatchingSANExists_FailureCases(t *testing.T) { + url1, err := url.Parse("http://golang.org") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + url2, err := url.Parse("https://github.com/grpc/grpc-go") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + inputCert := &x509.Certificate{ + DNSNames: []string{"foo.bar.example.com", "bar.baz.test.com", "*.example.com"}, + EmailAddresses: []string{"foobar@example.com", "barbaz@test.com"}, + IPAddresses: []net.IP{net.ParseIP("192.0.0.1"), net.ParseIP("2001:db8::68")}, + URIs: []*url.URL{url1, url2}, + } + + tests := []struct { + desc string + sanMatchers []xdsinternal.StringMatcher + }{ + { + desc: "exact match", + sanMatchers: []xdsinternal.StringMatcher{ + {ExactMatch: newStringP("abcd.test.com")}, + {ExactMatch: newStringP("http://golang")}, + {ExactMatch: newStringP("HTTP://GOLANG.ORG")}, + }, + }, + { + desc: "prefix match", + sanMatchers: []xdsinternal.StringMatcher{ + {PrefixMatch: newStringP("i-aint-the-one")}, + {PrefixMatch: newStringP("192.168.1.1")}, + {PrefixMatch: newStringP("FOO.BAR")}, + }, + }, + { + desc: "suffix match", + sanMatchers: []xdsinternal.StringMatcher{ + {SuffixMatch: newStringP("i-aint-the-one")}, + {SuffixMatch: newStringP("1::68")}, + {SuffixMatch: newStringP(".COM")}, + }, + }, + { + desc: "regex match", + sanMatchers: []xdsinternal.StringMatcher{ + {RegexMatch: regexp.MustCompile(`.*\.examples\.com`)}, + {RegexMatch: regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`)}, + }, + }, + { + desc: "contains match", + sanMatchers: []xdsinternal.StringMatcher{ + {ContainsMatch: newStringP("i-aint-the-one")}, + {ContainsMatch: newStringP("2001:db8:1:1::68")}, + {ContainsMatch: newStringP("GRPC")}, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + hi := NewHandshakeInfo(nil, nil) + hi.SetSANMatchers(test.sanMatchers) + + if hi.MatchingSANExists(inputCert) { + t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers) + } + }) + } +} + +func TestMatchingSANExists_Success(t *testing.T) { + url1, err := url.Parse("http://golang.org") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + url2, err := url.Parse("https://github.com/grpc/grpc-go") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + inputCert := &x509.Certificate{ + DNSNames: []string{"baz.test.com", "*.example.com"}, + EmailAddresses: []string{"foobar@example.com", "barbaz@test.com"}, + IPAddresses: []net.IP{net.ParseIP("192.0.0.1"), net.ParseIP("2001:db8::68")}, + URIs: []*url.URL{url1, url2}, + } + + tests := []struct { + desc string + sanMatchers []xdsinternal.StringMatcher + }{ + { + desc: "no san matchers", + }, + { + desc: "exact match dns wildcard", + sanMatchers: []xdsinternal.StringMatcher{ + {PrefixMatch: newStringP("192.168.1.1")}, + {ExactMatch: newStringP("https://github.com/grpc/grpc-java")}, + {ExactMatch: newStringP("abc.example.com")}, + }, + }, + { + desc: "exact match ignore case", + sanMatchers: []xdsinternal.StringMatcher{ + { + ExactMatch: newStringP("FOOBAR@EXAMPLE.COM"), + IgnoreCase: true, + }, + }, + }, + { + desc: "prefix match", + sanMatchers: []xdsinternal.StringMatcher{ + {SuffixMatch: newStringP(".co.in")}, + {PrefixMatch: newStringP("192.168.1.1")}, + {PrefixMatch: newStringP("baz.test")}, + }, + }, + { + desc: "prefix match ignore case", + sanMatchers: []xdsinternal.StringMatcher{ + { + PrefixMatch: newStringP("BAZ.test"), + IgnoreCase: true, + }, + }, + }, + { + desc: "suffix match", + sanMatchers: []xdsinternal.StringMatcher{ + {RegexMatch: regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`)}, + {SuffixMatch: newStringP("192.168.1.1")}, + {SuffixMatch: newStringP("@test.com")}, + }, + }, + { + desc: "suffix match ignore case", + sanMatchers: []xdsinternal.StringMatcher{ + { + SuffixMatch: newStringP("@test.COM"), + IgnoreCase: true, + }, + }, + }, + { + desc: "regex match", + sanMatchers: []xdsinternal.StringMatcher{ + {ContainsMatch: newStringP("https://github.com/grpc/grpc-java")}, + {RegexMatch: regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`)}, + {RegexMatch: regexp.MustCompile(`.*\.test\.com`)}, + }, + }, + { + desc: "contains match", + sanMatchers: []xdsinternal.StringMatcher{ + {ExactMatch: newStringP("https://github.com/grpc/grpc-java")}, + {ContainsMatch: newStringP("2001:68::db8")}, + {ContainsMatch: newStringP("192.0.0")}, + }, + }, + { + desc: "contains match ignore case", + sanMatchers: []xdsinternal.StringMatcher{ + { + ContainsMatch: newStringP("GRPC"), + IgnoreCase: true, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + hi := NewHandshakeInfo(nil, nil) + hi.SetSANMatchers(test.sanMatchers) + + if !hi.MatchingSANExists(inputCert) { + t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers) + } + }) + } +} + +func newStringP(s string) *string { + return &s +} diff --git a/internal/xds/string_matcher.go b/internal/xds/string_matcher.go new file mode 100644 index 000000000000..b4cd173fa9c3 --- /dev/null +++ b/internal/xds/string_matcher.go @@ -0,0 +1,39 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds contains types that need to be shared between code under +// google.golang.org/grpc/xds/... and the rest of gRPC. +package xds + +import "regexp" + +// StringMatcher contains match criteria for matching a string, and is an +// internal representation of the `StringMatcher` proto defined at +// https://github.com/envoyproxy/envoy/blob/main/api/envoy/type/matcher/v3/string.proto. +type StringMatcher struct { + // Since these match fields are part of a `oneof` in the corresponding xDS + // proto, only one of them is expected to be set. + ExactMatch *string + PrefixMatch *string + SuffixMatch *string + RegexMatch *regexp.Regexp + ContainsMatch *string + // If true, indicates the exact/prefix/suffix/contains matching should be + // case insensitive. This has no effect on the regex match. + IgnoreCase bool +} diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index e4d349753e13..19e075bd7b5d 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -235,7 +235,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err // one where fallback credentials are to be used. b.xdsHI.SetRootCertProvider(nil) b.xdsHI.SetIdentityCertProvider(nil) - b.xdsHI.SetAcceptedSANs(nil) + b.xdsHI.SetSANMatchers(nil) return nil } @@ -278,7 +278,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err // could have been non-nil earlier. b.xdsHI.SetRootCertProvider(rootProvider) b.xdsHI.SetIdentityCertProvider(identityProvider) - b.xdsHI.SetAcceptedSANs(config.AcceptedSANs) + b.xdsHI.SetSANMatchers(config.SubjectAltNameMatchers) return nil } diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index fee48c262ebe..bd64a2a2f703 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -20,16 +20,20 @@ import ( "context" "errors" "fmt" + "regexp" "testing" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal" - xdsinternal "google.golang.org/grpc/internal/credentials/xds" + xdscredsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/testutils" + xdsinternal "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/client/bootstrap" @@ -41,16 +45,28 @@ const ( fakeProvider1Name = "fake-certificate-provider-1" fakeProvider2Name = "fake-certificate-provider-2" fakeConfig = "my fake config" + testSAN = "test-san" ) var ( + testSANMatchers = []xdsinternal.StringMatcher{ + { + ExactMatch: newStringP(testSAN), + IgnoreCase: true, + }, + {PrefixMatch: newStringP(testSAN)}, + {SuffixMatch: newStringP(testSAN)}, + {RegexMatch: regexp.MustCompile(testSAN)}, + {ContainsMatch: newStringP(testSAN)}, + } fpb1, fpb2 *fakeProviderBuilder bootstrapConfig *bootstrap.Config cdsUpdateWithGoodSecurityCfg = xdsclient.ClusterUpdate{ ServiceName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", - IdentityInstanceName: "default2", + RootInstanceName: "default1", + IdentityInstanceName: "default2", + SubjectAltNameMatchers: testSANMatchers, }, } cdsUpdateWithMissingSecurityCfg = xdsclient.ClusterUpdate{ @@ -61,6 +77,10 @@ var ( } ) +func newStringP(s string) *string { + return &s +} + func init() { fpb1 = &fakeProviderBuilder{name: fakeProvider1Name} fpb2 = &fakeProviderBuilder{name: fakeProvider2Name} @@ -190,7 +210,7 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd if got, want := gotAddrs[0].Addr, addrs[0].Addr; got != want { return nil, fmt.Errorf("resolver.Address passed to parent ClientConn has address %q, want %q", got, want) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) hi := getHI(gotAddrs[0].Attributes) if hi == nil { return nil, errors.New("resolver.Address passed to parent ClientConn doesn't contain attributes") @@ -198,6 +218,11 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd if gotFallback := hi.UseFallbackCreds(); gotFallback != wantFallback { return nil, fmt.Errorf("resolver.Address HandshakeInfo uses fallback creds? %v, want %v", gotFallback, wantFallback) } + if !wantFallback { + if diff := cmp.Diff(testSANMatchers, hi.GetSANMatchersForTesting(), cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + return nil, fmt.Errorf("unexpected diff in the list of SAN matchers (-got, +want):\n%s", diff) + } + } } return sc, nil } @@ -507,7 +532,7 @@ func (s) TestGoodSecurityConfig(t *testing.T) { if got, want := gotAddrs[0].Addr, addrs[0].Addr; got != want { t.Fatalf("resolver.Address passed to parent ClientConn through UpdateAddresses() has address %q, want %q", got, want) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) hi := getHI(gotAddrs[0].Attributes) if hi == nil { t.Fatal("resolver.Address passed to parent ClientConn through UpdateAddresses() doesn't contain attributes") @@ -657,7 +682,8 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { cdsUpdate := xdsclient.ClusterUpdate{ ServiceName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", + RootInstanceName: "default1", + SubjectAltNameMatchers: testSANMatchers, }, } wantCCS := edsCCS(serviceName, nil, false) @@ -681,7 +707,8 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { cdsUpdate = xdsclient.ClusterUpdate{ ServiceName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default2", + RootInstanceName: "default2", + SubjectAltNameMatchers: testSANMatchers, }, } if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { diff --git a/xds/internal/client/cds_test.go b/xds/internal/client/cds_test.go index 6ad0b0fa88a0..a52c56a19296 100644 --- a/xds/internal/client/cds_test.go +++ b/xds/internal/client/cds_test.go @@ -19,6 +19,7 @@ package client import ( + "regexp" "testing" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" @@ -31,6 +32,7 @@ import ( anypb "github.com/golang/protobuf/ptypes/any" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + xds "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/xds/internal/env" "google.golang.org/grpc/xds/internal/version" "google.golang.org/protobuf/types/known/wrapperspb" @@ -230,9 +232,14 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { rootPluginInstance = "rootPluginInstance" rootCertName = "rootCert" serviceName = "service" - san1 = "san1" - san2 = "san2" + sanExact = "san-exact" + sanPrefix = "san-prefix" + sanSuffix = "san-suffix" + sanRegexBad = "??" + sanRegexGood = "san?regex?" + sanContains = "san-contains" ) + var sanRE = regexp.MustCompile(sanRegexGood) tests := []struct { name string @@ -377,6 +384,182 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { }, wantErr: true, }, + { + name: "empty-prefix-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: func() []byte { + tls := &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + } + mtls, _ := proto.Marshal(tls) + return mtls + }(), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "empty-suffix-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: func() []byte { + tls := &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + } + mtls, _ := proto.Marshal(tls) + return mtls + }(), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "empty-contains-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: func() []byte { + tls := &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + } + mtls, _ := proto.Marshal(tls) + return mtls + }(), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid-regex-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: func() []byte { + tls := &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexBad}}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + } + mtls, _ := proto.Marshal(tls) + return mtls + }(), + }, + }, + }, + }, + wantErr: true, + }, { name: "happy-case-with-no-identity-certs", cluster: &v3clusterpb.Cluster{ @@ -502,8 +685,14 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ DefaultValidationContext: &v3tlspb.CertificateValidationContext{ MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ - {MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: san1}}, - {MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: san2}}, + { + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: sanExact}, + IgnoreCase: true, + }, + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: sanPrefix}}, + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: sanSuffix}}, + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexGood}}}, + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: sanContains}}, }, }, ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ @@ -529,7 +718,16 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { RootCertName: rootCertName, IdentityInstanceName: identityPluginInstance, IdentityCertName: identityCertName, - AcceptedSANs: []string{san1, san2}, + SubjectAltNameMatchers: []xds.StringMatcher{ + { + ExactMatch: newStringP(sanExact), + IgnoreCase: true, + }, + {PrefixMatch: newStringP(sanPrefix)}, + {SuffixMatch: newStringP(sanSuffix)}, + {RegexMatch: sanRE}, + {ContainsMatch: newStringP(sanContains)}, + }, }, }, }, @@ -538,8 +736,11 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { update, err := validateCluster(test.cluster) - if ((err != nil) != test.wantErr) || !cmp.Equal(update, test.wantUpdate, cmpopts.EquateEmpty()) { - t.Errorf("validateCluster(%+v) = (%+v, %v), want: (%+v, %v)", test.cluster, update, err, test.wantUpdate, test.wantErr) + if (err != nil) != test.wantErr { + t.Errorf("validateCluster() returned err %v wantErr %v)", err, test.wantErr) + } + if diff := cmp.Diff(test.wantUpdate, update, cmpopts.EquateEmpty(), cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Errorf("validateCluster() returned unexpected diff (-want, +got):\n%s", diff) } }) } diff --git a/xds/internal/client/client.go b/xds/internal/client/client.go index 21881cc6eae6..c1698fdab700 100644 --- a/xds/internal/client/client.go +++ b/xds/internal/client/client.go @@ -32,6 +32,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/xds/internal/client/load" "google.golang.org/grpc/xds/internal/httpfilter" @@ -322,11 +323,15 @@ type SecurityConfig struct { // IdentityCertName is the certificate name to be passed to the plugin // (looked up from the bootstrap file) while fetching identity certificates. IdentityCertName string - // AcceptedSANs is a list of Subject Alternative Names. During the TLS - // handshake, the SAN present in the peer certificate is compared against - // this list, and the handshake succeeds only if a match is found. Used only - // on the client-side. - AcceptedSANs []string + // SubjectAltNameMatchers is an optional list of match criteria for SANs + // specified on the peer certificate. Used only on the client-side. + // + // Some intricacies: + // - If this field is empty, then any peer certificate is accepted. + // - If the peer certificate contains a wildcard DNS SAN, and an `exact` + // matcher is configured, a wildcard DNS match is performed instead of a + // regular string comparison. + SubjectAltNameMatchers []xds.StringMatcher // RequireClientCert indicates if the server handshake process expects the // client to present a certificate. Set to true when performing mTLS. Used // only on the server-side. diff --git a/xds/internal/client/xds.go b/xds/internal/client/xds.go index 8fc50ea3056c..2706fb840484 100644 --- a/xds/internal/client/xds.go +++ b/xds/internal/client/xds.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "net" + "regexp" "strconv" "strings" "time" @@ -34,12 +35,14 @@ import ( v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/env" "google.golang.org/grpc/xds/internal/httpfilter" @@ -631,21 +634,50 @@ func securityConfigFromCommonTLSContext(common *v3tlspb.CommonTlsContext) (*Secu // those possible values: // - combined validation context: // - contains a default validation context which holds the list of - // accepted SANs. + // matchers for accepted SANs. // - contains certificate provider instance configuration // - certificate provider instance configuration // - in this case, we do not get a list of accepted SANs. switch t := common.GetValidationContextType().(type) { case *v3tlspb.CommonTlsContext_CombinedValidationContext: combined := common.GetCombinedValidationContext() + var matchers []xds.StringMatcher if def := combined.GetDefaultValidationContext(); def != nil { - for _, matcher := range def.GetMatchSubjectAltNames() { - // We only support exact matches for now. - if exact := matcher.GetExact(); exact != "" { - sc.AcceptedSANs = append(sc.AcceptedSANs, exact) + for _, m := range def.GetMatchSubjectAltNames() { + matcher := xds.StringMatcher{} + switch mt := m.GetMatchPattern().(type) { + case *v3matcherpb.StringMatcher_Exact: + matcher.ExactMatch = &mt.Exact + case *v3matcherpb.StringMatcher_Prefix: + if m.GetPrefix() == "" { + return nil, errors.New("empty prefix is not allowed in StringMatcher") + } + matcher.PrefixMatch = &mt.Prefix + case *v3matcherpb.StringMatcher_Suffix: + if m.GetSuffix() == "" { + return nil, errors.New("empty suffix is not allowed in StringMatcher") + } + matcher.SuffixMatch = &mt.Suffix + case *v3matcherpb.StringMatcher_SafeRegex: + regex := m.GetSafeRegex().GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("safe_regex matcher %q is invalid", regex) + } + matcher.RegexMatch = re + case *v3matcherpb.StringMatcher_Contains: + if m.GetContains() == "" { + return nil, errors.New("empty contains is not allowed in StringMatcher") + } + matcher.ContainsMatch = &mt.Contains + default: + return nil, fmt.Errorf("combined validation context has unrecognized string matcher: %+v", m) } + matcher.IgnoreCase = m.GetIgnoreCase() + matchers = append(matchers, matcher) } } + sc.SubjectAltNameMatchers = matchers if pi := combined.GetValidationContextCertificateProviderInstance(); pi != nil { sc.RootInstanceName = pi.GetInstanceName() sc.RootCertName = pi.GetCertificateName()