diff --git a/auth/credentials/downscope/downscope.go b/auth/credentials/downscope/downscope.go index 51b5bafc7e32..4ad76c6193d4 100644 --- a/auth/credentials/downscope/downscope.go +++ b/auth/credentials/downscope/downscope.go @@ -20,13 +20,17 @@ import ( "fmt" "net/http" "net/url" + "strings" "time" "cloud.google.com/go/auth" "cloud.google.com/go/auth/internal" ) -var identityBindingEndpoint = "https://sts.googleapis.com/v1/token" +const ( + universeDomainPlaceholder = "UNIVERSE_DOMAIN" + identityBindingEndpointTemplate = "https://sts.UNIVERSE_DOMAIN/v1/token" +) // Options for configuring [NewCredentials]. type Options struct { @@ -42,15 +46,27 @@ type Options struct { // Client configures the underlying client used to make network requests // when fetching tokens. Optional. Client *http.Client + // UniverseDomain is the default service domain for a given Cloud universe. + // The default value is "googleapis.com". Optional. + UniverseDomain string } -func (c Options) client() *http.Client { - if c.Client != nil { - return c.Client +func (o *Options) client() *http.Client { + if o.Client != nil { + return o.Client } return internal.CloneDefaultClient() } +// identityBindingEndpoint returns the identity binding endpoint with the +// configured universe domain. +func (o *Options) identityBindingEndpoint() string { + if o.UniverseDomain == "" { + return strings.Replace(identityBindingEndpointTemplate, universeDomainPlaceholder, internal.DefaultUniverseDomain, 1) + } + return strings.Replace(identityBindingEndpointTemplate, universeDomainPlaceholder, o.UniverseDomain, 1) +} + // An AccessBoundaryRule Sets the permissions (and optionally conditions) that // the new token has on given resource. type AccessBoundaryRule struct { @@ -108,10 +124,14 @@ func NewCredentials(opts *Options) (*auth.Credentials, error) { } } return auth.NewCredentials(&auth.CredentialsOptions{ - TokenProvider: &downscopedTokenProvider{Options: opts, Client: opts.client()}, + TokenProvider: &downscopedTokenProvider{ + Options: opts, + Client: opts.client(), + identityBindingEndpoint: opts.identityBindingEndpoint(), + }, ProjectIDProvider: auth.CredentialsPropertyFunc(opts.Credentials.ProjectID), QuotaProjectIDProvider: auth.CredentialsPropertyFunc(opts.Credentials.QuotaProjectID), - UniverseDomainProvider: auth.CredentialsPropertyFunc(opts.Credentials.UniverseDomain), + UniverseDomainProvider: internal.StaticCredentialsProperty(opts.UniverseDomain), }), nil } @@ -119,6 +139,9 @@ func NewCredentials(opts *Options) (*auth.Credentials, error) { type downscopedTokenProvider struct { Options *Options Client *http.Client + // identityBindingEndpoint is the identity binding endpoint with the + // configured universe domain. + identityBindingEndpoint string } type downscopedOptions struct { @@ -159,7 +182,7 @@ func (dts *downscopedTokenProvider) Token(ctx context.Context) (*auth.Token, err form.Add("subject_token", tok.Value) form.Add("options", string(b)) - resp, err := dts.Client.PostForm(identityBindingEndpoint, form) + resp, err := dts.Client.PostForm(dts.identityBindingEndpoint, form) if err != nil { return nil, err } diff --git a/auth/credentials/downscope/downscope_test.go b/auth/credentials/downscope/downscope_test.go index 0dbf442d3459..1fffc6920d70 100644 --- a/auth/credentials/downscope/downscope_test.go +++ b/auth/credentials/downscope/downscope_test.go @@ -61,9 +61,6 @@ func TestNewTokenProvider(t *testing.T) { })) defer ts.Close() - oldEndpoint := identityBindingEndpoint - identityBindingEndpoint = ts.URL - t.Cleanup(func() { identityBindingEndpoint = oldEndpoint }) creds, err := NewCredentials(&Options{ Credentials: staticCredentials("token_base"), Rules: []AccessBoundaryRule{ @@ -76,6 +73,9 @@ func TestNewTokenProvider(t *testing.T) { if err != nil { t.Fatalf("NewTokenProvider() = %v", err) } + // Replace the default STS endpoint on the TokenProvider with the test server URL. + creds.TokenProvider.(*downscopedTokenProvider).identityBindingEndpoint = ts.URL + tok, err := creds.Token(context.Background()) if err != nil { t.Fatalf("Token failed with error: %v", err) @@ -85,7 +85,7 @@ func TestNewTokenProvider(t *testing.T) { } } -func TestTestNewCredentials_Validations(t *testing.T) { +func TestNewCredentials_Validations(t *testing.T) { tests := []struct { name string opts *Options @@ -136,3 +136,22 @@ func TestTestNewCredentials_Validations(t *testing.T) { }) } } + +func TestOptions_UniverseDomain(t *testing.T) { + tests := []struct { + universeDomain string + want string + }{ + {"", "https://sts.googleapis.com/v1/token"}, + {"googleapis.com", "https://sts.googleapis.com/v1/token"}, + {"example.com", "https://sts.example.com/v1/token"}, + } + for _, tt := range tests { + c := Options{ + UniverseDomain: tt.universeDomain, + } + if got := c.identityBindingEndpoint(); got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + } +}