diff --git a/ociregistry/ociauth/auth.go b/ociregistry/ociauth/auth.go index 60fc50c..87d8a8e 100644 --- a/ociregistry/ociauth/auth.go +++ b/ociregistry/ociauth/auth.go @@ -1,6 +1,7 @@ package ociauth import ( + "bytes" "context" "encoding/json" "errors" @@ -184,7 +185,30 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } } - return r.transport.RoundTrip(req) + resp, err = r.transport.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + deniedErr := ociregistry.ErrDenied + // The server has responded with Unauthorized (401) even though we've just + // provided a token that it gave us. Treat it as Forbidden (403) instead. + // TODO include the original error as part of the message or message detail? + resp.Body.Close() + data, err := json.Marshal(&ociregistry.WireErrors{ + Errors: []ociregistry.WireError{{ + Code_: deniedErr.Code(), + Message: "unauthorized response with freshly acquired auth token", + }}, + }) + resp.Header.Set("Content-Type", "application/json") + resp.ContentLength = int64(len(data)) + resp.Body = io.NopCloser(bytes.NewReader(data)) + resp.StatusCode = http.StatusForbidden + resp.Status = http.StatusText(resp.StatusCode) + return resp, nil } // setAuthorization sets up authorization on the given request using any @@ -220,7 +244,7 @@ func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requ accessToken, err := r.acquireAccessToken(ctx, requiredScope, wantScope) if err != nil { - return err + return fmt.Errorf("cannot acquire access token: %v", err) } req.Header.Set("Authorization", "Bearer "+accessToken) return nil diff --git a/ociregistry/ociauth/auth_test.go b/ociregistry/ociauth/auth_test.go index 3b9a8f3..2a32845 100644 --- a/ociregistry/ociauth/auth_test.go +++ b/ociregistry/ociauth/auth_test.go @@ -235,6 +235,78 @@ func TestAuthNotAvailableAfterChallenge(t *testing.T) { qt.Check(t, qt.Equals(requestCount, 1)) } +func Test401ResponseWithJustAcquiredToken(t *testing.T) { + // This tests the scenario where a server returns a 401 response + // when the client has just successfully acquired a token from + // the auth server. + // + // In this case, a "correct" server should return + // either 403 (access to the resource is forbidden because the + // client's credentials are not sufficient) or 404 (either the + // repository really doesn't exist or the credentials are insufficient + // and the server doesn't allow clients to see whether repositories + // they don't have access to might exist). + // + // However, some real-world servers instead return a 401 response + // erroneously indicating that the client needs to acquire + // authorization credentials, even though they have in fact just + // done so. + // + // As a workaround for this case, we treat the response as a 404. + + testScope := ParseScope("repository:foo:pull") + authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) { + requestedScope := ParseScope(req.Form.Get("scope")) + if !runNonFatal(t, func(t testing.TB) { + qt.Assert(t, qt.DeepEquals(requestedScope, testScope)) + qt.Assert(t, qt.DeepEquals(req.Form["service"], []string{"someService"})) + }) { + return nil, &httpError{ + statusCode: http.StatusInternalServerError, + } + } + return &wireToken{ + Token: token{requestedScope}.String(), + }, nil + }) + ts := newTargetServer(t, func(req *http.Request) *httpError { + if req.Header.Get("Authorization") == "" { + return &httpError{ + statusCode: http.StatusUnauthorized, + header: http.Header{ + "Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)}, + }, + } + } + if !runNonFatal(t, func(t testing.TB) { + qt.Assert(t, qt.DeepEquals(authScopeFromRequest(t, req), testScope)) + }) { + return &httpError{ + statusCode: http.StatusInternalServerError, + } + } + return &httpError{ + statusCode: http.StatusUnauthorized, + header: http.Header{ + "Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)}, + }, + } + }) + client := &http.Client{ + Transport: NewStdTransport(StdTransportParams{ + Config: configFunc(func(host string) (ConfigEntry, error) { + return ConfigEntry{}, nil + }), + }), + } + req, err := http.NewRequestWithContext(context.Background(), "GET", ts.String()+"/test", nil) + qt.Assert(t, qt.IsNil(err)) + resp, err := client.Do(req) + qt.Assert(t, qt.IsNil(err)) + defer resp.Body.Close() + qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusForbidden)) +} + func TestConfigHasAccessToken(t *testing.T) { accessToken := "somevalue" ts := newTargetServer(t, func(req *http.Request) *httpError {