Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/handlers/azdo_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (h *AzureDevOpsAPIHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr
}

logging.RequestLogf(ctx, "* authenticating azure devops api request with token for %s", host)
req.SetBasicAuth(creds[0].username, creds[0].password)
helpers.SetAuthorization(req, helpers.BasicAuth(creds[0].username, creds[0].password))

// Azure DevOps requires an api-version to be set for requests. Add it if it is not present.
var queryParams = req.URL.Query()
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/cargo_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (h *CargoRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro
}

logging.RequestLogf(ctx, "* authenticating cargo registry request (url: %s)", cred.url)
req.Header.Set("Authorization", cred.authorization)
helpers.SetAuthorization(req, helpers.RawAuth(cred.authorization))

return req, nil
}
Expand Down
4 changes: 2 additions & 2 deletions internal/handlers/composer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ func (h *ComposerHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCtx

if cred.token != "" {
logging.RequestLogf(ctx, "* authenticating composer registry request (host: %s, token auth)", req.URL.Hostname())
req.Header.Set("Authorization", "Bearer "+cred.token)
helpers.SetAuthorization(req, helpers.BearerAuth(cred.token))
} else {
logging.RequestLogf(ctx, "* authenticating composer registry request (host: %s, basic auth)", req.URL.Hostname())
req.SetBasicAuth(cred.username, cred.password)
helpers.SetAuthorization(req, helpers.BasicAuth(cred.username, cred.password))
}

return req, nil
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/docker_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (h *DockerRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr

if cred.getECRCredentials(ctx) {
logging.RequestLogf(ctx, "* authenticating docker ecr request (host: %s)", req.URL.Hostname())
req.SetBasicAuth(cred.ecrUsername, cred.ecrPassword)
helpers.SetAuthorization(req, helpers.BasicAuth(cred.ecrUsername, cred.ecrPassword))
} else {
logging.RequestLogf(ctx, "* authenticating docker registry request (host: %s)", req.URL.Hostname())
transport := &registry.BasicTransport{
Expand Down
4 changes: 2 additions & 2 deletions internal/handlers/git_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (h *GitServerHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt

logging.RequestLogf(ctx, "* authenticating git server request (host: %s)", helpers.GetHost(req))
credsToUse := creds[0]
req.SetBasicAuth(credsToUse.username, credsToUse.password)
helpers.SetAuthorization(req, helpers.BasicAuth(credsToUse.username, credsToUse.password))
if ctx != nil {
ctxdata.SetValue(ctx, addedAuthCtxKey, credsToUse)
}
Expand Down Expand Up @@ -472,7 +472,7 @@ func (h *GitServerHandler) requestWithAlternativeAuth(ctx *goproxy.ProxyCtx, bod
newReq.Body = io.NopCloser(bytes.NewReader(body))
}

newReq.SetBasicAuth(creds.username, creds.password)
helpers.SetAuthorization(newReq, helpers.BasicAuth(creds.username, creds.password))
newRsp, err := ctx.RoundTrip(newReq)
if err != nil {
return nil
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/goproxy_server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (h *GoProxyServerHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro
}

logging.RequestLogf(ctx, "* authenticating goproxy request (host: %s)", req.URL.Hostname())
req.SetBasicAuth(cred.username, cred.password)
helpers.SetAuthorization(req, helpers.BasicAuth(cred.username, cred.password))

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/helm_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (h *HelmRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Prox
}

logging.RequestLogf(ctx, "* authenticating helm registry request (host: %s)", req.URL.Hostname())
req.SetBasicAuth(cred.username, cred.password)
helpers.SetAuthorization(req, helpers.BasicAuth(cred.username, cred.password))

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/hex_organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (h *HexOrganizationHandler) HandleRequest(req *http.Request, ctx *goproxy.P
for _, cred := range h.credentials {
if cred.organization == reqOrg {
logging.RequestLogf(ctx, "* authenticating hex request (org: %s)", reqOrg)
req.Header.Set("authorization", cred.key)
helpers.SetAuthorization(req, helpers.RawAuth(cred.key))
return req, nil
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/hex_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (h *HexRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro
}

logging.RequestLogf(ctx, "* authenticating hex repository request (host: %s)", req.URL.Hostname())
req.Header.Set("authorization", cred.authKey)
helpers.SetAuthorization(req, helpers.RawAuth(cred.authKey))

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/maven_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (h *MavenRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.P
}

logging.RequestLogf(ctx, "* authenticating maven repository request (host: %s)", req.URL.Hostname())
req.SetBasicAuth(cred.username, cred.password)
helpers.SetAuthorization(req, helpers.BasicAuth(cred.username, cred.password))

return req, nil
}
Expand Down
4 changes: 2 additions & 2 deletions internal/handlers/npm_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ func (h *NPMRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy
username, password, found := strings.Cut(cred.token, ":")
if found {
logging.RequestLogf(ctx, "* authenticating npm registry request (host: %s, basic auth)", reqHost)
req.SetBasicAuth(username, password)
helpers.SetAuthorization(req, helpers.BasicAuth(username, password))
} else {
logging.RequestLogf(ctx, "* authenticating npm registry request (host: %s, token auth)", reqHost)
req.Header.Set("authorization", "Bearer "+cred.token)
helpers.SetAuthorization(req, helpers.BearerAuth(cred.token))
}
return req, nil
}
Expand Down
6 changes: 3 additions & 3 deletions internal/handlers/nuget_feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,14 @@ func authenticateNugetRequest(req *http.Request, cred nugetFeedCredentials, ctx
username, password, found := strings.Cut(token, ":")
if found {
logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, basic auth)", req.URL.Hostname())
req.SetBasicAuth(username, password)
helpers.SetAuthorization(req, helpers.BasicAuth(username, password))
} else if token != "" {
if shouldTreatTokenAsPassword(req.URL) {
logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, basic auth for Azure DevOps)", req.URL.Hostname())
req.SetBasicAuth("", token)
helpers.SetAuthorization(req, helpers.BasicAuth("", token))
} else {
logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, bearer auth)", req.URL.Hostname())
req.Header.Set("authorization", "Bearer "+token)
helpers.SetAuthorization(req, helpers.BearerAuth(token))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/pub_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (h *PubRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro
}

logging.RequestLogf(ctx, "* authenticating pub repository request (url: %s)", cred.url)
req.Header.Set("Authorization", "Bearer "+cred.token)
helpers.SetAuthorization(req, helpers.BearerAuth(cred.token))

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/python_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (h *PythonIndexHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy
}
// ignore `found` because it's okay for the password to be an empty string
username, password, _ := strings.Cut(token, ":")
req.SetBasicAuth(username, password)
helpers.SetAuthorization(req, helpers.BasicAuth(username, password))

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/rubygems_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (h *RubyGemsServerHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr

// ignore `found` because it's okay for the password to be an empty string
username, password, _ := strings.Cut(cred.token, ":")
req.SetBasicAuth(username, password)
helpers.SetAuthorization(req, helpers.BasicAuth(username, password))

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/terraform_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (h *TerraformRegistryHandler) HandleRequest(request *http.Request, context
}

logging.RequestLogf(context, "* authenticating terraform registry request (host: %s)", request.URL.Hostname())
request.Header.Set("Authorization", "Bearer "+cred.token)
helpers.SetAuthorization(request, helpers.BearerAuth(cred.token))
return request, nil
}

Expand Down
47 changes: 47 additions & 0 deletions internal/helpers/helpers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package helpers

import (
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
Expand All @@ -10,6 +12,51 @@ import (
"golang.org/x/net/idna"
)

// authorization holds a pre-formatted Authorization header value.
// Obtain instances via BasicAuth, BearerAuth, TokenAuth, or RawAuth.
type authorization struct {
value string
}

func (a authorization) asHeader() string { return a.value }

// BasicAuth returns an authorization for "Basic <base64(username:password)>".
func BasicAuth(username, password string) authorization {
encoded := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password)))
return authorization{fmt.Sprintf("Basic %s", encoded)}
}
Comment on lines +23 to +27

// BearerAuth returns an authorization for "Bearer <token>".
func BearerAuth(token string) authorization {
return authorization{fmt.Sprintf("Bearer %s", token)}
}

// TokenAuth returns an authorization for "token <value>", used at least by Github API.
func TokenAuth(value string) authorization {
return authorization{fmt.Sprintf("token %s", value)}
}

// RawAuth returns an authorization whose header value is the given string as-is.
// Use only when the credential is already a fully-formed header value.
func RawAuth(value string) authorization {
return authorization{value}
}

// SetAuthorization clears the existing authorization header on req and sets it to the value
// described by auth. The header key defaults to "Authorization" if not provided.
//
// Note: The "Authorization" header is always cleared.
func SetAuthorization(req *http.Request, auth authorization, key ...string) {
h := "Authorization"
if len(key) > 0 {
Comment on lines +49 to +51
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too familiar on how to cleanly do optional arguments in golang. Would appreciate feedback.

h = key[0]
}
// Clear any auth passed by Dependabot Core
req.Header.Del("Authorization")
req.Header.Del(h)
req.Header.Set(h, auth.asHeader())
}
Comment thread
joniumGit marked this conversation as resolved.

func CheckGitHubAPIHost(r *http.Request) bool {
hostname := GetHost(r)
// Check if the hostname is a GitHub API hostname and will return true
Expand Down
148 changes: 148 additions & 0 deletions internal/helpers/helpers_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,159 @@
package helpers

import (
"encoding/base64"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)

// newRequest builds a GET request to the given raw URL for use in tests.
func newRequest(t *testing.T, rawURL string) *http.Request {
t.Helper()
return httptest.NewRequest(http.MethodGet, rawURL, nil)
}

// newRequestWithAuth builds a request that already carries an Authorization header,
// simulating a client that sent credentials which should be replaced.
func newRequestWithAuth(t *testing.T, rawURL, existing string) *http.Request {
t.Helper()
req := newRequest(t, rawURL)
req.Header.Set("Authorization", existing)
return req
}

func TestSetAuthorization_BasicAuth(t *testing.T) {
t.Run("sets correct Basic header", func(t *testing.T) {
req := newRequest(t, "https://example.com")
SetAuthorization(req, BasicAuth("user", "pass"))

want := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
if got := req.Header.Get("Authorization"); got != want {
t.Errorf("Authorization = %q, want %q", got, want)
}
})

t.Run("clears pre-existing Authorization header", func(t *testing.T) {
req := newRequestWithAuth(t, "https://example.com", "Bearer old-token")
SetAuthorization(req, BasicAuth("user", "pass"))

want := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
if got := req.Header.Get("Authorization"); got != want {
t.Errorf("Authorization = %q, want %q", got, want)
}
if vals := req.Header["Authorization"]; len(vals) != 1 {
t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals)
}
})

t.Run("encodes empty username correctly", func(t *testing.T) {
req := newRequest(t, "https://example.com")
SetAuthorization(req, BasicAuth("", "token"))

want := "Basic " + base64.StdEncoding.EncodeToString([]byte(":token"))
if got := req.Header.Get("Authorization"); got != want {
t.Errorf("Authorization = %q, want %q", got, want)
}
})
}

func TestSetAuthorization_BearerAuth(t *testing.T) {
t.Run("sets correct Bearer header", func(t *testing.T) {
req := newRequest(t, "https://example.com")
SetAuthorization(req, BearerAuth("my-token"))

if got := req.Header.Get("Authorization"); got != "Bearer my-token" {
t.Errorf("Authorization = %q, want %q", got, "Bearer my-token")
}
})

t.Run("clears pre-existing Authorization header", func(t *testing.T) {
req := newRequestWithAuth(t, "https://example.com", "Basic dXNlcjpwYXNz")
SetAuthorization(req, BearerAuth("new-token"))

if got := req.Header.Get("Authorization"); got != "Bearer new-token" {
t.Errorf("Authorization = %q, want %q", got, "Bearer new-token")
}
if vals := req.Header["Authorization"]; len(vals) != 1 {
t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals)
}
})
}

func TestSetAuthorization_TokenAuth(t *testing.T) {
t.Run("sets correct token header", func(t *testing.T) {
req := newRequest(t, "https://api.github.com")
SetAuthorization(req, TokenAuth("ghp_abc123"))

if got := req.Header.Get("Authorization"); got != "token ghp_abc123" {
t.Errorf("Authorization = %q, want %q", got, "token ghp_abc123")
}
})

t.Run("clears pre-existing Authorization header", func(t *testing.T) {
req := newRequestWithAuth(t, "https://api.github.com", "token old-token")
SetAuthorization(req, TokenAuth("new-token"))

if got := req.Header.Get("Authorization"); got != "token new-token" {
t.Errorf("Authorization = %q, want %q", got, "token new-token")
}
if vals := req.Header["Authorization"]; len(vals) != 1 {
t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals)
}
})
}

func TestSetAuthorization_RawAuth(t *testing.T) {
t.Run("sets pre-formatted value as-is", func(t *testing.T) {
req := newRequest(t, "https://example.com")
SetAuthorization(req, RawAuth("Bearer already-formatted"))

if got := req.Header.Get("Authorization"); got != "Bearer already-formatted" {
t.Errorf("Authorization = %q, want %q", got, "Bearer already-formatted")
}
})

t.Run("clears pre-existing Authorization header", func(t *testing.T) {
req := newRequestWithAuth(t, "https://example.com", "Bearer stale")
SetAuthorization(req, RawAuth("token new-raw"))

if got := req.Header.Get("Authorization"); got != "token new-raw" {
t.Errorf("Authorization = %q, want %q", got, "token new-raw")
}
if vals := req.Header["Authorization"]; len(vals) != 1 {
t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals)
}
})
}

func TestSetAuthorization_CustomKey(t *testing.T) {
t.Run("sets value on custom header key", func(t *testing.T) {
req := newRequest(t, "https://cloudsmith.example.com")
SetAuthorization(req, RawAuth("my-api-key"), "X-Api-Key")

if got := req.Header.Get("X-Api-Key"); got != "my-api-key" {
t.Errorf("X-Api-Key = %q, want %q", got, "my-api-key")
}
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("Authorization should be empty, got %q", got)
}
})

t.Run("clears pre-existing custom header before setting", func(t *testing.T) {
req := newRequest(t, "https://cloudsmith.example.com")
req.Header.Set("X-Api-Key", "old-key")
SetAuthorization(req, RawAuth("new-key"), "X-Api-Key")

if got := req.Header.Get("X-Api-Key"); got != "new-key" {
t.Errorf("X-Api-Key = %q, want %q", got, "new-key")
}
if vals := req.Header["X-Api-Key"]; len(vals) != 1 {
t.Errorf("expected exactly 1 X-Api-Key value, got %d: %v", len(vals), vals)
}
})
}

func TestUrlMatchesRequest(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading
Loading