diff --git a/credentials/xds/xds_client_test.go b/credentials/xds/xds_client_test.go index 6d40b070008..82cfa5876ac 100644 --- a/credentials/xds/xds_client_test.go +++ b/credentials/xds/xds_client_test.go @@ -36,6 +36,7 @@ import ( xdsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" "google.golang.org/grpc/testdata" ) @@ -43,7 +44,7 @@ import ( const ( defaultTestTimeout = 1 * time.Second defaultTestShortTimeout = 10 * time.Millisecond - defaultTestCertSAN = "*.test.example.com" + defaultTestCertSAN = "abc.test.example.com" authority = "authority" ) @@ -214,12 +215,14 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider { // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo // context value added to it. -func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context { +func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context { // Creating the HandshakeInfo and adding it to the attributes is very // similar to what the CDS balancer would do when it intercepts calls to // NewSubConn(). info := xdsinternal.NewHandshakeInfo(root, identity) - info.SetAcceptedSANs(sans) + if sanExactMatch != "" { + info.SetSANMatchers([]xds.StringMatcher{xds.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}) + } addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info) // Moving the attributes from the resolver.Address to the context passed to @@ -293,7 +296,7 @@ func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) { pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}) + ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "") if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil { t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo") } @@ -330,7 +333,7 @@ func (s) TestClientCredsProviderFailure(t *testing.T) { t.Run(test.desc, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider) + ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "") if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) { t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr) } @@ -372,7 +375,7 @@ func (s) TestClientCredsSuccess(t *testing.T) { desc: "mTLS with no acceptedSANs specified", handshakeFunc: testServerMutualTLSHandshake, handshakeInfoCtx: func(ctx context.Context) context.Context { - return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem")) + return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "") }, }, } @@ -532,7 +535,7 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) { // use the correct trust roots. root1 := makeRootProvider(t, "x509/client_ca_cert.pem") handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil) - handshakeInfo.SetAcceptedSANs([]string{defaultTestCertSAN}) + handshakeInfo.SetSANMatchers([]xds.StringMatcher{xds.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}) // We need to repeat most of what newTestContextWithHandshakeInfo() does // here because we need access to the underlying HandshakeInfo so that we @@ -584,3 +587,7 @@ func (s) TestClientClone(t *testing.T) { t.Fatal("return value from Clone() doesn't point to new credentials instance") } } + +func newStringP(s string) *string { + return &s +} diff --git a/internal/credentials/xds/handshake_info.go b/internal/credentials/xds/handshake_info.go index db0e6542dd2..ca2e39edd6d 100644 --- a/internal/credentials/xds/handshake_info.go +++ b/internal/credentials/xds/handshake_info.go @@ -25,11 +25,13 @@ import ( "crypto/x509" "errors" "fmt" + "strings" "sync" "google.golang.org/grpc/attributes" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" + xdsinternal "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" ) @@ -64,8 +66,8 @@ type HandshakeInfo struct { mu sync.Mutex rootProvider certprovider.Provider identityProvider certprovider.Provider - acceptedSANs map[string]bool // Only on the client side. - requireClientCert bool // Only on server side. + sanMatchers []xdsinternal.StringMatcher // Only on the client side. + requireClientCert bool // Only on server side. } // SetRootCertProvider updates the root certificate provider. @@ -82,13 +84,10 @@ func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) hi.mu.Unlock() } -// SetAcceptedSANs updates the list of accepted SANs. -func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) { +// SetSANMatchers updates the list of SAN matchers. +func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []xdsinternal.StringMatcher) { hi.mu.Lock() - hi.acceptedSANs = make(map[string]bool, len(sans)) - for _, san := range sans { - hi.acceptedSANs[san] = true - } + hi.sanMatchers = sanMatchers hi.mu.Unlock() } @@ -112,6 +111,14 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool { return hi.identityProvider == nil && hi.rootProvider == nil } +// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo. +// To be used only for testing purposes. +func (hi *HandshakeInfo) GetSANMatchersForTesting() []xdsinternal.StringMatcher { + hi.mu.Lock() + defer hi.mu.Unlock() + return append([]xdsinternal.StringMatcher{}, hi.sanMatchers...) +} + // ClientSideTLSConfig constructs a tls.Config to be used in a client-side // handshake based on the contents of the HandshakeInfo. func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) { @@ -184,37 +191,113 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, return cfg, nil } -// MatchingSANExists returns true if the SAN contained in the passed in -// certificate is present in the list of accepted SANs in the HandshakeInfo. +// MatchingSANExists returns true if the SANs contained in cert match the +// criteria enforced by the list of SAN matchers in HandshakeInfo. // -// If the list of accepted SANs in the HandshakeInfo is empty, this function +// If the list of SAN matchers in the HandshakeInfo is empty, this function // returns true for all input certificates. func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool { - if len(hi.acceptedSANs) == 0 { + hi.mu.Lock() + defer hi.mu.Unlock() + if len(hi.sanMatchers) == 0 { return true } - var sans []string // SANs can be specified in any of these four fields on the parsed cert. - sans = append(sans, cert.DNSNames...) - sans = append(sans, cert.EmailAddresses...) - for _, ip := range cert.IPAddresses { - sans = append(sans, ip.String()) + for _, san := range cert.DNSNames { + if hi.matchSAN(san, true) { + return true + } } - for _, uri := range cert.URIs { - sans = append(sans, uri.String()) + for _, san := range cert.EmailAddresses { + if hi.matchSAN(san, false) { + return true + } + } + for _, san := range cert.IPAddresses { + if hi.matchSAN(san.String(), false) { + return true + } + } + for _, san := range cert.URIs { + if hi.matchSAN(san.String(), false) { + return true + } } + return false +} - hi.mu.Lock() - defer hi.mu.Unlock() - for _, san := range sans { - if hi.acceptedSANs[san] { +// Caller must hold mu. +func (hi *HandshakeInfo) matchSAN(san string, isDNS bool) bool { + for _, matcher := range hi.sanMatchers { + if em := matcher.ExactMatch(); em != "" && isDNS { + // This is a special case which is documented in the xDS protos. + // If the DNS SAN is a wildcard entry, and the match criteria is + // `exact`, then we need to perform DNS wildcard matching + // instead of regular string comparison. + if dnsMatch(em, san) { + return true + } + continue + } + if matcher.Match(san) { return true } } return false } +// dnsMatch implements a DNS wildcard matching algorithm based on RFC2828 and +// grpc-java's implementation in `OkHostnameVerifier` class. +// +// NOTE: Here the `host` argument is the one from the set of string matchers in +// the xDS proto and the `san` argument is a DNS SAN from the certificate, and +// this is the one which can potentially contain a wildcard pattern. +func dnsMatch(host, san string) bool { + // Add trailing "." and turn them into absolute domain names. + if !strings.HasSuffix(host, ".") { + host += "." + } + if !strings.HasSuffix(san, ".") { + san += "." + } + // Domain names are case-insensitive. + host = strings.ToLower(host) + san = strings.ToLower(san) + + // If san does not contain a wildcard, do exact match. + if !strings.Contains(san, "*") { + return host == san + } + + // Wildcard dns matching rules + // - '*' is only permitted in the left-most label and must be the only + // character in that label. For example, *.example.com is permitted, while + // *a.example.com, a*.example.com, a*b.example.com, a.*.example.com are + // not permitted. + // - '*' matches a single domain name component. For example, *.example.com + // matches test.example.com but does not match sub.test.example.com. + // - Wildcard patterns for single-label domain names are not permitted. + if san == "*." || !strings.HasPrefix(san, "*.") || strings.Contains(san[1:], "*") { + return false + } + // Optimization: at this point, we know that the san contains a '*' and + // is the first domain component of san. So, the host name must be at + // least as long as the san to be able to match. + if len(host) < len(san) { + return false + } + // Hostname must end with the non-wildcard portion of san. + if !strings.HasSuffix(host, san[1:]) { + return false + } + // At this point we know that the hostName and san share the same suffix + // (the non-wildcard portion of san). Now, we just need to make sure + // that the '*' does not match across domain components. + hostPrefix := strings.TrimSuffix(host, san[1:]) + return !strings.Contains(hostPrefix, ".") +} + // NewHandshakeInfo returns a new instance of HandshakeInfo with the given root // and identity certificate providers. func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo { diff --git a/internal/credentials/xds/handshake_info_test.go b/internal/credentials/xds/handshake_info_test.go new file mode 100644 index 00000000000..81906fa758a --- /dev/null +++ b/internal/credentials/xds/handshake_info_test.go @@ -0,0 +1,304 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "crypto/x509" + "net" + "net/url" + "regexp" + "testing" + + xdsinternal "google.golang.org/grpc/internal/xds" +) + +func TestDNSMatch(t *testing.T) { + tests := []struct { + desc string + host string + pattern string + wantMatch bool + }{ + { + desc: "invalid wildcard 1", + host: "aa.example.com", + pattern: "*a.example.com", + wantMatch: false, + }, + { + desc: "invalid wildcard 2", + host: "aa.example.com", + pattern: "a*.example.com", + wantMatch: false, + }, + { + desc: "invalid wildcard 3", + host: "abc.example.com", + pattern: "a*c.example.com", + wantMatch: false, + }, + { + desc: "wildcard in one of the middle components", + host: "abc.test.example.com", + pattern: "abc.*.example.com", + wantMatch: false, + }, + { + desc: "single component wildcard", + host: "a.example.com", + pattern: "*", + wantMatch: false, + }, + { + desc: "short host name", + host: "a.com", + pattern: "*.example.com", + wantMatch: false, + }, + { + desc: "suffix mismatch", + host: "a.notexample.com", + pattern: "*.example.com", + wantMatch: false, + }, + { + desc: "wildcard match across components", + host: "sub.test.example.com", + pattern: "*.example.com.", + wantMatch: false, + }, + { + desc: "host doesn't end in period", + host: "test.example.com", + pattern: "test.example.com.", + wantMatch: true, + }, + { + desc: "pattern doesn't end in period", + host: "test.example.com.", + pattern: "test.example.com", + wantMatch: true, + }, + { + desc: "case insensitive", + host: "TEST.EXAMPLE.COM.", + pattern: "test.example.com.", + wantMatch: true, + }, + { + desc: "simple match", + host: "test.example.com", + pattern: "test.example.com", + wantMatch: true, + }, + { + desc: "good wildcard", + host: "a.example.com", + pattern: "*.example.com", + wantMatch: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotMatch := dnsMatch(test.host, test.pattern) + if gotMatch != test.wantMatch { + t.Fatalf("dnsMatch(%s, %s) = %v, want %v", test.host, test.pattern, gotMatch, test.wantMatch) + } + }) + } +} + +func TestMatchingSANExists_FailureCases(t *testing.T) { + url1, err := url.Parse("http://golang.org") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + url2, err := url.Parse("https://github.com/grpc/grpc-go") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + inputCert := &x509.Certificate{ + DNSNames: []string{"foo.bar.example.com", "bar.baz.test.com", "*.example.com"}, + EmailAddresses: []string{"foobar@example.com", "barbaz@test.com"}, + IPAddresses: []net.IP{net.ParseIP("192.0.0.1"), net.ParseIP("2001:db8::68")}, + URIs: []*url.URL{url1, url2}, + } + + tests := []struct { + desc string + sanMatchers []xdsinternal.StringMatcher + }{ + { + desc: "exact match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(newStringP("abcd.test.com"), nil, nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(newStringP("http://golang"), nil, nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(newStringP("HTTP://GOLANG.ORG"), nil, nil, nil, nil, false), + }, + }, + { + desc: "prefix match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, newStringP("i-aint-the-one"), nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, newStringP("FOO.BAR"), nil, nil, nil, false), + }, + }, + { + desc: "suffix match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, nil, newStringP("i-aint-the-one"), nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, newStringP("1::68"), nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, newStringP(".COM"), nil, nil, false), + }, + }, + { + desc: "regex match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`.*\.examples\.com`), false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + }, + }, + { + desc: "contains match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP("i-aint-the-one"), nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP("2001:db8:1:1::68"), nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP("GRPC"), nil, false), + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + hi := NewHandshakeInfo(nil, nil) + hi.SetSANMatchers(test.sanMatchers) + + if hi.MatchingSANExists(inputCert) { + t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers) + } + }) + } +} + +func TestMatchingSANExists_Success(t *testing.T) { + url1, err := url.Parse("http://golang.org") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + url2, err := url.Parse("https://github.com/grpc/grpc-go") + if err != nil { + t.Fatalf("url.Parse() failed: %v", err) + } + inputCert := &x509.Certificate{ + DNSNames: []string{"baz.test.com", "*.example.com"}, + EmailAddresses: []string{"foobar@example.com", "barbaz@test.com"}, + IPAddresses: []net.IP{net.ParseIP("192.0.0.1"), net.ParseIP("2001:db8::68")}, + URIs: []*url.URL{url1, url2}, + } + + tests := []struct { + desc string + sanMatchers []xdsinternal.StringMatcher + }{ + { + desc: "no san matchers", + }, + { + desc: "exact match dns wildcard", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(newStringP("https://github.com/grpc/grpc-java"), nil, nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(newStringP("abc.example.com"), nil, nil, nil, nil, false), + }, + }, + { + desc: "exact match ignore case", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(newStringP("FOOBAR@EXAMPLE.COM"), nil, nil, nil, nil, true), + }, + }, + { + desc: "prefix match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, nil, newStringP(".co.in"), nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, newStringP("192.168.1.1"), nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, newStringP("baz.test"), nil, nil, nil, false), + }, + }, + { + desc: "prefix match ignore case", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, newStringP("BAZ.test"), nil, nil, nil, true), + }, + }, + { + desc: "suffix match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + xdsinternal.StringMatcherForTesting(nil, nil, newStringP("192.168.1.1"), nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, newStringP("@test.com"), nil, nil, false), + }, + }, + { + desc: "suffix match ignore case", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, nil, newStringP("@test.COM"), nil, nil, true), + }, + }, + { + desc: "regex match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP("https://github.com/grpc/grpc-java"), nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`192\.[0-9]{1,3}\.1\.1`), false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(`.*\.test\.com`), false), + }, + }, + { + desc: "contains match", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(newStringP("https://github.com/grpc/grpc-java"), nil, nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP("2001:68::db8"), nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP("192.0.0"), nil, false), + }, + }, + { + desc: "contains match ignore case", + sanMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP("GRPC"), nil, true), + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + hi := NewHandshakeInfo(nil, nil) + hi.SetSANMatchers(test.sanMatchers) + + if !hi.MatchingSANExists(inputCert) { + t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers) + } + }) + } +} + +func newStringP(s string) *string { + return &s +} diff --git a/internal/xds/string_matcher.go b/internal/xds/string_matcher.go new file mode 100644 index 00000000000..21f15aad1b8 --- /dev/null +++ b/internal/xds/string_matcher.go @@ -0,0 +1,183 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package xds contains types that need to be shared between code under +// google.golang.org/grpc/xds/... and the rest of gRPC. +package xds + +import ( + "errors" + "fmt" + "regexp" + "strings" + + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +// StringMatcher contains match criteria for matching a string, and is an +// internal representation of the `StringMatcher` proto defined at +// https://github.com/envoyproxy/envoy/blob/main/api/envoy/type/matcher/v3/string.proto. +type StringMatcher struct { + // Since these match fields are part of a `oneof` in the corresponding xDS + // proto, only one of them is expected to be set. + exactMatch *string + prefixMatch *string + suffixMatch *string + regexMatch *regexp.Regexp + containsMatch *string + // If true, indicates the exact/prefix/suffix/contains matching should be + // case insensitive. This has no effect on the regex match. + ignoreCase bool +} + +// Match returns true if input matches the criteria in the given StringMatcher. +func (sm StringMatcher) Match(input string) bool { + if sm.ignoreCase { + input = strings.ToLower(input) + } + switch { + case sm.exactMatch != nil: + return input == *sm.exactMatch + case sm.prefixMatch != nil: + return strings.HasPrefix(input, *sm.prefixMatch) + case sm.suffixMatch != nil: + return strings.HasSuffix(input, *sm.suffixMatch) + case sm.regexMatch != nil: + return sm.regexMatch.MatchString(input) + case sm.containsMatch != nil: + return strings.Contains(input, *sm.containsMatch) + } + return false +} + +// StringMatcherFromProto is a helper function to create a StringMatcher from +// the corresponding StringMatcher proto. +// +// Returns a non-nil error if matcherProto is invalid. +func StringMatcherFromProto(matcherProto *v3matcherpb.StringMatcher) (StringMatcher, error) { + if matcherProto == nil { + return StringMatcher{}, errors.New("input StringMatcher proto is nil") + } + + matcher := StringMatcher{ignoreCase: matcherProto.GetIgnoreCase()} + switch mt := matcherProto.GetMatchPattern().(type) { + case *v3matcherpb.StringMatcher_Exact: + matcher.exactMatch = &mt.Exact + if matcher.ignoreCase { + *matcher.exactMatch = strings.ToLower(*matcher.exactMatch) + } + case *v3matcherpb.StringMatcher_Prefix: + if matcherProto.GetPrefix() == "" { + return StringMatcher{}, errors.New("empty prefix is not allowed in StringMatcher") + } + matcher.prefixMatch = &mt.Prefix + if matcher.ignoreCase { + *matcher.prefixMatch = strings.ToLower(*matcher.prefixMatch) + } + case *v3matcherpb.StringMatcher_Suffix: + if matcherProto.GetSuffix() == "" { + return StringMatcher{}, errors.New("empty suffix is not allowed in StringMatcher") + } + matcher.suffixMatch = &mt.Suffix + if matcher.ignoreCase { + *matcher.suffixMatch = strings.ToLower(*matcher.suffixMatch) + } + case *v3matcherpb.StringMatcher_SafeRegex: + regex := matcherProto.GetSafeRegex().GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return StringMatcher{}, fmt.Errorf("safe_regex matcher %q is invalid", regex) + } + matcher.regexMatch = re + case *v3matcherpb.StringMatcher_Contains: + if matcherProto.GetContains() == "" { + return StringMatcher{}, errors.New("empty contains is not allowed in StringMatcher") + } + matcher.containsMatch = &mt.Contains + if matcher.ignoreCase { + *matcher.containsMatch = strings.ToLower(*matcher.containsMatch) + } + default: + return StringMatcher{}, fmt.Errorf("unrecognized string matcher: %+v", matcherProto) + } + return matcher, nil +} + +// StringMatcherForTesting is a helper function to create a StringMatcher based +// on the given arguments. Intended only for testing purposes. +func StringMatcherForTesting(exact, prefix, suffix, contains *string, regex *regexp.Regexp, ignoreCase bool) StringMatcher { + sm := StringMatcher{ + exactMatch: exact, + prefixMatch: prefix, + suffixMatch: suffix, + regexMatch: regex, + containsMatch: contains, + ignoreCase: ignoreCase, + } + if ignoreCase { + switch { + case sm.exactMatch != nil: + *sm.exactMatch = strings.ToLower(*exact) + case sm.prefixMatch != nil: + *sm.prefixMatch = strings.ToLower(*prefix) + case sm.suffixMatch != nil: + *sm.suffixMatch = strings.ToLower(*suffix) + case sm.containsMatch != nil: + *sm.containsMatch = strings.ToLower(*contains) + } + } + return sm +} + +// ExactMatch returns the value of the configured exact match or an empty string +// if exact match criteria was not specified. +func (sm StringMatcher) ExactMatch() string { + if sm.exactMatch != nil { + return *sm.exactMatch + } + return "" +} + +// Equal returns true if other and sm are equivalent to each other. +func (sm StringMatcher) Equal(other StringMatcher) bool { + if sm.ignoreCase != other.ignoreCase { + return false + } + + if (sm.exactMatch != nil) != (other.exactMatch != nil) || + (sm.prefixMatch != nil) != (other.prefixMatch != nil) || + (sm.suffixMatch != nil) != (other.suffixMatch != nil) || + (sm.regexMatch != nil) != (other.regexMatch != nil) || + (sm.containsMatch != nil) != (other.containsMatch != nil) { + return false + } + + switch { + case sm.exactMatch != nil: + return *sm.exactMatch == *other.exactMatch + case sm.prefixMatch != nil: + return *sm.prefixMatch == *other.prefixMatch + case sm.suffixMatch != nil: + return *sm.suffixMatch == *other.suffixMatch + case sm.regexMatch != nil: + return sm.regexMatch.String() == other.regexMatch.String() + case sm.containsMatch != nil: + return *sm.containsMatch == *other.containsMatch + } + return true +} diff --git a/internal/xds/string_matcher_test.go b/internal/xds/string_matcher_test.go new file mode 100644 index 00000000000..7908ac974b2 --- /dev/null +++ b/internal/xds/string_matcher_test.go @@ -0,0 +1,316 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xds + +import ( + "regexp" + "testing" + + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + "github.com/google/go-cmp/cmp" +) + +func TestStringMatcherFromProto(t *testing.T) { + tests := []struct { + desc string + inputProto *v3matcherpb.StringMatcher + wantMatcher StringMatcher + wantErr bool + }{ + { + desc: "nil proto", + wantErr: true, + }, + { + desc: "empty prefix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}, + }, + wantErr: true, + }, + { + desc: "empty suffix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: ""}, + }, + wantErr: true, + }, + { + desc: "empty contains", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: ""}, + }, + wantErr: true, + }, + { + desc: "invalid regex", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{ + SafeRegex: &v3matcherpb.RegexMatcher{Regex: "??"}, + }, + }, + wantErr: true, + }, + { + desc: "invalid deprecated regex", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_HiddenEnvoyDeprecatedRegex{}, + }, + wantErr: true, + }, + { + desc: "happy case exact", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}, + }, + wantMatcher: StringMatcher{exactMatch: newStringP("exact")}, + }, + { + desc: "happy case exact ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "EXACT"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + exactMatch: newStringP("exact"), + ignoreCase: true, + }, + }, + { + desc: "happy case prefix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}, + }, + wantMatcher: StringMatcher{prefixMatch: newStringP("prefix")}, + }, + { + desc: "happy case prefix ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "PREFIX"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + prefixMatch: newStringP("prefix"), + ignoreCase: true, + }, + }, + { + desc: "happy case suffix", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}, + }, + wantMatcher: StringMatcher{suffixMatch: newStringP("suffix")}, + }, + { + desc: "happy case suffix ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "SUFFIX"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + suffixMatch: newStringP("suffix"), + ignoreCase: true, + }, + }, + { + desc: "happy case regex", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{ + SafeRegex: &v3matcherpb.RegexMatcher{Regex: "good?regex?"}, + }, + }, + wantMatcher: StringMatcher{regexMatch: regexp.MustCompile("good?regex?")}, + }, + { + desc: "happy case contains", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}, + }, + wantMatcher: StringMatcher{containsMatch: newStringP("contains")}, + }, + { + desc: "happy case contains ignore case", + inputProto: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "CONTAINS"}, + IgnoreCase: true, + }, + wantMatcher: StringMatcher{ + containsMatch: newStringP("contains"), + ignoreCase: true, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + gotMatcher, err := StringMatcherFromProto(test.inputProto) + if (err != nil) != test.wantErr { + t.Fatalf("StringMatcherFromProto(%+v) returned err: %v, wantErr: %v", test.inputProto, err, test.wantErr) + } + if diff := cmp.Diff(gotMatcher, test.wantMatcher, cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Fatalf("StringMatcherFromProto(%+v) returned unexpected diff (-got, +want):\n%s", test.inputProto, diff) + } + }) + } +} + +func TestMatch(t *testing.T) { + var ( + exactMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}}) + prefixMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}}) + suffixMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}}) + regexMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: "good?regex?"}}}) + containsMatcher, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}}) + exactMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "exact"}, + IgnoreCase: true, + }) + prefixMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "prefix"}, + IgnoreCase: true, + }) + suffixMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "suffix"}, + IgnoreCase: true, + }) + containsMatcherIgnoreCase, _ = StringMatcherFromProto(&v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: "contains"}, + IgnoreCase: true, + }) + ) + + tests := []struct { + desc string + matcher StringMatcher + input string + wantMatch bool + }{ + { + desc: "exact match success", + matcher: exactMatcher, + input: "exact", + wantMatch: true, + }, + { + desc: "exact match failure", + matcher: exactMatcher, + input: "not-exact", + }, + { + desc: "exact match success with ignore case", + matcher: exactMatcherIgnoreCase, + input: "EXACT", + wantMatch: true, + }, + { + desc: "exact match failure with ignore case", + matcher: exactMatcherIgnoreCase, + input: "not-exact", + }, + { + desc: "prefix match success", + matcher: prefixMatcher, + input: "prefixIsHere", + wantMatch: true, + }, + { + desc: "prefix match failure", + matcher: prefixMatcher, + input: "not-prefix", + }, + { + desc: "prefix match success with ignore case", + matcher: prefixMatcherIgnoreCase, + input: "PREFIXisHere", + wantMatch: true, + }, + { + desc: "prefix match failure with ignore case", + matcher: prefixMatcherIgnoreCase, + input: "not-PREFIX", + }, + { + desc: "suffix match success", + matcher: suffixMatcher, + input: "hereIsThesuffix", + wantMatch: true, + }, + { + desc: "suffix match failure", + matcher: suffixMatcher, + input: "suffix-is-not-here", + }, + { + desc: "suffix match success with ignore case", + matcher: suffixMatcherIgnoreCase, + input: "hereIsTheSuFFix", + wantMatch: true, + }, + { + desc: "suffix match failure with ignore case", + matcher: suffixMatcherIgnoreCase, + input: "SUFFIX-is-not-here", + }, + { + desc: "regex match success", + matcher: regexMatcher, + input: "goodregex", + wantMatch: true, + }, + { + desc: "regex match failure", + matcher: regexMatcher, + input: "regex-is-not-here", + }, + { + desc: "contains match success", + matcher: containsMatcher, + input: "IScontainsHERE", + wantMatch: true, + }, + { + desc: "contains match failure", + matcher: containsMatcher, + input: "con-tains-is-not-here", + }, + { + desc: "contains match success with ignore case", + matcher: containsMatcherIgnoreCase, + input: "isCONTAINShere", + wantMatch: true, + }, + { + desc: "contains match failure with ignore case", + matcher: containsMatcherIgnoreCase, + input: "CON-TAINS-is-not-here", + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + if gotMatch := test.matcher.Match(test.input); gotMatch != test.wantMatch { + t.Errorf("StringMatcher.Match(%s) returned %v, want %v", test.input, gotMatch, test.wantMatch) + } + }) + } +} + +func newStringP(s string) *string { + return &s +} diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index e4d349753e1..19e075bd7b5 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -235,7 +235,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err // one where fallback credentials are to be used. b.xdsHI.SetRootCertProvider(nil) b.xdsHI.SetIdentityCertProvider(nil) - b.xdsHI.SetAcceptedSANs(nil) + b.xdsHI.SetSANMatchers(nil) return nil } @@ -278,7 +278,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err // could have been non-nil earlier. b.xdsHI.SetRootCertProvider(rootProvider) b.xdsHI.SetIdentityCertProvider(identityProvider) - b.xdsHI.SetAcceptedSANs(config.AcceptedSANs) + b.xdsHI.SetSANMatchers(config.SubjectAltNameMatchers) return nil } diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index fee48c262eb..73459dd6410 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -20,16 +20,19 @@ import ( "context" "errors" "fmt" + "regexp" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal" - xdsinternal "google.golang.org/grpc/internal/credentials/xds" + xdscredsinternal "google.golang.org/grpc/internal/credentials/xds" "google.golang.org/grpc/internal/testutils" + xdsinternal "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/resolver" xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/client/bootstrap" @@ -41,16 +44,25 @@ const ( fakeProvider1Name = "fake-certificate-provider-1" fakeProvider2Name = "fake-certificate-provider-2" fakeConfig = "my fake config" + testSAN = "test-san" ) var ( + testSANMatchers = []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(newStringP(testSAN), nil, nil, nil, nil, true), + xdsinternal.StringMatcherForTesting(nil, newStringP(testSAN), nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, newStringP(testSAN), nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, nil, regexp.MustCompile(testSAN), false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP(testSAN), nil, false), + } fpb1, fpb2 *fakeProviderBuilder bootstrapConfig *bootstrap.Config cdsUpdateWithGoodSecurityCfg = xdsclient.ClusterUpdate{ ServiceName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", - IdentityInstanceName: "default2", + RootInstanceName: "default1", + IdentityInstanceName: "default2", + SubjectAltNameMatchers: testSANMatchers, }, } cdsUpdateWithMissingSecurityCfg = xdsclient.ClusterUpdate{ @@ -61,6 +73,10 @@ var ( } ) +func newStringP(s string) *string { + return &s +} + func init() { fpb1 = &fakeProviderBuilder{name: fakeProvider1Name} fpb2 = &fakeProviderBuilder{name: fakeProvider2Name} @@ -190,7 +206,7 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd if got, want := gotAddrs[0].Addr, addrs[0].Addr; got != want { return nil, fmt.Errorf("resolver.Address passed to parent ClientConn has address %q, want %q", got, want) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) hi := getHI(gotAddrs[0].Attributes) if hi == nil { return nil, errors.New("resolver.Address passed to parent ClientConn doesn't contain attributes") @@ -198,6 +214,11 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd if gotFallback := hi.UseFallbackCreds(); gotFallback != wantFallback { return nil, fmt.Errorf("resolver.Address HandshakeInfo uses fallback creds? %v, want %v", gotFallback, wantFallback) } + if !wantFallback { + if diff := cmp.Diff(testSANMatchers, hi.GetSANMatchersForTesting(), cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + return nil, fmt.Errorf("unexpected diff in the list of SAN matchers (-got, +want):\n%s", diff) + } + } } return sc, nil } @@ -507,7 +528,7 @@ func (s) TestGoodSecurityConfig(t *testing.T) { if got, want := gotAddrs[0].Addr, addrs[0].Addr; got != want { t.Fatalf("resolver.Address passed to parent ClientConn through UpdateAddresses() has address %q, want %q", got, want) } - getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdsinternal.HandshakeInfo) + getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo) hi := getHI(gotAddrs[0].Attributes) if hi == nil { t.Fatal("resolver.Address passed to parent ClientConn through UpdateAddresses() doesn't contain attributes") @@ -657,7 +678,8 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { cdsUpdate := xdsclient.ClusterUpdate{ ServiceName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default1", + RootInstanceName: "default1", + SubjectAltNameMatchers: testSANMatchers, }, } wantCCS := edsCCS(serviceName, nil, false) @@ -681,7 +703,8 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { cdsUpdate = xdsclient.ClusterUpdate{ ServiceName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ - RootInstanceName: "default2", + RootInstanceName: "default2", + SubjectAltNameMatchers: testSANMatchers, }, } if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { diff --git a/xds/internal/client/cds_test.go b/xds/internal/client/cds_test.go index c5f1d76d32c..bac0ef1aeaa 100644 --- a/xds/internal/client/cds_test.go +++ b/xds/internal/client/cds_test.go @@ -19,6 +19,7 @@ package client import ( + "regexp" "testing" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" @@ -31,6 +32,7 @@ import ( anypb "github.com/golang/protobuf/ptypes/any" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + xdsinternal "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/xds/internal/env" "google.golang.org/grpc/xds/internal/version" "google.golang.org/protobuf/types/known/wrapperspb" @@ -288,9 +290,14 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { rootPluginInstance = "rootPluginInstance" rootCertName = "rootCert" serviceName = "service" - san1 = "san1" - san2 = "san2" + sanExact = "san-exact" + sanPrefix = "san-prefix" + sanSuffix = "san-suffix" + sanRegexBad = "??" + sanRegexGood = "san?regex?" + sanContains = "san-contains" ) + var sanRE = regexp.MustCompile(sanRegexGood) tests := []struct { name string @@ -435,6 +442,182 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { }, wantErr: true, }, + { + name: "empty-prefix-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: func() []byte { + tls := &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + } + mtls, _ := proto.Marshal(tls) + return mtls + }(), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "empty-suffix-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: func() []byte { + tls := &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + } + mtls, _ := proto.Marshal(tls) + return mtls + }(), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "empty-contains-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: func() []byte { + tls := &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: ""}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + } + mtls, _ := proto.Marshal(tls) + return mtls + }(), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid-regex-in-matching-SAN", + cluster: &v3clusterpb.Cluster{ + ClusterDiscoveryType: &v3clusterpb.Cluster_Type{Type: v3clusterpb.Cluster_EDS}, + EdsClusterConfig: &v3clusterpb.Cluster_EdsClusterConfig{ + EdsConfig: &v3corepb.ConfigSource{ + ConfigSourceSpecifier: &v3corepb.ConfigSource_Ads{ + Ads: &v3corepb.AggregatedConfigSource{}, + }, + }, + ServiceName: serviceName, + }, + LbPolicy: v3clusterpb.Cluster_ROUND_ROBIN, + TransportSocket: &v3corepb.TransportSocket{ + ConfigType: &v3corepb.TransportSocket_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: version.V3UpstreamTLSContextURL, + Value: func() []byte { + tls := &v3tlspb.UpstreamTlsContext{ + CommonTlsContext: &v3tlspb.CommonTlsContext{ + ValidationContextType: &v3tlspb.CommonTlsContext_CombinedValidationContext{ + CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ + DefaultValidationContext: &v3tlspb.CertificateValidationContext{ + MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexBad}}}, + }, + }, + ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ + InstanceName: rootPluginInstance, + CertificateName: rootCertName, + }, + }, + }, + }, + } + mtls, _ := proto.Marshal(tls) + return mtls + }(), + }, + }, + }, + }, + wantErr: true, + }, { name: "happy-case-with-no-identity-certs", cluster: &v3clusterpb.Cluster{ @@ -560,8 +743,14 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { CombinedValidationContext: &v3tlspb.CommonTlsContext_CombinedCertificateValidationContext{ DefaultValidationContext: &v3tlspb.CertificateValidationContext{ MatchSubjectAltNames: []*v3matcherpb.StringMatcher{ - {MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: san1}}, - {MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: san2}}, + { + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: sanExact}, + IgnoreCase: true, + }, + {MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: sanPrefix}}, + {MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: sanSuffix}}, + {MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{SafeRegex: &v3matcherpb.RegexMatcher{Regex: sanRegexGood}}}, + {MatchPattern: &v3matcherpb.StringMatcher_Contains{Contains: sanContains}}, }, }, ValidationContextCertificateProviderInstance: &v3tlspb.CommonTlsContext_CertificateProviderInstance{ @@ -587,7 +776,13 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { RootCertName: rootCertName, IdentityInstanceName: identityPluginInstance, IdentityCertName: identityCertName, - AcceptedSANs: []string{san1, san2}, + SubjectAltNameMatchers: []xdsinternal.StringMatcher{ + xdsinternal.StringMatcherForTesting(newStringP(sanExact), nil, nil, nil, nil, true), + xdsinternal.StringMatcherForTesting(nil, newStringP(sanPrefix), nil, nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, newStringP(sanSuffix), nil, nil, false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, nil, sanRE, false), + xdsinternal.StringMatcherForTesting(nil, nil, nil, newStringP(sanContains), nil, false), + }, }, }, }, @@ -596,8 +791,11 @@ func (s) TestValidateClusterWithSecurityConfig(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { update, err := validateCluster(test.cluster) - if ((err != nil) != test.wantErr) || !cmp.Equal(update, test.wantUpdate, cmpopts.EquateEmpty()) { - t.Errorf("validateCluster(%+v) = (%+v, %v), want: (%+v, %v)", test.cluster, update, err, test.wantUpdate, test.wantErr) + if (err != nil) != test.wantErr { + t.Errorf("validateCluster() returned err %v wantErr %v)", err, test.wantErr) + } + if diff := cmp.Diff(test.wantUpdate, update, cmpopts.EquateEmpty(), cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Errorf("validateCluster() returned unexpected diff (-want, +got):\n%s", diff) } }) } diff --git a/xds/internal/client/client.go b/xds/internal/client/client.go index 7e8ffe7b18b..5c0f38a9782 100644 --- a/xds/internal/client/client.go +++ b/xds/internal/client/client.go @@ -33,6 +33,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/xds/internal/client/load" "google.golang.org/grpc/xds/internal/httpfilter" @@ -385,11 +386,15 @@ type SecurityConfig struct { // IdentityCertName is the certificate name to be passed to the plugin // (looked up from the bootstrap file) while fetching identity certificates. IdentityCertName string - // AcceptedSANs is a list of Subject Alternative Names. During the TLS - // handshake, the SAN present in the peer certificate is compared against - // this list, and the handshake succeeds only if a match is found. Used only - // on the client-side. - AcceptedSANs []string + // SubjectAltNameMatchers is an optional list of match criteria for SANs + // specified on the peer certificate. Used only on the client-side. + // + // Some intricacies: + // - If this field is empty, then any peer certificate is accepted. + // - If the peer certificate contains a wildcard DNS SAN, and an `exact` + // matcher is configured, a wildcard DNS match is performed instead of a + // regular string comparison. + SubjectAltNameMatchers []xds.StringMatcher // RequireClientCert indicates if the server handshake process expects the // client to present a certificate. Set to true when performing mTLS. Used // only on the server-side. diff --git a/xds/internal/client/xds.go b/xds/internal/client/xds.go index 8f604d32b8f..13b4e7f76d4 100644 --- a/xds/internal/client/xds.go +++ b/xds/internal/client/xds.go @@ -40,6 +40,7 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/env" "google.golang.org/grpc/xds/internal/httpfilter" @@ -697,21 +698,24 @@ func securityConfigFromCommonTLSContext(common *v3tlspb.CommonTlsContext) (*Secu // those possible values: // - combined validation context: // - contains a default validation context which holds the list of - // accepted SANs. + // matchers for accepted SANs. // - contains certificate provider instance configuration // - certificate provider instance configuration // - in this case, we do not get a list of accepted SANs. switch t := common.GetValidationContextType().(type) { case *v3tlspb.CommonTlsContext_CombinedValidationContext: combined := common.GetCombinedValidationContext() + var matchers []xds.StringMatcher if def := combined.GetDefaultValidationContext(); def != nil { - for _, matcher := range def.GetMatchSubjectAltNames() { - // We only support exact matches for now. - if exact := matcher.GetExact(); exact != "" { - sc.AcceptedSANs = append(sc.AcceptedSANs, exact) + for _, m := range def.GetMatchSubjectAltNames() { + matcher, err := xds.StringMatcherFromProto(m) + if err != nil { + return nil, err } + matchers = append(matchers, matcher) } } + sc.SubjectAltNameMatchers = matchers if pi := combined.GetValidationContextCertificateProviderInstance(); pi != nil { sc.RootInstanceName = pi.GetInstanceName() sc.RootCertName = pi.GetCertificateName()