Skip to content

Commit

Permalink
ociregistry/ociauth: return 403 instead of 401 after new token acquired
Browse files Browse the repository at this point in the history
Some registries will return a 401 Unauthorized error, indicating that no
valid auth credentials have been provided, even when valid auth
credentials _have_ been provided. Although this goes against [the HTTP
status code conventions](https://stackoverflow.com/a/6937030) we need to
deal with this somehow, because otherwise a client cannot distinguish
between a "bad auth credentials error" (meaning that if the user does
authenticate, the error might go away) and an "auth credentials not
valid for resource error" (meaning that the user is authenticated but
cannot access the resource despite that).

The `ociauth` package is in a unique position to be able to make this
determination because it the only place that knows that auth credentials
have been freshly acquired, therefore a subsequent 401 error is almost
certainly because the privileges were insufficient for the authenticated
user rather than because there is no authenticated user.

So, we change `ociauth` to return 403 Forbidden in this case.

Fixes cue-lang/cue#2955

Signed-off-by: Roger Peppe <rogpeppe@gmail.com>
Change-Id: Ie50bb826e266d3b26f06881d41da36f740bc43ab
Reviewed-on: https://review.gerrithub.io/c/cue-labs/oci/+/1188182
TryBot-Result: CUE porcuepine <cue.porcuepine@gmail.com>
Reviewed-by: Tianon Gravi <admwiggin@gmail.com>
Reviewed-by: Paul Jolly <paul@myitcv.io>
  • Loading branch information
rogpeppe committed Mar 28, 2024
1 parent a074250 commit 7eb5fc6
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 2 deletions.
34 changes: 32 additions & 2 deletions ociregistry/ociauth/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ociauth

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -184,7 +185,33 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}
}
return r.transport.RoundTrip(req)
resp, err = r.transport.RoundTrip(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusUnauthorized {
return resp, nil
}
deniedErr := ociregistry.ErrDenied
// The server has responded with Unauthorized (401) even though we've just
// provided a token that it gave us. Treat it as Forbidden (403) instead.
// TODO include the original body/error as part of the message or message detail?
resp.Body.Close()
data, err := json.Marshal(&ociregistry.WireErrors{
Errors: []ociregistry.WireError{{
Code_: deniedErr.Code(),
Message: "unauthorized response with freshly acquired auth token",
}},
})
if err != nil {
return nil, fmt.Errorf("cannot marshal response body: %v", err)
}
resp.Header.Set("Content-Type", "application/json")
resp.ContentLength = int64(len(data))
resp.Body = io.NopCloser(bytes.NewReader(data))
resp.StatusCode = http.StatusForbidden
resp.Status = http.StatusText(resp.StatusCode)
return resp, nil
}

// setAuthorization sets up authorization on the given request using any
Expand Down Expand Up @@ -220,7 +247,10 @@ func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requ

accessToken, err := r.acquireAccessToken(ctx, requiredScope, wantScope)
if err != nil {
return err
// Avoid using %w to wrap the error because we don't want the
// caller of RoundTrip (usually ociclient) to assume that the
// error applies to the target server rather than the token server.
return fmt.Errorf("cannot acquire access token: %v", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
return nil
Expand Down
72 changes: 72 additions & 0 deletions ociregistry/ociauth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,78 @@ func TestAuthNotAvailableAfterChallenge(t *testing.T) {
qt.Check(t, qt.Equals(requestCount, 1))
}

func Test401ResponseWithJustAcquiredToken(t *testing.T) {
// This tests the scenario where a server returns a 401 response
// when the client has just successfully acquired a token from
// the auth server.
//
// In this case, a "correct" server should return
// either 403 (access to the resource is forbidden because the
// client's credentials are not sufficient) or 404 (either the
// repository really doesn't exist or the credentials are insufficient
// and the server doesn't allow clients to see whether repositories
// they don't have access to might exist).
//
// However, some real-world servers instead return a 401 response
// erroneously indicating that the client needs to acquire
// authorization credentials, even though they have in fact just
// done so.
//
// As a workaround for this case, we treat the response as a 404.

testScope := ParseScope("repository:foo:pull")
authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
requestedScope := ParseScope(req.Form.Get("scope"))
if !runNonFatal(t, func(t testing.TB) {
qt.Assert(t, qt.DeepEquals(requestedScope, testScope))
qt.Assert(t, qt.DeepEquals(req.Form["service"], []string{"someService"}))
}) {
return nil, &httpError{
statusCode: http.StatusInternalServerError,
}
}
return &wireToken{
Token: token{requestedScope}.String(),
}, nil
})
ts := newTargetServer(t, func(req *http.Request) *httpError {
if req.Header.Get("Authorization") == "" {
return &httpError{
statusCode: http.StatusUnauthorized,
header: http.Header{
"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)},
},
}
}
if !runNonFatal(t, func(t testing.TB) {
qt.Assert(t, qt.DeepEquals(authScopeFromRequest(t, req), testScope))
}) {
return &httpError{
statusCode: http.StatusInternalServerError,
}
}
return &httpError{
statusCode: http.StatusUnauthorized,
header: http.Header{
"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)},
},
}
})
client := &http.Client{
Transport: NewStdTransport(StdTransportParams{
Config: configFunc(func(host string) (ConfigEntry, error) {
return ConfigEntry{}, nil
}),
}),
}
req, err := http.NewRequestWithContext(context.Background(), "GET", ts.String()+"/test", nil)
qt.Assert(t, qt.IsNil(err))
resp, err := client.Do(req)
qt.Assert(t, qt.IsNil(err))
defer resp.Body.Close()
qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusForbidden))
}

func TestConfigHasAccessToken(t *testing.T) {
accessToken := "somevalue"
ts := newTargetServer(t, func(req *http.Request) *httpError {
Expand Down

0 comments on commit 7eb5fc6

Please sign in to comment.