Skip to content

Commit

Permalink
google: add Credentials.GetUniverseDomain with GCE MDS support
Browse files Browse the repository at this point in the history
* Deprecate Credentials.UniverseDomain

Change-Id: I1cbc842fbfce35540c8dff99fec09e036b9e2cdf
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/554215
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
Reviewed-by: Viacheslav Rostovtsev <virost@google.com>
  • Loading branch information
quartzmo authored and gopherbot committed Jan 5, 2024
1 parent 1e6999b commit 4ce7bbb
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 0 deletions.
58 changes: 58 additions & 0 deletions google/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"runtime"
"sync"
"time"

"cloud.google.com/go/compute/metadata"
Expand Down Expand Up @@ -41,19 +42,76 @@ type Credentials struct {
// running on Google Cloud Platform.
JSON []byte

udMu sync.Mutex // guards universeDomain
// universeDomain is the default service domain for a given Cloud universe.
universeDomain string
}

// UniverseDomain returns the default service domain for a given Cloud universe.
//
// The default value is "googleapis.com".
//
// Deprecated: Use instead (*Credentials).GetUniverseDomain(), which supports
// obtaining the universe domain when authenticating via the GCE metadata server.
// Unlike GetUniverseDomain, this method, UniverseDomain, will always return the
// default value when authenticating via the GCE metadata server.
// See also [The attached service account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
func (c *Credentials) UniverseDomain() string {
if c.universeDomain == "" {
return universeDomainDefault
}
return c.universeDomain
}

// GetUniverseDomain returns the default service domain for a given Cloud
// universe.
//
// The default value is "googleapis.com".
//
// It obtains the universe domain from the attached service account on GCE when
// authenticating via the GCE metadata server. See also [The attached service
// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
// If the GCE metadata server returns a 404 error, the default value is
// returned. If the GCE metadata server returns an error other than 404, the
// error is returned.
func (c *Credentials) GetUniverseDomain() (string, error) {
c.udMu.Lock()
defer c.udMu.Unlock()
if c.universeDomain == "" && metadata.OnGCE() {
// If we're on Google Compute Engine, an App Engine standard second
// generation runtime, or App Engine flexible, use the metadata server.
err := c.computeUniverseDomain()
if err != nil {
return "", err
}
}
// If not on Google Compute Engine, or in case of any non-error path in
// computeUniverseDomain that did not set universeDomain, set the default
// universe domain.
if c.universeDomain == "" {
c.universeDomain = universeDomainDefault
}
return c.universeDomain, nil
}

// computeUniverseDomain fetches the default service domain for a given Cloud
// universe from Google Compute Engine (GCE)'s metadata server. It's only valid
// to use this method if your program is running on a GCE instance.
func (c *Credentials) computeUniverseDomain() error {
var err error
c.universeDomain, err = metadata.Get("universe/universe_domain")
if err != nil {
if _, ok := err.(metadata.NotDefinedError); ok {
// http.StatusNotFound (404)
c.universeDomain = universeDomainDefault
return nil
} else {
return err
}
}
return nil
}

// DefaultCredentials is the old name of Credentials.
//
// Deprecated: use Credentials instead.
Expand Down
95 changes: 95 additions & 0 deletions google/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ package google

import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

Expand Down Expand Up @@ -74,6 +77,9 @@ func TestCredentialsFromJSONWithParams_SA(t *testing.T) {
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
}

func TestCredentialsFromJSONWithParams_SA_Params_UniverseDomain(t *testing.T) {
Expand All @@ -94,6 +100,9 @@ func TestCredentialsFromJSONWithParams_SA_Params_UniverseDomain(t *testing.T) {
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
}

func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) {
Expand All @@ -113,6 +122,13 @@ func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) {
if creds.UniverseDomain() != universeDomain {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if got != universeDomain {
t.Fatalf("got %q, want %q", got, universeDomain)
}
}

func TestCredentialsFromJSONWithParams_SA_UniverseDomain_Params_UniverseDomain(t *testing.T) {
Expand All @@ -133,6 +149,13 @@ func TestCredentialsFromJSONWithParams_SA_UniverseDomain_Params_UniverseDomain(t
if creds.UniverseDomain() != universeDomain2 {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if got != universeDomain2 {
t.Fatalf("got %q, want %q", got, universeDomain2)
}
}

func TestCredentialsFromJSONWithParams_User(t *testing.T) {
Expand All @@ -149,6 +172,13 @@ func TestCredentialsFromJSONWithParams_User(t *testing.T) {
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}

func TestCredentialsFromJSONWithParams_User_Params_UniverseDomain(t *testing.T) {
Expand All @@ -166,6 +196,13 @@ func TestCredentialsFromJSONWithParams_User_Params_UniverseDomain(t *testing.T)
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}

func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) {
Expand All @@ -182,6 +219,13 @@ func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) {
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}

func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain(t *testing.T) {
Expand All @@ -199,4 +243,55 @@ func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain
if want := "googleapis.com"; creds.UniverseDomain() != want {
t.Fatalf("got %q, want %q", creds.UniverseDomain(), want)
}
got, err := creds.GetUniverseDomain()
if err != nil {
t.Fatal(err)
}
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
}

func TestComputeUniverseDomain(t *testing.T) {
universeDomainPath := "/computeMetadata/v1/universe/universe_domain"
universeDomainResponseBody := "example.com"
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != universeDomainPath {
t.Errorf("got %s, want %s", r.URL.Path, universeDomainPath)
}
w.Write([]byte(universeDomainResponseBody))
}))
defer s.Close()
t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://"))

scope := "https://www.googleapis.com/auth/cloud-platform"
params := CredentialsParams{
Scopes: []string{scope},
}
// Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block
creds := &Credentials{
ProjectID: "fake_project",
TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
universeDomain: params.UniverseDomain, // empty
}
c := make(chan bool)
go func() {
got, err := creds.GetUniverseDomain() // First conflicting access.
if err != nil {
t.Error(err)
}
if want := universeDomainResponseBody; got != want {
t.Errorf("got %q, want %q", got, want)
}
c <- true
}()
got, err := creds.GetUniverseDomain() // Second conflicting access.
<-c
if err != nil {
t.Error(err)
}
if want := universeDomainResponseBody; got != want {
t.Errorf("got %q, want %q", got, want)
}

}
20 changes: 20 additions & 0 deletions google/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package google

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
Expand Down Expand Up @@ -137,3 +139,21 @@ func TestJWTConfigFromJSONNoAudience(t *testing.T) {
t.Errorf("Audience = %q; want %q", got, want)
}
}

func TestComputeTokenSource(t *testing.T) {
tokenPath := "/computeMetadata/v1/instance/service-accounts/default/token"
tokenResponseBody := `{"access_token":"Sample.Access.Token","token_type":"Bearer","expires_in":3600}`
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != tokenPath {
t.Errorf("got %s, want %s", r.URL.Path, tokenPath)
}
w.Write([]byte(tokenResponseBody))
}))
defer s.Close()
t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://"))
ts := ComputeTokenSource("")
_, err := ts.Token()
if err != nil {
t.Errorf("ts.Token() = %v", err)
}
}

0 comments on commit 4ce7bbb

Please sign in to comment.