diff --git a/internal/creds.go b/internal/creds.go index 69b186b70a7..63c66092203 100644 --- a/internal/creds.go +++ b/internal/creds.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "net" "net/http" + "os" "time" "golang.org/x/oauth2" @@ -21,6 +22,8 @@ import ( "golang.org/x/oauth2/google" ) +const quotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT" + // Creds returns credential information obtained from DialSettings, or if none, then // it returns default credential information. func Creds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) { @@ -152,14 +155,22 @@ func selfSignedJWTTokenSource(data []byte, ds *DialSettings) (oauth2.TokenSource } } -// QuotaProjectFromCreds returns the quota project from the JSON blob in the provided credentials. -// -// NOTE(cbro): consider promoting this to a field on google.Credentials. -func QuotaProjectFromCreds(cred *google.Credentials) string { +// GetQuotaProject retrieves quota project with precedence being: client option, +// environment variable, creds file. +func GetQuotaProject(creds *google.Credentials, clientOpt string) string { + if clientOpt != "" { + return clientOpt + } + if env := os.Getenv(quotaProjectEnvVar); env != "" { + return env + } + if creds == nil { + return "" + } var v struct { QuotaProject string `json:"quota_project_id"` } - if err := json.Unmarshal(cred.JSON, &v); err != nil { + if err := json.Unmarshal(creds.JSON, &v); err != nil { return "" } return v.QuotaProject diff --git a/internal/creds_test.go b/internal/creds_test.go index 34b052dcb29..7b5bf2235d3 100644 --- a/internal/creds_test.go +++ b/internal/creds_test.go @@ -6,6 +6,7 @@ package internal import ( "context" + "os" "testing" "github.com/google/go-cmp/cmp" @@ -199,10 +200,9 @@ const validServiceAccountJSON = `{ "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/dumba-504%40appspot.gserviceaccount.com" }` -func TestQuotaProjectFromCreds(t *testing.T) { +func TestGetQuotaProject(t *testing.T) { ctx := context.Background() - - cred, err := credentialsFromJSON( + emptyCred, err := credentialsFromJSON( ctx, []byte(validServiceAccountJSON), &DialSettings{ @@ -212,17 +212,13 @@ func TestQuotaProjectFromCreds(t *testing.T) { if err != nil { t.Fatalf("got %v, wanted no error", err) } - if want, got := "", QuotaProjectFromCreds(cred); want != got { - t.Errorf("QuotaProjectFromCreds(validServiceAccountJSON): want %q, got %q", want, got) - } - quotaProjectJSON := []byte(` { "type": "authorized_user", "quota_project_id": "foobar" }`) - cred, err = credentialsFromJSON( + quotaCred, err := credentialsFromJSON( ctx, []byte(quotaProjectJSON), &DialSettings{ @@ -232,8 +228,53 @@ func TestQuotaProjectFromCreds(t *testing.T) { if err != nil { t.Fatalf("got %v, wanted no error", err) } - if want, got := "foobar", QuotaProjectFromCreds(cred); want != got { - t.Errorf("QuotaProjectFromCreds(quotaProjectJSON): want %q, got %q", want, got) + + tests := []struct { + name string + cred *google.Credentials + clientOpt string + env string + want string + }{ + { + name: "empty all", + cred: nil, + want: "", + }, + { + name: "empty cred", + cred: emptyCred, + want: "", + }, + { + name: "from cred", + cred: quotaCred, + want: "foobar", + }, + { + name: "from opt", + cred: quotaCred, + clientOpt: "clientopt", + want: "clientopt", + }, + { + name: "from env", + cred: quotaCred, + env: "envProject", + want: "envProject", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oldEnv := os.Getenv(quotaProjectEnvVar) + if tt.env != "" { + os.Setenv(quotaProjectEnvVar, tt.env) + } + if want, got := tt.want, GetQuotaProject(tt.cred, tt.clientOpt); want != got { + t.Errorf("GetQuotaProject(%v, %q): want %q, got %q", tt.cred, tt.clientOpt, want, got) + } + os.Setenv(quotaProjectEnvVar, oldEnv) + }) } } diff --git a/transport/grpc/dial.go b/transport/grpc/dial.go index c76894ff4c6..20c94fa640b 100644 --- a/transport/grpc/dial.go +++ b/transport/grpc/dial.go @@ -154,14 +154,10 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C return nil, err } - if o.QuotaProject == "" { - o.QuotaProject = internal.QuotaProjectFromCreds(creds) - } - grpcOpts = append(grpcOpts, grpc.WithPerRPCCredentials(grpcTokenSource{ TokenSource: oauth.TokenSource{creds.TokenSource}, - quotaProject: o.QuotaProject, + quotaProject: internal.GetQuotaProject(creds, o.QuotaProject), requestReason: o.RequestReason, }), ) diff --git a/transport/http/dial.go b/transport/http/dial.go index 4f7f44e8dbf..403509d08f6 100644 --- a/transport/http/dial.go +++ b/transport/http/dial.go @@ -65,7 +65,6 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna paramTransport := ¶meterTransport{ base: base, userAgent: settings.UserAgent, - quotaProject: settings.QuotaProject, requestReason: settings.RequestReason, } var trans http.RoundTripper = paramTransport @@ -74,6 +73,7 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna case settings.NoAuth: // Do nothing. case settings.APIKey != "": + paramTransport.quotaProject = internal.GetQuotaProject(nil, settings.QuotaProject) trans = &transport.APIKey{ Transport: trans, Key: settings.APIKey, @@ -83,10 +83,7 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna if err != nil { return nil, err } - if paramTransport.quotaProject == "" { - paramTransport.quotaProject = internal.QuotaProjectFromCreds(creds) - } - + paramTransport.quotaProject = internal.GetQuotaProject(creds, settings.QuotaProject) ts := creds.TokenSource if settings.ImpersonationConfig == nil && settings.TokenSource != nil { ts = settings.TokenSource