diff --git a/internal/handlers/azdo_api.go b/internal/handlers/azdo_api.go index bef305f..f16bf21 100644 --- a/internal/handlers/azdo_api.go +++ b/internal/handlers/azdo_api.go @@ -71,7 +71,7 @@ func (h *AzureDevOpsAPIHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr } logging.RequestLogf(ctx, "* authenticating azure devops api request with token for %s", host) - req.SetBasicAuth(creds[0].username, creds[0].password) + helpers.SetAuthorization(req, helpers.BasicAuth(creds[0].username, creds[0].password)) // Azure DevOps requires an api-version to be set for requests. Add it if it is not present. var queryParams = req.URL.Query() diff --git a/internal/handlers/cargo_registry.go b/internal/handlers/cargo_registry.go index b875d84..b2be65d 100644 --- a/internal/handlers/cargo_registry.go +++ b/internal/handlers/cargo_registry.go @@ -101,7 +101,7 @@ func (h *CargoRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro } logging.RequestLogf(ctx, "* authenticating cargo registry request (url: %s)", cred.url) - req.Header.Set("Authorization", cred.authorization) + helpers.SetAuthorization(req, helpers.RawAuth(cred.authorization)) return req, nil } diff --git a/internal/handlers/composer.go b/internal/handlers/composer.go index 2865fad..96178f6 100644 --- a/internal/handlers/composer.go +++ b/internal/handlers/composer.go @@ -81,10 +81,10 @@ func (h *ComposerHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCtx if cred.token != "" { logging.RequestLogf(ctx, "* authenticating composer registry request (host: %s, token auth)", req.URL.Hostname()) - req.Header.Set("Authorization", "Bearer "+cred.token) + helpers.SetAuthorization(req, helpers.BearerAuth(cred.token)) } else { logging.RequestLogf(ctx, "* authenticating composer registry request (host: %s, basic auth)", req.URL.Hostname()) - req.SetBasicAuth(cred.username, cred.password) + helpers.SetAuthorization(req, helpers.BasicAuth(cred.username, cred.password)) } return req, nil diff --git a/internal/handlers/docker_registry.go b/internal/handlers/docker_registry.go index 4d6ce1a..1b00062 100644 --- a/internal/handlers/docker_registry.go +++ b/internal/handlers/docker_registry.go @@ -120,7 +120,7 @@ func (h *DockerRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr if cred.getECRCredentials(ctx) { logging.RequestLogf(ctx, "* authenticating docker ecr request (host: %s)", req.URL.Hostname()) - req.SetBasicAuth(cred.ecrUsername, cred.ecrPassword) + helpers.SetAuthorization(req, helpers.BasicAuth(cred.ecrUsername, cred.ecrPassword)) } else { logging.RequestLogf(ctx, "* authenticating docker registry request (host: %s)", req.URL.Hostname()) transport := ®istry.BasicTransport{ diff --git a/internal/handlers/git_server.go b/internal/handlers/git_server.go index 784aeb9..452a70a 100644 --- a/internal/handlers/git_server.go +++ b/internal/handlers/git_server.go @@ -282,7 +282,7 @@ func (h *GitServerHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt logging.RequestLogf(ctx, "* authenticating git server request (host: %s)", helpers.GetHost(req)) credsToUse := creds[0] - req.SetBasicAuth(credsToUse.username, credsToUse.password) + helpers.SetAuthorization(req, helpers.BasicAuth(credsToUse.username, credsToUse.password)) if ctx != nil { ctxdata.SetValue(ctx, addedAuthCtxKey, credsToUse) } @@ -472,7 +472,7 @@ func (h *GitServerHandler) requestWithAlternativeAuth(ctx *goproxy.ProxyCtx, bod newReq.Body = io.NopCloser(bytes.NewReader(body)) } - newReq.SetBasicAuth(creds.username, creds.password) + helpers.SetAuthorization(newReq, helpers.BasicAuth(creds.username, creds.password)) newRsp, err := ctx.RoundTrip(newReq) if err != nil { return nil diff --git a/internal/handlers/goproxy_server_handler.go b/internal/handlers/goproxy_server_handler.go index 64e6eff..e7d6632 100644 --- a/internal/handlers/goproxy_server_handler.go +++ b/internal/handlers/goproxy_server_handler.go @@ -77,7 +77,7 @@ func (h *GoProxyServerHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro } logging.RequestLogf(ctx, "* authenticating goproxy request (host: %s)", req.URL.Hostname()) - req.SetBasicAuth(cred.username, cred.password) + helpers.SetAuthorization(req, helpers.BasicAuth(cred.username, cred.password)) return req, nil } diff --git a/internal/handlers/helm_registry.go b/internal/handlers/helm_registry.go index 981c6a6..8f71cab 100644 --- a/internal/handlers/helm_registry.go +++ b/internal/handlers/helm_registry.go @@ -74,7 +74,7 @@ func (h *HelmRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Prox } logging.RequestLogf(ctx, "* authenticating helm registry request (host: %s)", req.URL.Hostname()) - req.SetBasicAuth(cred.username, cred.password) + helpers.SetAuthorization(req, helpers.BasicAuth(cred.username, cred.password)) return req, nil } diff --git a/internal/handlers/hex_organization.go b/internal/handlers/hex_organization.go index 74a8d18..174a759 100644 --- a/internal/handlers/hex_organization.go +++ b/internal/handlers/hex_organization.go @@ -69,7 +69,7 @@ func (h *HexOrganizationHandler) HandleRequest(req *http.Request, ctx *goproxy.P for _, cred := range h.credentials { if cred.organization == reqOrg { logging.RequestLogf(ctx, "* authenticating hex request (org: %s)", reqOrg) - req.Header.Set("authorization", cred.key) + helpers.SetAuthorization(req, helpers.RawAuth(cred.key)) return req, nil } } diff --git a/internal/handlers/hex_repository.go b/internal/handlers/hex_repository.go index 94c0f80..f58b135 100644 --- a/internal/handlers/hex_repository.go +++ b/internal/handlers/hex_repository.go @@ -85,7 +85,7 @@ func (h *HexRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro } logging.RequestLogf(ctx, "* authenticating hex repository request (host: %s)", req.URL.Hostname()) - req.Header.Set("authorization", cred.authKey) + helpers.SetAuthorization(req, helpers.RawAuth(cred.authKey)) return req, nil } diff --git a/internal/handlers/maven_repository.go b/internal/handlers/maven_repository.go index 40f59fe..a6c804a 100644 --- a/internal/handlers/maven_repository.go +++ b/internal/handlers/maven_repository.go @@ -79,7 +79,7 @@ func (h *MavenRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.P } logging.RequestLogf(ctx, "* authenticating maven repository request (host: %s)", req.URL.Hostname()) - req.SetBasicAuth(cred.username, cred.password) + helpers.SetAuthorization(req, helpers.BasicAuth(cred.username, cred.password)) return req, nil } diff --git a/internal/handlers/npm_registry.go b/internal/handlers/npm_registry.go index 04b04ff..40769ec 100644 --- a/internal/handlers/npm_registry.go +++ b/internal/handlers/npm_registry.go @@ -108,10 +108,10 @@ func (h *NPMRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy username, password, found := strings.Cut(cred.token, ":") if found { logging.RequestLogf(ctx, "* authenticating npm registry request (host: %s, basic auth)", reqHost) - req.SetBasicAuth(username, password) + helpers.SetAuthorization(req, helpers.BasicAuth(username, password)) } else { logging.RequestLogf(ctx, "* authenticating npm registry request (host: %s, token auth)", reqHost) - req.Header.Set("authorization", "Bearer "+cred.token) + helpers.SetAuthorization(req, helpers.BearerAuth(cred.token)) } return req, nil } diff --git a/internal/handlers/nuget_feed.go b/internal/handlers/nuget_feed.go index 4191dc0..6296216 100644 --- a/internal/handlers/nuget_feed.go +++ b/internal/handlers/nuget_feed.go @@ -300,14 +300,14 @@ func authenticateNugetRequest(req *http.Request, cred nugetFeedCredentials, ctx username, password, found := strings.Cut(token, ":") if found { logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, basic auth)", req.URL.Hostname()) - req.SetBasicAuth(username, password) + helpers.SetAuthorization(req, helpers.BasicAuth(username, password)) } else if token != "" { if shouldTreatTokenAsPassword(req.URL) { logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, basic auth for Azure DevOps)", req.URL.Hostname()) - req.SetBasicAuth("", token) + helpers.SetAuthorization(req, helpers.BasicAuth("", token)) } else { logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, bearer auth)", req.URL.Hostname()) - req.Header.Set("authorization", "Bearer "+token) + helpers.SetAuthorization(req, helpers.BearerAuth(token)) } } } diff --git a/internal/handlers/pub_repository.go b/internal/handlers/pub_repository.go index a3bae5a..d7a5c80 100644 --- a/internal/handlers/pub_repository.go +++ b/internal/handlers/pub_repository.go @@ -83,7 +83,7 @@ func (h *PubRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro } logging.RequestLogf(ctx, "* authenticating pub repository request (url: %s)", cred.url) - req.Header.Set("Authorization", "Bearer "+cred.token) + helpers.SetAuthorization(req, helpers.BearerAuth(cred.token)) return req, nil } diff --git a/internal/handlers/python_index.go b/internal/handlers/python_index.go index ef69f9a..f366088 100644 --- a/internal/handlers/python_index.go +++ b/internal/handlers/python_index.go @@ -107,7 +107,7 @@ func (h *PythonIndexHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy } // ignore `found` because it's okay for the password to be an empty string username, password, _ := strings.Cut(token, ":") - req.SetBasicAuth(username, password) + helpers.SetAuthorization(req, helpers.BasicAuth(username, password)) return req, nil } diff --git a/internal/handlers/rubygems_server.go b/internal/handlers/rubygems_server.go index cb9829a..05b5ddf 100644 --- a/internal/handlers/rubygems_server.go +++ b/internal/handlers/rubygems_server.go @@ -80,7 +80,7 @@ func (h *RubyGemsServerHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr // ignore `found` because it's okay for the password to be an empty string username, password, _ := strings.Cut(cred.token, ":") - req.SetBasicAuth(username, password) + helpers.SetAuthorization(req, helpers.BasicAuth(username, password)) return req, nil } diff --git a/internal/handlers/terraform_registry.go b/internal/handlers/terraform_registry.go index df6b09c..aab7878 100644 --- a/internal/handlers/terraform_registry.go +++ b/internal/handlers/terraform_registry.go @@ -88,7 +88,7 @@ func (h *TerraformRegistryHandler) HandleRequest(request *http.Request, context } logging.RequestLogf(context, "* authenticating terraform registry request (host: %s)", request.URL.Hostname()) - request.Header.Set("Authorization", "Bearer "+cred.token) + helpers.SetAuthorization(request, helpers.BearerAuth(cred.token)) return request, nil } diff --git a/internal/helpers/helpers.go b/internal/helpers/helpers.go index cf529ae..3125e6c 100644 --- a/internal/helpers/helpers.go +++ b/internal/helpers/helpers.go @@ -1,6 +1,8 @@ package helpers import ( + "encoding/base64" + "fmt" "io" "net/http" "net/url" @@ -10,6 +12,51 @@ import ( "golang.org/x/net/idna" ) +// authorization holds a pre-formatted Authorization header value. +// Obtain instances via BasicAuth, BearerAuth, TokenAuth, or RawAuth. +type authorization struct { + value string +} + +func (a authorization) asHeader() string { return a.value } + +// BasicAuth returns an authorization for "Basic ". +func BasicAuth(username, password string) authorization { + encoded := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password))) + return authorization{fmt.Sprintf("Basic %s", encoded)} +} + +// BearerAuth returns an authorization for "Bearer ". +func BearerAuth(token string) authorization { + return authorization{fmt.Sprintf("Bearer %s", token)} +} + +// TokenAuth returns an authorization for "token ", used at least by Github API. +func TokenAuth(value string) authorization { + return authorization{fmt.Sprintf("token %s", value)} +} + +// RawAuth returns an authorization whose header value is the given string as-is. +// Use only when the credential is already a fully-formed header value. +func RawAuth(value string) authorization { + return authorization{value} +} + +// SetAuthorization clears the existing authorization header on req and sets it to the value +// described by auth. The header key defaults to "Authorization" if not provided. +// +// Note: The "Authorization" header is always cleared. +func SetAuthorization(req *http.Request, auth authorization, key ...string) { + h := "Authorization" + if len(key) > 0 { + h = key[0] + } + // Clear any auth passed by Dependabot Core + req.Header.Del("Authorization") + req.Header.Del(h) + req.Header.Set(h, auth.asHeader()) +} + func CheckGitHubAPIHost(r *http.Request) bool { hostname := GetHost(r) // Check if the hostname is a GitHub API hostname and will return true diff --git a/internal/helpers/helpers_test.go b/internal/helpers/helpers_test.go index 74d59e8..f5de841 100644 --- a/internal/helpers/helpers_test.go +++ b/internal/helpers/helpers_test.go @@ -1,11 +1,159 @@ package helpers import ( + "encoding/base64" "net/http" + "net/http/httptest" "net/url" "testing" ) +// newRequest builds a GET request to the given raw URL for use in tests. +func newRequest(t *testing.T, rawURL string) *http.Request { + t.Helper() + return httptest.NewRequest(http.MethodGet, rawURL, nil) +} + +// newRequestWithAuth builds a request that already carries an Authorization header, +// simulating a client that sent credentials which should be replaced. +func newRequestWithAuth(t *testing.T, rawURL, existing string) *http.Request { + t.Helper() + req := newRequest(t, rawURL) + req.Header.Set("Authorization", existing) + return req +} + +func TestSetAuthorization_BasicAuth(t *testing.T) { + t.Run("sets correct Basic header", func(t *testing.T) { + req := newRequest(t, "https://example.com") + SetAuthorization(req, BasicAuth("user", "pass")) + + want := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")) + if got := req.Header.Get("Authorization"); got != want { + t.Errorf("Authorization = %q, want %q", got, want) + } + }) + + t.Run("clears pre-existing Authorization header", func(t *testing.T) { + req := newRequestWithAuth(t, "https://example.com", "Bearer old-token") + SetAuthorization(req, BasicAuth("user", "pass")) + + want := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")) + if got := req.Header.Get("Authorization"); got != want { + t.Errorf("Authorization = %q, want %q", got, want) + } + if vals := req.Header["Authorization"]; len(vals) != 1 { + t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals) + } + }) + + t.Run("encodes empty username correctly", func(t *testing.T) { + req := newRequest(t, "https://example.com") + SetAuthorization(req, BasicAuth("", "token")) + + want := "Basic " + base64.StdEncoding.EncodeToString([]byte(":token")) + if got := req.Header.Get("Authorization"); got != want { + t.Errorf("Authorization = %q, want %q", got, want) + } + }) +} + +func TestSetAuthorization_BearerAuth(t *testing.T) { + t.Run("sets correct Bearer header", func(t *testing.T) { + req := newRequest(t, "https://example.com") + SetAuthorization(req, BearerAuth("my-token")) + + if got := req.Header.Get("Authorization"); got != "Bearer my-token" { + t.Errorf("Authorization = %q, want %q", got, "Bearer my-token") + } + }) + + t.Run("clears pre-existing Authorization header", func(t *testing.T) { + req := newRequestWithAuth(t, "https://example.com", "Basic dXNlcjpwYXNz") + SetAuthorization(req, BearerAuth("new-token")) + + if got := req.Header.Get("Authorization"); got != "Bearer new-token" { + t.Errorf("Authorization = %q, want %q", got, "Bearer new-token") + } + if vals := req.Header["Authorization"]; len(vals) != 1 { + t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals) + } + }) +} + +func TestSetAuthorization_TokenAuth(t *testing.T) { + t.Run("sets correct token header", func(t *testing.T) { + req := newRequest(t, "https://api.github.com") + SetAuthorization(req, TokenAuth("ghp_abc123")) + + if got := req.Header.Get("Authorization"); got != "token ghp_abc123" { + t.Errorf("Authorization = %q, want %q", got, "token ghp_abc123") + } + }) + + t.Run("clears pre-existing Authorization header", func(t *testing.T) { + req := newRequestWithAuth(t, "https://api.github.com", "token old-token") + SetAuthorization(req, TokenAuth("new-token")) + + if got := req.Header.Get("Authorization"); got != "token new-token" { + t.Errorf("Authorization = %q, want %q", got, "token new-token") + } + if vals := req.Header["Authorization"]; len(vals) != 1 { + t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals) + } + }) +} + +func TestSetAuthorization_RawAuth(t *testing.T) { + t.Run("sets pre-formatted value as-is", func(t *testing.T) { + req := newRequest(t, "https://example.com") + SetAuthorization(req, RawAuth("Bearer already-formatted")) + + if got := req.Header.Get("Authorization"); got != "Bearer already-formatted" { + t.Errorf("Authorization = %q, want %q", got, "Bearer already-formatted") + } + }) + + t.Run("clears pre-existing Authorization header", func(t *testing.T) { + req := newRequestWithAuth(t, "https://example.com", "Bearer stale") + SetAuthorization(req, RawAuth("token new-raw")) + + if got := req.Header.Get("Authorization"); got != "token new-raw" { + t.Errorf("Authorization = %q, want %q", got, "token new-raw") + } + if vals := req.Header["Authorization"]; len(vals) != 1 { + t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals) + } + }) +} + +func TestSetAuthorization_CustomKey(t *testing.T) { + t.Run("sets value on custom header key", func(t *testing.T) { + req := newRequest(t, "https://cloudsmith.example.com") + SetAuthorization(req, RawAuth("my-api-key"), "X-Api-Key") + + if got := req.Header.Get("X-Api-Key"); got != "my-api-key" { + t.Errorf("X-Api-Key = %q, want %q", got, "my-api-key") + } + if got := req.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization should be empty, got %q", got) + } + }) + + t.Run("clears pre-existing custom header before setting", func(t *testing.T) { + req := newRequest(t, "https://cloudsmith.example.com") + req.Header.Set("X-Api-Key", "old-key") + SetAuthorization(req, RawAuth("new-key"), "X-Api-Key") + + if got := req.Header.Get("X-Api-Key"); got != "new-key" { + t.Errorf("X-Api-Key = %q, want %q", got, "new-key") + } + if vals := req.Header["X-Api-Key"]; len(vals) != 1 { + t.Errorf("expected exactly 1 X-Api-Key value, got %d: %v", len(vals), vals) + } + }) +} + func TestUrlMatchesRequest(t *testing.T) { tests := []struct { name string diff --git a/internal/oidc/oidc_registry.go b/internal/oidc/oidc_registry.go index ce8f702..bbc0c3a 100644 --- a/internal/oidc/oidc_registry.go +++ b/internal/oidc/oidc_registry.go @@ -1,7 +1,6 @@ package oidc import ( - "fmt" "net/http" "strings" "sync" @@ -143,18 +142,18 @@ func (r *OIDCRegistry) TryAuth(req *http.Request, ctx *goproxy.ProxyCtx) bool { switch matched.parameters.(type) { case *CloudsmithOIDCParameters: logging.RequestLogf(ctx, "* authenticating request with OIDC API key (host: %s)", host) - req.Header.Set("X-Api-Key", token) + helpers.SetAuthorization(req, helpers.RawAuth(token), "X-Api-Key") case *GCPOIDCParameters: if strings.HasSuffix(host, "-docker.pkg.dev") { logging.RequestLogf(ctx, "* authenticating request with OIDC oauth2accesstoken (host: %s)", host) - req.SetBasicAuth("oauth2accesstoken", token) + helpers.SetAuthorization(req, helpers.BasicAuth("oauth2accesstoken", token)) } else { logging.RequestLogf(ctx, "* authenticating request with OIDC token (host: %s)", host) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + helpers.SetAuthorization(req, helpers.BearerAuth(token)) } default: logging.RequestLogf(ctx, "* authenticating request with OIDC token (host: %s)", host) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + helpers.SetAuthorization(req, helpers.BearerAuth(token)) } return true