diff --git a/ociregistry/ociauth/auth.go b/ociregistry/ociauth/auth.go index d30b1e59..955356c2 100644 --- a/ociregistry/ociauth/auth.go +++ b/ociregistry/ociauth/auth.go @@ -168,7 +168,7 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) { if challenge == nil { return resp, nil } - authAdded, err := r.setAuthorizationFromChallenge(ctx, req, challenge, requiredScope, wantScope) + authAdded, tokenAcquired, err := r.setAuthorizationFromChallenge(ctx, req, challenge, requiredScope, wantScope) if err != nil { resp.Body.Close() return nil, err @@ -189,17 +189,16 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { return nil, err } - if resp.StatusCode != http.StatusUnauthorized { + if resp.StatusCode != http.StatusUnauthorized || !tokenAcquired { return resp, nil } - deniedErr := ociregistry.ErrDenied // The server has responded with Unauthorized (401) even though we've just // provided a token that it gave us. Treat it as Forbidden (403) instead. // TODO include the original body/error as part of the message or message detail? resp.Body.Close() data, err := json.Marshal(&ociregistry.WireErrors{ Errors: []ociregistry.WireError{{ - Code_: deniedErr.Code(), + Code_: ociregistry.ErrDenied.Code(), Message: "unauthorized response with freshly acquired auth token", }}, }) @@ -262,7 +261,7 @@ func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requ return nil } -func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http.Request, challenge *authHeader, requiredScope, wantScope Scope) (bool, error) { +func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http.Request, challenge *authHeader, requiredScope, wantScope Scope) (authAdded, tokenAcquired bool, _ error) { r.mu.Lock() defer r.mu.Unlock() r.wwwAuthenticate = challenge @@ -272,15 +271,15 @@ func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http. scope := ParseScope(r.wwwAuthenticate.params["scope"]) accessToken, err := r.acquireAccessToken(ctx, scope, wantScope.Union(requiredScope)) if err != nil { - return false, err + return false, false, err } req.Header.Set("Authorization", "Bearer "+accessToken) - return true, nil + return true, true, nil case r.basic != nil: req.SetBasicAuth(r.basic.username, r.basic.password) - return true, nil + return true, false, nil } - return false, nil + return false, false, nil } // init initializes the registry instance by acquiring auth information from diff --git a/ociregistry/ociauth/auth_test.go b/ociregistry/ociauth/auth_test.go index 2a328455..f997399f 100644 --- a/ociregistry/ociauth/auth_test.go +++ b/ociregistry/ociauth/auth_test.go @@ -307,6 +307,52 @@ func Test401ResponseWithJustAcquiredToken(t *testing.T) { qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusForbidden)) } +func Test401ResponseWithNonAcquiredToken(t *testing.T) { + // This tests the scenario where a server returns a 401 response + // when the client has provided credentials already present in + // the configuration file. + // + // In this case, we don't want to trigger the fake-403-response + // behaviour test for in Test401ResponseWithJustAcquiredToken. + + ts := newTargetServer(t, func(req *http.Request) *httpError { + if req.Header.Get("Authorization") == "" { + return &httpError{ + statusCode: http.StatusUnauthorized, + header: http.Header{ + "Www-Authenticate": []string{"Basic"}, + }, + body: "no auth creds provided", + } + } + return &httpError{ + statusCode: http.StatusUnauthorized, + header: http.Header{ + "Www-Authenticate": []string{"Basic"}, + }, + body: "password mismatch", + } + }) + client := &http.Client{ + Transport: NewStdTransport(StdTransportParams{ + Config: configFunc(func(host string) (ConfigEntry, error) { + return ConfigEntry{ + Username: "someuser", + Password: "somepassword", + }, nil + }), + }), + } + req, err := http.NewRequestWithContext(context.Background(), "GET", ts.String()+"/test", nil) + qt.Assert(t, qt.IsNil(err)) + resp, err := client.Do(req) + qt.Assert(t, qt.IsNil(err)) + defer resp.Body.Close() + data, _ := io.ReadAll(resp.Body) + qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusUnauthorized)) + qt.Assert(t, qt.Equals(string(data), "password mismatch")) +} + func TestConfigHasAccessToken(t *testing.T) { accessToken := "somevalue" ts := newTargetServer(t, func(req *http.Request) *httpError {