From 67d353beefe3b607c08c891876fbd95ab89e5fe3 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Tue, 2 Apr 2024 12:29:02 -0600 Subject: [PATCH] feat(auth): add universe domain to grpctransport and httptransport (#9663) * add universe domain to grpctransport and endpoint * replace deprecated DefaultEndpoint usages with DefaultEndpointTemplate * remove DefaultUniverseDomain usages * fix EXPERIMENTAL_GOOGLE_API_USE_S2A env var detection fixes: #9670 --- auth/credentials/detect_test.go | 8 +- .../credentials/downscope/integration_test.go | 12 +- auth/credentials/idtoken/integration_test.go | 8 +- auth/credentials/impersonate/impersonate.go | 40 +- .../impersonate/impersonate_test.go | 130 ++++-- .../impersonate/integration_test.go | 4 +- auth/credentials/impersonate/user.go | 2 + auth/credentials/impersonate/user_test.go | 12 + auth/grpctransport/grpctransport.go | 12 +- auth/grpctransport/grpctransport_test.go | 16 +- auth/httptransport/httptransport.go | 8 +- auth/httptransport/httptransport_test.go | 6 +- auth/httptransport/transport_test.go | 12 +- auth/internal/credsfile/credsfile.go | 8 +- auth/internal/internal.go | 6 +- auth/internal/internal_test.go | 4 +- auth/internal/transport/cba.go | 99 ++-- auth/internal/transport/cba_test.go | 437 ++++++++++++++---- auth/internal/transport/s2a.go | 10 +- auth/internal/transport/s2a_test.go | 38 +- 20 files changed, 661 insertions(+), 211 deletions(-) diff --git a/auth/credentials/detect_test.go b/auth/credentials/detect_test.go index 902ac6321fcd..8219a9b23919 100644 --- a/auth/credentials/detect_test.go +++ b/auth/credentials/detect_test.go @@ -667,7 +667,7 @@ func TestDefaultCredentials_ExternalAccountAuthorizedUserKey(t *testing.T) { } func TestDefaultCredentials_Fails(t *testing.T) { - t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "nothingToSeeHere") + t.Setenv(credsfile.GoogleAppCredsEnvVar, "nothingToSeeHere") t.Setenv("HOME", "nothingToSeeHere") t.Setenv("APPDATA", "nothingToSeeHere") allowOnGCECheck = false @@ -890,14 +890,14 @@ func TestDefaultCredentials_UniverseDomain(t *testing.T) { t.Run(tt.name, func(t *testing.T) { creds, err := DetectDefault(tt.opts) if err != nil { - t.Fatalf("%s: %v", tt.name, err) + t.Fatalf("%v", err) } ud, err := creds.UniverseDomain(ctx) if err != nil { - t.Fatalf("%s: %v", tt.name, err) + t.Fatal(err) } if ud != tt.want { - t.Fatalf("%s: got %q, want %q", tt.name, ud, tt.want) + t.Fatalf("got %q, want %q", ud, tt.want) } }) } diff --git a/auth/credentials/downscope/integration_test.go b/auth/credentials/downscope/integration_test.go index 8b6151c7296f..eff9da604137 100644 --- a/auth/credentials/downscope/integration_test.go +++ b/auth/credentials/downscope/integration_test.go @@ -25,22 +25,22 @@ import ( "cloud.google.com/go/auth" "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/credentials/downscope" + "cloud.google.com/go/auth/internal/credsfile" "cloud.google.com/go/auth/internal/testutil" "cloud.google.com/go/auth/internal/testutil/testgcs" ) const ( - rootTokenScope = "https://www.googleapis.com/auth/cloud-platform" - envServiceAccountFile = "GOOGLE_APPLICATION_CREDENTIALS" - object1 = "cab-first-c45wknuy.txt" - object2 = "cab-second-c45wknuy.txt" - bucket = "dulcet-port-762" + rootTokenScope = "https://www.googleapis.com/auth/cloud-platform" + object1 = "cab-first-c45wknuy.txt" + object2 = "cab-second-c45wknuy.txt" + bucket = "dulcet-port-762" ) func TestDownscopedToken(t *testing.T) { testutil.IntegrationTestCheck(t) creds, err := credentials.DetectDefault(&credentials.DetectOptions{ - CredentialsFile: os.Getenv(envServiceAccountFile), + CredentialsFile: os.Getenv(credsfile.GoogleAppCredsEnvVar), Scopes: []string{rootTokenScope}, }) if err != nil { diff --git a/auth/credentials/idtoken/integration_test.go b/auth/credentials/idtoken/integration_test.go index ad147bb73193..715ae23cf97f 100644 --- a/auth/credentials/idtoken/integration_test.go +++ b/auth/credentials/idtoken/integration_test.go @@ -24,12 +24,12 @@ import ( "cloud.google.com/go/auth/credentials/idtoken" "cloud.google.com/go/auth/httptransport" + "cloud.google.com/go/auth/internal/credsfile" "cloud.google.com/go/auth/internal/testutil" ) const ( - envCredentialFile = "GOOGLE_APPLICATION_CREDENTIALS" - aud = "http://example.com" + aud = "http://example.com" ) func TestNewCredentials_CredentialsFile(t *testing.T) { @@ -37,7 +37,7 @@ func TestNewCredentials_CredentialsFile(t *testing.T) { ctx := context.Background() ts, err := idtoken.NewCredentials(&idtoken.Options{ Audience: "http://example.com", - CredentialsFile: os.Getenv(envCredentialFile), + CredentialsFile: os.Getenv(credsfile.GoogleAppCredsEnvVar), }) if err != nil { t.Fatalf("unable to create credentials: %v", err) @@ -63,7 +63,7 @@ func TestNewCredentials_CredentialsFile(t *testing.T) { func TestNewCredentials_CredentialsJSON(t *testing.T) { testutil.IntegrationTestCheck(t) ctx := context.Background() - b, err := os.ReadFile(os.Getenv(envCredentialFile)) + b, err := os.ReadFile(os.Getenv(credsfile.GoogleAppCredsEnvVar)) if err != nil { log.Fatal(err) } diff --git a/auth/credentials/impersonate/impersonate.go b/auth/credentials/impersonate/impersonate.go index 79eb15b4bdf9..a0045db45fd1 100644 --- a/auth/credentials/impersonate/impersonate.go +++ b/auth/credentials/impersonate/impersonate.go @@ -30,8 +30,13 @@ import ( ) var ( - iamCredentialsEndpoint = "https://iamcredentials.googleapis.com" - oauth2Endpoint = "https://oauth2.googleapis.com" + iamCredentialsEndpoint = "https://iamcredentials.googleapis.com" + oauth2Endpoint = "https://oauth2.googleapis.com" + errMissingTargetPrincipal = errors.New("impersonate: target service account must be provided") + errMissingScopes = errors.New("impersonate: scopes must be provided") + errLifetimeOverMax = errors.New("impersonate: max lifetime is 12 hours") + errUniverseNotSupportedDomainWideDelegation = errors.New("impersonate: service account user is configured for the credential. " + + "Domain-wide delegation is not supported in universes other than googleapis.com") ) // TODO(codyoss): plumb through base for this and idtoken @@ -82,9 +87,12 @@ func NewCredentials(opts *CredentialsOptions) (*auth.Credentials, error) { client = opts.Client } - // If a subject is specified a different auth-flow is initiated to - // impersonate as the provided subject (user). + // If a subject is specified a domain-wide delegation auth-flow is initiated + // to impersonate as the provided subject (user). if opts.Subject != "" { + if !opts.isUniverseDomainGDU() { + return nil, errUniverseNotSupportedDomainWideDelegation + } tp, err := user(opts, client, lifetime, isStaticToken) if err != nil { return nil, err @@ -158,6 +166,9 @@ type CredentialsOptions struct { // when fetching tokens. If provided the client should provide it's own // credentials at call time. Optional. Client *http.Client + // UniverseDomain is the default service domain for a given Cloud universe. + // The default value is "googleapis.com". Optional. + UniverseDomain string } func (o *CredentialsOptions) validate() error { @@ -165,17 +176,32 @@ func (o *CredentialsOptions) validate() error { return errors.New("impersonate: options must be provided") } if o.TargetPrincipal == "" { - return errors.New("impersonate: target service account must be provided") + return errMissingTargetPrincipal } if len(o.Scopes) == 0 { - return errors.New("impersonate: scopes must be provided") + return errMissingScopes } if o.Lifetime.Hours() > 12 { - return errors.New("impersonate: max lifetime is 12 hours") + return errLifetimeOverMax } return nil } +// getUniverseDomain is the default service domain for a given Cloud universe. +// The default value is "googleapis.com". +func (o *CredentialsOptions) getUniverseDomain() string { + if o.UniverseDomain == "" { + return internal.DefaultUniverseDomain + } + return o.UniverseDomain +} + +// isUniverseDomainGDU returns true if the universe domain is the default Google +// universe. +func (o *CredentialsOptions) isUniverseDomainGDU() bool { + return o.getUniverseDomain() == internal.DefaultUniverseDomain +} + func formatIAMServiceAccountName(name string) string { return fmt.Sprintf("projects/-/serviceAccounts/%s", name) } diff --git a/auth/credentials/impersonate/impersonate_test.go b/auth/credentials/impersonate/impersonate_test.go index a3c869e20838..cd468a43a1f2 100644 --- a/auth/credentials/impersonate/impersonate_test.go +++ b/auth/credentials/impersonate/impersonate_test.go @@ -30,33 +30,47 @@ import ( func TestNewCredentials_serviceAccount(t *testing.T) { ctx := context.Background() tests := []struct { - name string - targetPrincipal string - scopes []string - lifetime time.Duration - wantErr bool + name string + config CredentialsOptions + wantErr error }{ { name: "missing targetPrincipal", - wantErr: true, + wantErr: errMissingTargetPrincipal, }, { - name: "missing scopes", - targetPrincipal: "foo@project-id.iam.gserviceaccount.com", - wantErr: true, + name: "missing scopes", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + }, + wantErr: errMissingScopes, }, { - name: "lifetime over max", - targetPrincipal: "foo@project-id.iam.gserviceaccount.com", - scopes: []string{"scope"}, - lifetime: 13 * time.Hour, - wantErr: true, + name: "lifetime over max", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + Lifetime: 13 * time.Hour, + }, + wantErr: errLifetimeOverMax, }, { - name: "works", - targetPrincipal: "foo@project-id.iam.gserviceaccount.com", - scopes: []string{"scope"}, - wantErr: false, + name: "works", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + }, + wantErr: nil, + }, + { + name: "universe domain", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + Subject: "admin@example.com", + UniverseDomain: "example.com", + }, + wantErr: errUniverseNotSupportedDomainWideDelegation, }, } @@ -76,11 +90,11 @@ func TestNewCredentials_serviceAccount(t *testing.T) { if err := json.Unmarshal(b, &r); err != nil { t.Error(err) } - if !cmp.Equal(r.Scope, tt.scopes) { - t.Errorf("got %v, want %v", r.Scope, tt.scopes) + if !cmp.Equal(r.Scope, tt.config.Scopes) { + t.Errorf("got %v, want %v", r.Scope, tt.config.Scopes) } - if !strings.Contains(req.URL.Path, tt.targetPrincipal) { - t.Errorf("got %q, want %q", req.URL.Path, tt.targetPrincipal) + if !strings.Contains(req.URL.Path, tt.config.TargetPrincipal) { + t.Errorf("got %q, want %q", req.URL.Path, tt.config.TargetPrincipal) } resp := generateAccessTokenResponse{ @@ -100,24 +114,20 @@ func TestNewCredentials_serviceAccount(t *testing.T) { return nil }), } - ts, err := NewCredentials(&CredentialsOptions{ - TargetPrincipal: tt.targetPrincipal, - Scopes: tt.scopes, - Lifetime: tt.lifetime, - Client: client, - }) - if tt.wantErr && err != nil { - return - } - if err != nil { - t.Fatal(err) - } - tok, err := ts.Token(ctx) + tt.config.Client = client + ts, err := NewCredentials(&tt.config) if err != nil { - t.Fatal(err) - } - if tok.Value != saTok { - t.Fatalf("got %q, want %q", tok.Value, saTok) + if err != tt.wantErr { + t.Fatalf("err: %v", err) + } + } else { + tok, err := ts.Token(ctx) + if err != nil { + t.Fatal(err) + } + if tok.Value != saTok { + t.Fatalf("got %q, want %q", tok.Value, saTok) + } } }) } @@ -126,3 +136,45 @@ func TestNewCredentials_serviceAccount(t *testing.T) { type RoundTripFn func(req *http.Request) *http.Response func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil } + +func TestCredentialsOptions_UniverseDomain(t *testing.T) { + testCases := []struct { + name string + opts *CredentialsOptions + wantUniverseDomain string + wantIsGDU bool + }{ + { + name: "empty", + opts: &CredentialsOptions{}, + wantUniverseDomain: "googleapis.com", + wantIsGDU: true, + }, + { + name: "defaults", + opts: &CredentialsOptions{ + UniverseDomain: "googleapis.com", + }, + wantUniverseDomain: "googleapis.com", + wantIsGDU: true, + }, + { + name: "non-GDU", + opts: &CredentialsOptions{ + UniverseDomain: "example.com", + }, + wantUniverseDomain: "example.com", + wantIsGDU: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.opts.getUniverseDomain(); got != tc.wantUniverseDomain { + t.Errorf("got %v, want %v", got, tc.wantUniverseDomain) + } + if got := tc.opts.isUniverseDomainGDU(); got != tc.wantIsGDU { + t.Errorf("got %v, want %v", got, tc.wantIsGDU) + } + }) + } +} diff --git a/auth/credentials/impersonate/integration_test.go b/auth/credentials/impersonate/integration_test.go index 83cae4f14c51..fd12b4fffc75 100644 --- a/auth/credentials/impersonate/integration_test.go +++ b/auth/credentials/impersonate/integration_test.go @@ -28,12 +28,12 @@ import ( "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/credentials/idtoken" "cloud.google.com/go/auth/credentials/impersonate" + "cloud.google.com/go/auth/internal/credsfile" "cloud.google.com/go/auth/internal/testutil" "cloud.google.com/go/auth/internal/testutil/testgcs" ) const ( - envAppCreds = "GOOGLE_APPLICATION_CREDENTIALS" envProjectID = "GCLOUD_TESTS_GOLANG_PROJECT_ID" envReaderCreds = "GCLOUD_TESTS_IMPERSONATE_READER_KEY" envReaderEmail = "GCLOUD_TESTS_IMPERSONATE_READER_EMAIL" @@ -52,7 +52,7 @@ var ( func TestMain(m *testing.M) { flag.Parse() random = rand.New(rand.NewSource(time.Now().UnixNano())) - baseKeyFile = os.Getenv(envAppCreds) + baseKeyFile = os.Getenv(credsfile.GoogleAppCredsEnvVar) projectID = os.Getenv(envProjectID) readerKeyFile = os.Getenv(envReaderCreds) readerEmail = os.Getenv(envReaderEmail) diff --git a/auth/credentials/impersonate/user.go b/auth/credentials/impersonate/user.go index 09283b91a46d..5aefa2a8e301 100644 --- a/auth/credentials/impersonate/user.go +++ b/auth/credentials/impersonate/user.go @@ -28,6 +28,8 @@ import ( "cloud.google.com/go/auth/internal" ) +// user provides an auth flow for domain-wide delegation, setting +// CredentialsConfig.Subject to be the impersonated user. func user(opts *CredentialsOptions, client *http.Client, lifetime time.Duration, isStaticToken bool) (auth.TokenProvider, error) { u := userTokenProvider{ client: client, diff --git a/auth/credentials/impersonate/user_test.go b/auth/credentials/impersonate/user_test.go index 02ab889a55fe..adb4612d5eca 100644 --- a/auth/credentials/impersonate/user_test.go +++ b/auth/credentials/impersonate/user_test.go @@ -37,6 +37,7 @@ func TestNewCredentials_user(t *testing.T) { lifetime time.Duration subject string wantErr bool + universeDomain string }{ { name: "missing targetPrincipal", @@ -61,6 +62,16 @@ func TestNewCredentials_user(t *testing.T) { subject: "admin@example.com", wantErr: false, }, + { + name: "universeDomain", + targetPrincipal: "foo@project-id.iam.gserviceaccount.com", + scopes: []string{"scope"}, + subject: "admin@example.com", + wantErr: true, + // Non-GDU Universe Domain should result in error if + // CredentialsConfig.Subject is present for domain-wide delegation. + universeDomain: "example.com", + }, } for _, tt := range tests { @@ -132,6 +143,7 @@ func TestNewCredentials_user(t *testing.T) { Lifetime: tt.lifetime, Subject: tt.subject, Client: client, + UniverseDomain: tt.universeDomain, }) if tt.wantErr && err != nil { return diff --git a/auth/grpctransport/grpctransport.go b/auth/grpctransport/grpctransport.go index 17e83d9b34a0..5cfa0a1fe032 100644 --- a/auth/grpctransport/grpctransport.go +++ b/auth/grpctransport/grpctransport.go @@ -143,8 +143,9 @@ type InternalOptions struct { // DefaultAudience specifies a default audience to be used as the audience // field ("aud") for the JWT token authentication. DefaultAudience string - // DefaultEndpoint specifies the default endpoint. - DefaultEndpoint string + // DefaultEndpointTemplate combined with UniverseDomain specifies + // the default endpoint. + DefaultEndpointTemplate string // DefaultMTLSEndpoint specifies the default mTLS endpoint. DefaultMTLSEndpoint string // DefaultScopes specifies the default OAuth2 scopes to be used for a @@ -182,11 +183,12 @@ func Dial(ctx context.Context, secure bool, opts *Options) (GRPCClientConnPool, // return a GRPCClientConnPool if pool == 1 or else a pool of of them if >1 func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, error) { tOpts := &transport.Options{ - Endpoint: opts.Endpoint, - Client: opts.client(), + Endpoint: opts.Endpoint, + Client: opts.client(), + UniverseDomain: opts.UniverseDomain, } if io := opts.InternalOptions; io != nil { - tOpts.DefaultEndpoint = io.DefaultEndpoint + tOpts.DefaultEndpointTemplate = io.DefaultEndpointTemplate tOpts.DefaultMTLSEndpoint = io.DefaultMTLSEndpoint } transportCreds, endpoint, err := transport.GetGRPCTransportCredsAndEndpoint(tOpts) diff --git a/auth/grpctransport/grpctransport_test.go b/auth/grpctransport/grpctransport_test.go index e27571dd0f46..d60c3f6854fa 100644 --- a/auth/grpctransport/grpctransport_test.go +++ b/auth/grpctransport/grpctransport_test.go @@ -262,18 +262,20 @@ func TestGrpcCredentialsProvider_GetClientUniverseDomain(t *testing.T) { }, } for _, tt := range tests { - at := &grpcCredentialsProvider{clientUniverseDomain: tt.universeDomain} - got := at.getClientUniverseDomain() - if got != tt.want { - t.Errorf("%s: got %q, want %q", tt.name, got, tt.want) - } + t.Run(tt.name, func(t *testing.T) { + at := &grpcCredentialsProvider{clientUniverseDomain: tt.universeDomain} + got := at.getClientUniverseDomain() + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) } } func TestNewClient_DetectedServiceAccount(t *testing.T) { testQuota := "testquota" wantHeader := "bar" - t.Setenv("GOOGLE_CLOUD_QUOTA_PROJECT", testQuota) + t.Setenv(internal.QuotaProjectEnvVar, testQuota) l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatal(err) @@ -308,7 +310,7 @@ func TestNewClient_DetectedServiceAccount(t *testing.T) { pool, err := Dial(context.Background(), false, &Options{ Metadata: map[string]string{"Foo": wantHeader}, InternalOptions: &InternalOptions{ - DefaultEndpoint: l.Addr().String(), + DefaultEndpointTemplate: l.Addr().String(), }, DetectOpts: &credentials.DetectOptions{ Audience: l.Addr().String(), diff --git a/auth/httptransport/httptransport.go b/auth/httptransport/httptransport.go index 5fc3f93f5b89..d2d476908541 100644 --- a/auth/httptransport/httptransport.go +++ b/auth/httptransport/httptransport.go @@ -123,8 +123,9 @@ type InternalOptions struct { // DefaultAudience specifies a default audience to be used as the audience // field ("aud") for the JWT token authentication. DefaultAudience string - // DefaultEndpoint specifies the default endpoint. - DefaultEndpoint string + // DefaultEndpointTemplate combined with UniverseDomain specifies the + // default endpoint. + DefaultEndpointTemplate string // DefaultMTLSEndpoint specifies the default mTLS endpoint. DefaultMTLSEndpoint string // DefaultScopes specifies the default OAuth2 scopes to be used for a @@ -164,9 +165,10 @@ func NewClient(opts *Options) (*http.Client, error) { Endpoint: opts.Endpoint, ClientCertProvider: opts.ClientCertProvider, Client: opts.client(), + UniverseDomain: opts.UniverseDomain, } if io := opts.InternalOptions; io != nil { - tOpts.DefaultEndpoint = io.DefaultEndpoint + tOpts.DefaultEndpointTemplate = io.DefaultEndpointTemplate tOpts.DefaultMTLSEndpoint = io.DefaultMTLSEndpoint } clientCertProvider, dialTLSContext, err := transport.GetHTTPTransportConfig(tOpts) diff --git a/auth/httptransport/httptransport_test.go b/auth/httptransport/httptransport_test.go index 63e8d2d37a43..e3897486f71e 100644 --- a/auth/httptransport/httptransport_test.go +++ b/auth/httptransport/httptransport_test.go @@ -263,7 +263,7 @@ func TestOptions_ResolveDetectOptions(t *testing.T) { func TestNewClient_DetectedServiceAccount(t *testing.T) { testQuota := "testquota" wantHeader := "bar" - t.Setenv("GOOGLE_CLOUD_QUOTA_PROJECT", testQuota) + t.Setenv(internal.QuotaProjectEnvVar, testQuota) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got := r.Header.Get("Authorization"); got == "" { t.Errorf(`got "", want an auth token`) @@ -279,7 +279,7 @@ func TestNewClient_DetectedServiceAccount(t *testing.T) { client, err := NewClient(&Options{ Headers: http.Header{"Foo": []string{wantHeader}}, InternalOptions: &InternalOptions{ - DefaultEndpoint: ts.URL, + DefaultEndpointTemplate: ts.URL, }, DetectOpts: &credentials.DetectOptions{ Audience: ts.URL, @@ -303,7 +303,7 @@ func TestNewClient_APIKey(t *testing.T) { testQuota := "testquota" apiKey := "thereisnospoon" wantHeader := "bar" - t.Setenv("GOOGLE_CLOUD_QUOTA_PROJECT", testQuota) + t.Setenv(internal.QuotaProjectEnvVar, testQuota) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { got := r.URL.Query().Get("key") if got != apiKey { diff --git a/auth/httptransport/transport_test.go b/auth/httptransport/transport_test.go index da4d69befb3c..5c06b43f3fd7 100644 --- a/auth/httptransport/transport_test.go +++ b/auth/httptransport/transport_test.go @@ -39,10 +39,12 @@ func TestAuthTransport_GetClientUniverseDomain(t *testing.T) { }, } for _, tt := range tests { - at := &authTransport{clientUniverseDomain: tt.universeDomain} - got := at.getClientUniverseDomain() - if got != tt.want { - t.Errorf("%s: got %q, want %q", tt.name, got, tt.want) - } + t.Run(tt.name, func(t *testing.T) { + at := &authTransport{clientUniverseDomain: tt.universeDomain} + got := at.getClientUniverseDomain() + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) } } diff --git a/auth/internal/credsfile/credsfile.go b/auth/internal/credsfile/credsfile.go index 41c68a66b0ad..9cd4bed61b5c 100644 --- a/auth/internal/credsfile/credsfile.go +++ b/auth/internal/credsfile/credsfile.go @@ -26,8 +26,10 @@ import ( ) const ( - envGoogApCreds = "GOOGLE_APPLICATION_CREDENTIALS" - userCredsFilename = "application_default_credentials.json" + // GoogleAppCredsEnvVar is the environment variable for setting the + // application default credentials. + GoogleAppCredsEnvVar = "GOOGLE_APPLICATION_CREDENTIALS" + userCredsFilename = "application_default_credentials.json" ) // CredentialType represents different credential filetypes Google credentials @@ -80,7 +82,7 @@ func GetFileNameFromEnv(override string) string { if override != "" { return override } - return os.Getenv(envGoogApCreds) + return os.Getenv(GoogleAppCredsEnvVar) } // GetWellKnownFileName tries to locate the filepath for the user credential diff --git a/auth/internal/internal.go b/auth/internal/internal.go index 21dd2f020bf3..70534e809a4a 100644 --- a/auth/internal/internal.go +++ b/auth/internal/internal.go @@ -35,7 +35,9 @@ const ( // TokenTypeBearer is the auth header prefix for bearer tokens. TokenTypeBearer = "Bearer" - quotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT" + // QuotaProjectEnvVar is the environment variable for setting the quota + // project. + QuotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT" projectEnvVar = "GOOGLE_CLOUD_PROJECT" maxBodySize = 1 << 20 @@ -82,7 +84,7 @@ func GetQuotaProject(b []byte, override string) string { if override != "" { return override } - if env := os.Getenv(quotaProjectEnvVar); env != "" { + if env := os.Getenv(QuotaProjectEnvVar); env != "" { return env } if b == nil { diff --git a/auth/internal/internal_test.go b/auth/internal/internal_test.go index b0eda2e5cea6..98734a1c520c 100644 --- a/auth/internal/internal_test.go +++ b/auth/internal/internal_test.go @@ -67,10 +67,10 @@ func TestComputeUniverseDomainProvider(t *testing.T) { c := ComputeUniverseDomainProvider{} got, err := c.GetProperty(context.Background()) if err != tc.wantErr { - t.Errorf("%s: got error %v; want error %v", tc.name, err, tc.wantErr) + t.Errorf("got error %v; want error %v", err, tc.wantErr) } if got != tc.want { - t.Errorf("%s: got %v; want %v", tc.name, got, tc.want) + t.Errorf("got %v; want %v", got, tc.want) } }) } diff --git a/auth/internal/transport/cba.go b/auth/internal/transport/cba.go index 2ebaea824519..7ee02c6f61e4 100644 --- a/auth/internal/transport/cba.go +++ b/auth/internal/transport/cba.go @@ -17,6 +17,7 @@ package transport import ( "context" "crypto/tls" + "errors" "net" "net/http" "net/url" @@ -24,6 +25,7 @@ import ( "strconv" "strings" + "cloud.google.com/go/auth/internal" "cloud.google.com/go/auth/internal/transport/cert" "github.com/google/s2a-go" "github.com/google/s2a-go/fallback" @@ -40,21 +42,68 @@ const ( googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE" googleAPIUseMTLS = "GOOGLE_API_USE_MTLS_ENDPOINT" googleAPIUseMTLSOld = "GOOGLE_API_USE_MTLS" + + universeDomainPlaceholder = "UNIVERSE_DOMAIN" ) var ( - mdsMTLSAutoConfigSource mtlsConfigSource + mdsMTLSAutoConfigSource mtlsConfigSource + errUniverseNotSupportedMTLS = errors.New("mTLS is not supported in any universe other than googleapis.com") ) // Options is a struct that is duplicated information from the individual // transport packages in order to avoid cyclic deps. It correlates 1:1 with // fields on httptransport.Options and grpctransport.Options. type Options struct { - Endpoint string - DefaultEndpoint string - DefaultMTLSEndpoint string - ClientCertProvider cert.Provider - Client *http.Client + Endpoint string + DefaultMTLSEndpoint string + DefaultEndpointTemplate string + ClientCertProvider cert.Provider + Client *http.Client + UniverseDomain string +} + +// getUniverseDomain returns the default service domain for a given Cloud +// universe. +func (o *Options) getUniverseDomain() string { + if o.UniverseDomain == "" { + return internal.DefaultUniverseDomain + } + return o.UniverseDomain +} + +// isUniverseDomainGDU returns true if the universe domain is the default Google +// universe. +func (o *Options) isUniverseDomainGDU() bool { + return o.getUniverseDomain() == internal.DefaultUniverseDomain +} + +// defaultEndpoint returns the DefaultEndpointTemplate merged with the +// universe domain if the DefaultEndpointTemplate is set, otherwise returns an +// empty string. +func (o *Options) defaultEndpoint() string { + if o.DefaultEndpointTemplate == "" { + return "" + } + return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1) +} + +// mergedEndpoint merges a user-provided Endpoint of format host[:port] with the +// default endpoint. +func (o *Options) mergedEndpoint() (string, error) { + defaultEndpoint := o.defaultEndpoint() + u, err := url.Parse(fixScheme(defaultEndpoint)) + if err != nil { + return "", err + } + return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil +} + +func fixScheme(baseURL string) string { + if !strings.Contains(baseURL, "://") { + baseURL = "https://" + baseURL + } + return baseURL } // GetGRPCTransportCredsAndEndpoint returns an instance of @@ -127,11 +176,11 @@ func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, func getTransportConfig(opts *Options) (*transportConfig, error) { clientCertSource, err := getClientCertificateSource(opts) if err != nil { - return &transportConfig{}, err + return nil, err } endpoint, err := getEndpoint(opts, clientCertSource) if err != nil { - return &transportConfig{}, err + return nil, err } defaultTransportConfig := transportConfig{ clientCertSource: clientCertSource, @@ -141,6 +190,9 @@ func getTransportConfig(opts *Options) (*transportConfig, error) { if !shouldUseS2A(clientCertSource, opts) { return &defaultTransportConfig, nil } + if !opts.isUniverseDomainGDU() { + return nil, errUniverseNotSupportedMTLS + } s2aMTLSEndpoint := opts.DefaultMTLSEndpoint // If there is endpoint override, honor it. @@ -209,27 +261,31 @@ type transportConfig struct { // If the endpoint override is an address (host:port) rather than full base // URL (ex. https://...), then the user-provided address will be merged into // the default endpoint. For example, WithEndpoint("myhost:8000") and -// WithDefaultEndpoint("https://foo.com/bar/baz") will return "https://myhost:8080/bar/baz" +// DefaultEndpointTemplate("https://UNIVERSE_DOMAIN/bar/baz") will return "https://myhost:8080/bar/baz" func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) { if opts.Endpoint == "" { mtlsMode := getMTLSMode() if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) { + if !opts.isUniverseDomainGDU() { + return "", errUniverseNotSupportedMTLS + } return opts.DefaultMTLSEndpoint, nil } - return opts.DefaultEndpoint, nil + return opts.defaultEndpoint(), nil } if strings.Contains(opts.Endpoint, "://") { // User passed in a full URL path, use it verbatim. return opts.Endpoint, nil } - if opts.DefaultEndpoint == "" { - // If DefaultEndpoint is not configured, use the user provided endpoint verbatim. - // This allows a naked "host[:port]" URL to be used with GRPC Direct Path. + if opts.defaultEndpoint() == "" { + // If DefaultEndpointTemplate is not configured, + // use the user provided endpoint verbatim. This allows a naked + // "host[:port]" URL to be used with GRPC Direct Path. return opts.Endpoint, nil } // Assume user-provided endpoint is host[:port], merge it with the default endpoint. - return mergeEndpoints(opts.DefaultEndpoint, opts.Endpoint) + return opts.mergedEndpoint() } func getMTLSMode() string { @@ -242,18 +298,3 @@ func getMTLSMode() string { } return strings.ToLower(mode) } - -func mergeEndpoints(baseURL, newHost string) (string, error) { - u, err := url.Parse(fixScheme(baseURL)) - if err != nil { - return "", err - } - return strings.Replace(baseURL, u.Host, newHost, 1), nil -} - -func fixScheme(baseURL string) string { - if !strings.Contains(baseURL, "://") { - baseURL = "https://" + baseURL - } - return baseURL -} diff --git a/auth/internal/transport/cba_test.go b/auth/internal/transport/cba_test.go index 509bc18e6639..f5920f0cd161 100644 --- a/auth/internal/transport/cba_test.go +++ b/auth/internal/transport/cba_test.go @@ -24,9 +24,12 @@ import ( ) const ( - testMTLSEndpoint = "test.mtls.endpoint" - testRegularEndpoint = "test.endpoint" - testOverrideEndpoint = "test.override.endpoint" + testMTLSEndpoint = "https://test.mtls.googleapis.com/" + testEndpointTemplate = "https://test.UNIVERSE_DOMAIN/" + testRegularEndpoint = "https://test.googleapis.com/" + testOverrideEndpoint = "https://test.override.example.com/" + testUniverseDomain = "example.com" + testUniverseDomainEndpoint = "https://test.example.com/" ) var ( @@ -58,49 +61,123 @@ var ( fakeClientCertSource = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { return nil, nil } ) +func TestOptions_UniverseDomain(t *testing.T) { + testCases := []struct { + name string + opts *Options + wantUniverseDomain string + wantDefaultEndpoint string + wantIsGDU bool + wantMergedEndpoint string + }{ + { + name: "empty", + opts: &Options{}, + wantUniverseDomain: "googleapis.com", + wantDefaultEndpoint: "", + wantIsGDU: true, + wantMergedEndpoint: "", + }, + { + name: "defaults", + opts: &Options{ + DefaultEndpointTemplate: "https://test.UNIVERSE_DOMAIN/", + }, + wantUniverseDomain: "googleapis.com", + wantDefaultEndpoint: "https://test.googleapis.com/", + wantIsGDU: true, + wantMergedEndpoint: "", + }, + { + name: "non-GDU", + opts: &Options{ + DefaultEndpointTemplate: "https://test.UNIVERSE_DOMAIN/", + UniverseDomain: "example.com", + }, + wantUniverseDomain: "example.com", + wantDefaultEndpoint: "https://test.example.com/", + wantIsGDU: false, + wantMergedEndpoint: "", + }, + { + name: "merged endpoint", + opts: &Options{ + DefaultEndpointTemplate: "https://test.UNIVERSE_DOMAIN/bar/baz", + Endpoint: "myhost:8000", + }, + wantUniverseDomain: "googleapis.com", + wantDefaultEndpoint: "https://test.googleapis.com/bar/baz", + wantIsGDU: true, + wantMergedEndpoint: "https://myhost:8000/bar/baz", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.opts.getUniverseDomain(); got != tc.wantUniverseDomain { + t.Errorf("got %v, want %v", got, tc.wantUniverseDomain) + } + if got := tc.opts.isUniverseDomainGDU(); got != tc.wantIsGDU { + t.Errorf("got %v, want %v", got, tc.wantIsGDU) + } + if got := tc.opts.defaultEndpoint(); got != tc.wantDefaultEndpoint { + t.Errorf("got %v, want %v", got, tc.wantDefaultEndpoint) + } + if tc.opts.Endpoint != "" { + got, err := tc.opts.mergedEndpoint() + if err != nil { + t.Fatalf("%v", err) + } + if got != tc.wantMergedEndpoint { + t.Errorf("got %v, want %v", got, tc.wantMergedEndpoint) + } + } + }) + } +} + func TestGetEndpoint(t *testing.T) { testCases := []struct { - endpoint string - defaultEndpoint string - want string - wantErr bool + endpoint string + defaultEndpointTemplate string + want string + wantErr bool }{ { - defaultEndpoint: "https://foo.googleapis.com/bar/baz", - want: "https://foo.googleapis.com/bar/baz", + defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz", + want: "https://foo.googleapis.com/bar/baz", }, { - endpoint: "myhost:3999", - defaultEndpoint: "https://foo.googleapis.com/bar/baz", - want: "https://myhost:3999/bar/baz", + endpoint: "myhost:3999", + defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz", + want: "https://myhost:3999/bar/baz", }, { - endpoint: "https://host/path/to/bar", - defaultEndpoint: "https://foo.googleapis.com/bar/baz", - want: "https://host/path/to/bar", + endpoint: "https://host/path/to/bar", + defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz", + want: "https://host/path/to/bar", }, { - endpoint: "host:123", - defaultEndpoint: "", - want: "host:123", + endpoint: "host:123", + defaultEndpointTemplate: "", + want: "host:123", }, { - endpoint: "host:123", - defaultEndpoint: "default:443", - want: "host:123", + endpoint: "host:123", + defaultEndpointTemplate: "default:443", + want: "host:123", }, { - endpoint: "host:123", - defaultEndpoint: "default:443/bar/baz", - want: "host:123/bar/baz", + endpoint: "host:123", + defaultEndpointTemplate: "default:443/bar/baz", + want: "host:123/bar/baz", }, } for _, tc := range testCases { t.Run(tc.want, func(t *testing.T) { got, err := getEndpoint(&Options{ - Endpoint: tc.endpoint, - DefaultEndpoint: tc.defaultEndpoint, + Endpoint: tc.endpoint, + DefaultEndpointTemplate: tc.defaultEndpointTemplate, }, nil) if tc.wantErr && err == nil { t.Fatalf("want err, got nil err") @@ -109,7 +186,7 @@ func TestGetEndpoint(t *testing.T) { t.Fatalf("want nil err, got %v", err) } if tc.want != got { - t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.endpoint, tc.defaultEndpoint, got, tc.want) + t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.endpoint, tc.defaultEndpointTemplate, got, tc.want) } }) } @@ -117,45 +194,45 @@ func TestGetEndpoint(t *testing.T) { func TestGetEndpointWithClientCertSource(t *testing.T) { testCases := []struct { - endpoint string - defaultEndpoint string - defaultMTLSEndpoint string - want string - wantErr bool + endpoint string + defaultEndpointTemplate string + defaultMTLSEndpoint string + want string + wantErr bool }{ { - defaultEndpoint: "https://foo.googleapis.com/bar/baz", - defaultMTLSEndpoint: "https://foo.mtls.googleapis.com/bar/baz", - want: "https://foo.mtls.googleapis.com/bar/baz", + defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz", + defaultMTLSEndpoint: "https://foo.mtls.googleapis.com/bar/baz", + want: "https://foo.mtls.googleapis.com/bar/baz", }, { - defaultEndpoint: "https://staging-foo.sandbox.googleapis.com/bar/baz", - defaultMTLSEndpoint: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz", - want: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz", + defaultEndpointTemplate: "https://staging-foo.sandbox.UNIVERSE_DOMAIN/bar/baz", + defaultMTLSEndpoint: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz", + want: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz", }, { - endpoint: "myhost:3999", - defaultEndpoint: "https://foo.googleapis.com/bar/baz", - want: "https://myhost:3999/bar/baz", + endpoint: "myhost:3999", + defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz", + want: "https://myhost:3999/bar/baz", }, { - endpoint: "https://host/path/to/bar", - defaultEndpoint: "https://foo.googleapis.com/bar/baz", - want: "https://host/path/to/bar", + endpoint: "https://host/path/to/bar", + defaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz", + want: "https://host/path/to/bar", }, { - endpoint: "host:port", - defaultEndpoint: "", - want: "host:port", + endpoint: "host:port", + defaultEndpointTemplate: "", + want: "host:port", }, } for _, tc := range testCases { t.Run(tc.want, func(t *testing.T) { got, err := getEndpoint(&Options{ - Endpoint: tc.endpoint, - DefaultEndpoint: tc.defaultEndpoint, - DefaultMTLSEndpoint: tc.defaultMTLSEndpoint, + Endpoint: tc.endpoint, + DefaultEndpointTemplate: tc.defaultEndpointTemplate, + DefaultMTLSEndpoint: tc.defaultMTLSEndpoint, }, fakeClientCertSource) if tc.wantErr && err == nil { t.Fatalf("want err, got nil err") @@ -164,7 +241,7 @@ func TestGetEndpointWithClientCertSource(t *testing.T) { t.Fatalf("want nil err, got %v", err) } if tc.want != got { - t.Fatalf("getEndpoint(%q, %q): got %v; want %v", tc.endpoint, tc.defaultEndpoint, got, tc.want) + t.Fatalf("getEndpoint(%q, %q): got %v; want %v", tc.endpoint, tc.defaultEndpointTemplate, got, tc.want) } }) } @@ -181,8 +258,8 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) { { name: "no client cert, endpoint is MTLS enabled, S2A address not empty", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, }, s2ARespFn: validConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -191,9 +268,9 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) { { name: "has client cert", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, - ClientCertProvider: fakeClientCertSource, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, + ClientCertProvider: fakeClientCertSource, }, s2ARespFn: validConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -202,8 +279,8 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) { { name: "no client cert, endpoint is not MTLS enabled", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, }, s2ARespFn: validConfigResp, mtlsEnabledFn: func() bool { return false }, @@ -212,8 +289,8 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) { { name: "no client cert, endpoint is MTLS enabled, S2A address empty", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, }, s2ARespFn: invalidConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -222,9 +299,9 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) { { name: "no client cert, endpoint is MTLS enabled, S2A address not empty, override endpoint", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, - Endpoint: testOverrideEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, + Endpoint: testOverrideEndpoint, }, s2ARespFn: validConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -237,13 +314,13 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) { httpGetMetadataMTLSConfig = tc.s2ARespFn mtlsEndpointEnabledForS2A = tc.mtlsEnabledFn if tc.opts.ClientCertProvider != nil { - t.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true") + t.Setenv(googleAPIUseCertSource, "true") } else { - t.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + t.Setenv(googleAPIUseCertSource, "false") } _, endpoint, _ := GetGRPCTransportCredsAndEndpoint(tc.opts) if tc.want != endpoint { - t.Fatalf("%s: want endpoint: [%s], got [%s]", tc.name, tc.want, endpoint) + t.Fatalf("want endpoint: %s, got %s", tc.want, endpoint) } // Let the cached MTLS config expire at the end of each test case. time.Sleep(2 * time.Millisecond) @@ -251,7 +328,7 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) { } } -func TestGetHTTPTransportConfig(t *testing.T) { +func TestGetHTTPTransportConfig_S2a(t *testing.T) { testCases := []struct { name string opts *Options @@ -263,8 +340,8 @@ func TestGetHTTPTransportConfig(t *testing.T) { { name: "no client cert, endpoint is MTLS enabled, S2A address not empty", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, }, s2aFn: validConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -273,9 +350,9 @@ func TestGetHTTPTransportConfig(t *testing.T) { { name: "has client cert", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, - ClientCertProvider: fakeClientCertSource, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, + ClientCertProvider: fakeClientCertSource, }, s2aFn: validConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -285,8 +362,8 @@ func TestGetHTTPTransportConfig(t *testing.T) { { name: "no client cert, endpoint is not MTLS enabled", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, }, s2aFn: validConfigResp, mtlsEnabledFn: func() bool { return false }, @@ -296,8 +373,8 @@ func TestGetHTTPTransportConfig(t *testing.T) { { name: "no client cert, endpoint is MTLS enabled, S2A address empty", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, }, s2aFn: invalidConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -307,9 +384,9 @@ func TestGetHTTPTransportConfig(t *testing.T) { { name: "no client cert, endpoint is MTLS enabled, S2A address not empty, override endpoint", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, - Endpoint: testOverrideEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, + Endpoint: testOverrideEndpoint, }, s2aFn: validConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -318,8 +395,8 @@ func TestGetHTTPTransportConfig(t *testing.T) { { name: "no client cert, S2A address not empty, but DefaultMTLSEndpoint is not set", opts: &Options{ - DefaultMTLSEndpoint: "", - DefaultEndpoint: testRegularEndpoint, + DefaultMTLSEndpoint: "", + DefaultEndpointTemplate: testEndpointTemplate, }, s2aFn: validConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -339,9 +416,9 @@ func TestGetHTTPTransportConfig(t *testing.T) { { name: "no client cert, endpoint is MTLS enabled, S2A address not empty, custom HTTP client", opts: &Options{ - DefaultMTLSEndpoint: testMTLSEndpoint, - DefaultEndpoint: testRegularEndpoint, - Client: http.DefaultClient, + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpointTemplate: testEndpointTemplate, + Client: http.DefaultClient, }, s2aFn: validConfigResp, mtlsEnabledFn: func() bool { return true }, @@ -355,13 +432,16 @@ func TestGetHTTPTransportConfig(t *testing.T) { httpGetMetadataMTLSConfig = tc.s2aFn mtlsEndpointEnabledForS2A = tc.mtlsEnabledFn if tc.opts.ClientCertProvider != nil { - t.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true") + t.Setenv(googleAPIUseCertSource, "true") } else { - t.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + t.Setenv(googleAPIUseCertSource, "false") + } + _, dialFunc, err := GetHTTPTransportConfig(tc.opts) + if err != nil { + t.Fatalf("err: %v", err) } - _, dialFunc, _ := GetHTTPTransportConfig(tc.opts) if want, got := tc.isDialFnNil, dialFunc == nil; want != got { - t.Errorf("%s: expecting returned dialFunc is nil: [%v], got [%v]", tc.name, tc.isDialFnNil, got) + t.Errorf("expecting returned dialFunc is nil: [%v], got [%v]", tc.isDialFnNil, got) } // Let MTLS config expire at end of each test case. time.Sleep(2 * time.Millisecond) @@ -383,3 +463,184 @@ func setupTest(t *testing.T) func() { configExpiry = oldExpiry } } + +func TestGetTransportConfig_UniverseDomain(t *testing.T) { + testCases := []struct { + name string + opts *Options + wantEndpoint string + wantErr error + }{ + { + name: "google default universe (GDU), no client cert, template is regular endpoint", + opts: &Options{ + DefaultEndpointTemplate: testRegularEndpoint, + DefaultMTLSEndpoint: testMTLSEndpoint, + }, + wantEndpoint: testRegularEndpoint, + }, + { + name: "google default universe (GDU), no client cert", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + }, + wantEndpoint: testRegularEndpoint, + }, + { + name: "google default universe (GDU), client cert", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + ClientCertProvider: fakeClientCertSource, + }, + wantEndpoint: testMTLSEndpoint, + }, + { + name: "UniverseDomain, no client cert", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + UniverseDomain: testUniverseDomain, + }, + wantEndpoint: testUniverseDomainEndpoint, + }, + { + name: "UniverseDomain, client cert", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + UniverseDomain: testUniverseDomain, + ClientCertProvider: fakeClientCertSource, + }, + wantEndpoint: testUniverseDomainEndpoint, + wantErr: errUniverseNotSupportedMTLS, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.opts.ClientCertProvider != nil { + t.Setenv(googleAPIUseCertSource, "true") + } else { + t.Setenv(googleAPIUseCertSource, "false") + } + config, err := getTransportConfig(tc.opts) + if err != nil { + if err != tc.wantErr { + t.Fatalf("err: %v", err) + } + } else { + if tc.wantEndpoint != config.endpoint { + t.Errorf("want endpoint: %s, got %s", tc.wantEndpoint, config.endpoint) + } + } + }) + } +} + +func TestGetGRPCTransportCredsAndEndpoint_UniverseDomain(t *testing.T) { + testCases := []struct { + name string + opts *Options + wantEndpoint string + wantErr error + }{ + { + name: "google default universe (GDU), no client cert", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + }, + wantEndpoint: testRegularEndpoint, + }, + { + name: "google default universe (GDU), no client cert, endpoint", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + Endpoint: testOverrideEndpoint, + }, + wantEndpoint: testOverrideEndpoint, + }, + { + name: "google default universe (GDU), client cert", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + ClientCertProvider: fakeClientCertSource, + }, + wantEndpoint: testMTLSEndpoint, + }, + { + name: "google default universe (GDU), client cert, endpoint", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + ClientCertProvider: fakeClientCertSource, + Endpoint: testOverrideEndpoint, + }, + wantEndpoint: testOverrideEndpoint, + }, + { + name: "UniverseDomain, no client cert", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + UniverseDomain: testUniverseDomain, + }, + wantEndpoint: testUniverseDomainEndpoint, + }, + { + name: "UniverseDomain, no client cert, endpoint", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + UniverseDomain: testUniverseDomain, + Endpoint: testOverrideEndpoint, + }, + wantEndpoint: testOverrideEndpoint, + }, + { + name: "UniverseDomain, client cert", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + UniverseDomain: testUniverseDomain, + ClientCertProvider: fakeClientCertSource, + }, + wantErr: errUniverseNotSupportedMTLS, + }, + { + name: "UniverseDomain, client cert, endpoint", + opts: &Options{ + DefaultEndpointTemplate: testEndpointTemplate, + DefaultMTLSEndpoint: testMTLSEndpoint, + UniverseDomain: testUniverseDomain, + ClientCertProvider: fakeClientCertSource, + Endpoint: testOverrideEndpoint, + }, + wantEndpoint: testOverrideEndpoint, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.opts.ClientCertProvider != nil { + t.Setenv(googleAPIUseCertSource, "true") + } else { + t.Setenv(googleAPIUseCertSource, "false") + } + _, endpoint, err := GetGRPCTransportCredsAndEndpoint(tc.opts) + if err != nil { + if err != tc.wantErr { + t.Fatalf("err: %v", err) + } + } else { + if tc.wantEndpoint != endpoint { + t.Errorf("want endpoint: %s, got %s", tc.wantEndpoint, endpoint) + } + } + }) + } +} diff --git a/auth/internal/transport/s2a.go b/auth/internal/transport/s2a.go index 45ac578b2653..e9e4793523b1 100644 --- a/auth/internal/transport/s2a.go +++ b/auth/internal/transport/s2a.go @@ -162,7 +162,7 @@ func shouldUseS2A(clientCertSource cert.Provider, opts *Options) bool { return false } // If EXPERIMENTAL_GOOGLE_API_USE_S2A is not set to true, skip S2A. - if b, err := strconv.ParseBool(os.Getenv(googleAPIUseS2AEnv)); err == nil && !b { + if !isGoogleS2AEnabled() { return false } // If DefaultMTLSEndpoint is not set and no endpoint override, skip S2A. @@ -179,3 +179,11 @@ func shouldUseS2A(clientCertSource cert.Provider, opts *Options) bool { } return true } + +func isGoogleS2AEnabled() bool { + b, err := strconv.ParseBool(os.Getenv(googleAPIUseS2AEnv)) + if err != nil { + return false + } + return b +} diff --git a/auth/internal/transport/s2a_test.go b/auth/internal/transport/s2a_test.go index 0a6e780b903b..51a495fd66fd 100644 --- a/auth/internal/transport/s2a_test.go +++ b/auth/internal/transport/s2a_test.go @@ -62,7 +62,7 @@ func TestGetS2AAddress(t *testing.T) { t.Run(tc.name, func(t *testing.T) { httpGetMetadataMTLSConfig = tc.respFn if want, got := tc.want, GetS2AAddress(); got != want { - t.Errorf("%s: want address [%s], got address [%s]", tc.name, want, got) + t.Errorf("want address [%s], got address [%s]", want, got) } // Let the MTLS config expire at the end of each test case. time.Sleep(2 * time.Millisecond) @@ -93,3 +93,39 @@ func TestMTLSConfigExpiry(t *testing.T) { // Let the MTLS config expire before running other tests. time.Sleep(1 * time.Second) } + +func TestIsGoogleS2AEnabled(t *testing.T) { + testCases := []struct { + name string + useS2AEnv string + want bool + }{ + { + name: "true", + useS2AEnv: "true", + want: true, + }, + { + name: "false", + useS2AEnv: "false", + want: false, + }, + { + name: "empty", + useS2AEnv: "", + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.useS2AEnv != "" { + t.Setenv(googleAPIUseS2AEnv, tc.useS2AEnv) + } + + if got := isGoogleS2AEnabled(); got != tc.want { + t.Errorf("got %t, want %t", got, tc.want) + } + }) + } +}