Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ linters:
- r
- w
- f
- h
- q
- err

exclusions:
Expand Down
64 changes: 40 additions & 24 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,52 @@ const (
maxBodySize = 10 * 1024 * 1024 // 10MB
defaultTimeout = 30 * time.Second
metadataFlavor = "Google"
//nolint:revive // GCP metadata server only supports HTTP
defaultMetadataURL = "http://metadata.google.internal/computeMetadata/v1"
)

var (
metadataURL = "http://metadata.google.internal/computeMetadata/v1" //nolint:revive // GCP metadata server only supports HTTP
isTestMode = false
var httpClient = &http.Client{
Timeout: defaultTimeout,
}

httpClient = &http.Client{
Timeout: defaultTimeout,
}
)
// Config holds auth configuration.
type Config struct {
// MetadataURL is the URL for the GCP metadata server.
// Defaults to the production metadata server if empty.
MetadataURL string

// SkipADC skips Application Default Credentials and goes straight to metadata server.
// Useful for testing to ensure mock servers are used.
SkipADC bool
}

// SetMetadataURL sets a custom metadata server URL for testing.
// Returns a function that restores the original URL.
// WARNING: This function should only be called in test code.
// Set DS9_ALLOW_TEST_OVERRIDES=true to enable in non-test environments.
func SetMetadataURL(urlStr string) func() {
old := metadataURL
oldTestMode := isTestMode
metadataURL = urlStr
isTestMode = true // Enable test mode to skip ADC
return func() {
metadataURL = old
isTestMode = oldTestMode
// configKey is the key for storing Config in context.
type configKey struct{}

// WithConfig returns a new context with the given auth config.
func WithConfig(ctx context.Context, cfg *Config) context.Context {
return context.WithValue(ctx, configKey{}, cfg)
}

// getConfig retrieves the auth config from context, or returns defaults.
func getConfig(ctx context.Context) *Config {
if cfg, ok := ctx.Value(configKey{}).(*Config); ok && cfg != nil {
return cfg
}
return &Config{
MetadataURL: defaultMetadataURL,
SkipADC: false,
}
}

// AccessToken retrieves a GCP access token.
// It tries Application Default Credentials first, then falls back to the metadata server.
// In test mode, ADC is skipped to ensure mock servers are used.
// Configuration can be provided via auth.WithConfig in the context.
func AccessToken(ctx context.Context) (string, error) {
// Skip ADC in test mode to ensure tests use mock metadata server
if !isTestMode {
cfg := getConfig(ctx)

// Skip ADC if configured (useful for testing to ensure mock metadata server is used)
if !cfg.SkipADC {
// Try Application Default Credentials first (for local development)
token, err := accessTokenFromADC(ctx)
if err == nil {
Expand Down Expand Up @@ -165,7 +179,8 @@ func exchangeRefreshToken(ctx context.Context, clientID, clientSecret, refreshTo
// accessTokenFromMetadata retrieves an access token from the GCP metadata server.
// This is used when running on GCP (GCE, GKE, Cloud Run, etc.).
func accessTokenFromMetadata(ctx context.Context) (string, error) {
reqURL := metadataURL + "/instance/service-accounts/default/token"
cfg := getConfig(ctx)
reqURL := cfg.MetadataURL + "/instance/service-accounts/default/token"

req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, http.NoBody)
if err != nil {
Expand Down Expand Up @@ -206,7 +221,8 @@ func accessTokenFromMetadata(ctx context.Context) (string, error) {

// ProjectID retrieves the project ID from the GCP metadata server.
func ProjectID(ctx context.Context) (string, error) {
reqURL := metadataURL + "/project/project-id"
cfg := getConfig(ctx)
reqURL := cfg.MetadataURL + "/project/project-id"

req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, http.NoBody)
if err != nil {
Expand Down
137 changes: 68 additions & 69 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,27 @@ import (
"testing"
)

func TestSetMetadataURL(t *testing.T) {
originalURL := metadataURL
originalTestMode := isTestMode

// Set custom URL
restore := SetMetadataURL("http://custom-metadata")

if metadataURL != "http://custom-metadata" {
t.Errorf("expected metadataURL to be http://custom-metadata, got %s", metadataURL)
func TestWithConfig(t *testing.T) {
// Test that config can be set in context
cfg := &Config{
MetadataURL: "http://custom-metadata",
SkipADC: true,
}
ctx := WithConfig(context.Background(), cfg)

if !isTestMode {
t.Error("expected isTestMode to be true")
// Context should be non-nil
if ctx == nil {
t.Fatal("expected non-nil context")
}

// Restore
restore()

if metadataURL != originalURL {
t.Errorf("expected metadataURL to be restored to %s, got %s", originalURL, metadataURL)
// Verify config is retrievable
retrievedCfg := getConfig(ctx)
if retrievedCfg.MetadataURL != "http://custom-metadata" {
t.Errorf("expected MetadataURL to be http://custom-metadata, got %s", retrievedCfg.MetadataURL)
}

if isTestMode != originalTestMode {
t.Errorf("expected isTestMode to be restored to %v, got %v", originalTestMode, isTestMode)
if !retrievedCfg.SkipADC {
t.Error("expected SkipADC to be true")
}
}

Expand Down Expand Up @@ -111,10 +108,11 @@ func TestAccessTokenFromMetadata(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})

ctx := context.Background()
token, err := accessTokenFromMetadata(ctx)

if tt.wantErr {
Expand Down Expand Up @@ -153,12 +151,12 @@ func TestAccessToken(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})

ctx := context.Background()

// In test mode, should use metadata server
// With SkipADC=true, should use metadata server
token, err := AccessToken(ctx)
if err != nil {
t.Fatalf("AccessToken failed: %v", err)
Expand Down Expand Up @@ -353,10 +351,10 @@ func TestProjectID(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})
projectID, err := ProjectID(ctx)

if tt.wantErr {
Expand All @@ -379,10 +377,10 @@ func TestProjectID(t *testing.T) {

func TestAccessTokenMetadataServerDown(t *testing.T) {
// Point to non-existent server
restore := SetMetadataURL("http://localhost:59999")
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: "http://localhost:59999",
SkipADC: true,
})
_, err := accessTokenFromMetadata(ctx)

if err == nil {
Expand All @@ -396,10 +394,10 @@ func TestAccessTokenMetadataServerDown(t *testing.T) {

func TestProjectIDMetadataServerDown(t *testing.T) {
// Point to non-existent server
restore := SetMetadataURL("http://localhost:59998")
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: "http://localhost:59998",
SkipADC: true,
})
_, err := ProjectID(ctx)

if err == nil {
Expand Down Expand Up @@ -465,10 +463,10 @@ func TestAccessTokenFromMetadataReadError(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})
_, err := accessTokenFromMetadata(ctx)

if err == nil {
Expand All @@ -488,10 +486,10 @@ func TestProjectIDReadError(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})
_, err := ProjectID(ctx)

if err == nil {
Expand Down Expand Up @@ -521,10 +519,10 @@ func TestAccessTokenFromMetadataWithMalformedJSON(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})
_, err := accessTokenFromMetadata(ctx)
// Should either succeed (if parser is lenient) or fail with parse error
if err != nil {
Expand All @@ -547,10 +545,10 @@ func TestProjectIDWithEmptyResponse(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})
projectID, err := ProjectID(ctx)
if err != nil {
t.Fatalf("ProjectID with empty response failed: %v", err)
Expand Down Expand Up @@ -650,15 +648,16 @@ func TestAccessTokenFallbackToMetadata(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})

// Ensure no ADC credentials are available
if err := os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS"); err != nil {
t.Fatalf("failed to unset env var: %v", err)
}

ctx := context.Background()
token, err := AccessToken(ctx)
if err != nil {
t.Fatalf("AccessToken failed: %v", err)
Expand Down Expand Up @@ -688,10 +687,10 @@ func TestProjectIDInvalidJSON(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})
projectID, err := ProjectID(ctx)

if err != nil {
Expand Down Expand Up @@ -750,10 +749,10 @@ func TestAccessTokenFromADCDefaultLocation(t *testing.T) {
// Test ProjectID with request error
func TestProjectIDRequestError(t *testing.T) {
// Set invalid URL to trigger request error
restore := SetMetadataURL("http://invalid-host-that-does-not-exist-12345")
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: "http://invalid-host-that-does-not-exist-12345",
SkipADC: true,
})
_, err := ProjectID(ctx)

if err == nil {
Expand All @@ -777,10 +776,10 @@ func TestAccessTokenFromMetadataTypeError(t *testing.T) {
}))
defer server.Close()

restore := SetMetadataURL(server.URL)
defer restore()

ctx := context.Background()
ctx := WithConfig(context.Background(), &Config{
MetadataURL: server.URL,
SkipADC: true,
})
_, err := accessTokenFromMetadata(ctx)

if err == nil {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/codeGROOVE-dev/ds9

go 1.23
go 1.25
Loading
Loading