Skip to content

Commit

Permalink
ociregistry/ociserver: allow customization of error responses
Browse files Browse the repository at this point in the history
Currently there is no way to fully customize the errors sent by
the `ociserver` implementation. This change allows
that possibility by providing `ociserver.Options.WriteError`,
giving the user complete flexibility in that respect.

It's still possible to access the default implementation
by invoking `ociregistry.WriteError` or `ociregistry.MarshalError`
directly, allowing a simple way of adding functionality
to the error marshaling logic without the need to rewrite the
whole thing.

Signed-off-by: Roger Peppe <rogpeppe@gmail.com>
Change-Id: Ic680283a29a29461613c8d46d9f212400922d488
Dispatch-Trailer: {"type":"trybot","CL":1192981,"patchset":1,"ref":"refs/changes/81/1192981/1","targetBranch":"main"}
  • Loading branch information
rogpeppe authored and porcuepine committed Apr 11, 2024
1 parent a39bec0 commit e27348f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
10 changes: 10 additions & 0 deletions ociregistry/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,16 @@ func (e *httpError) ResponseBody() []byte {
return e.body
}

// WriteError marshals the given error as JSON using [MarshalError] and
// then writes it to w. It returns the error returned from w.Write.
func WriteError(w http.ResponseWriter, err error) error {
data, httpStatus := MarshalError(err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(httpStatus)
_, err = w.Write(data)
return err
}

// MarshalError marshals the given error as JSON according
// to the OCI distribution specification. It also returns
// the associated HTTP status code, or [http.StatusInternalServerError]
Expand Down
8 changes: 0 additions & 8 deletions ociregistry/ociserver/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,10 @@ package ociserver

import (
"fmt"
"net/http"

"cuelabs.dev/go/oci/ociregistry"
)

func writeError(resp http.ResponseWriter, err error) {
data, httpStatus := ociregistry.MarshalError(err)
resp.Header().Set("Content-Type", "application/json")
resp.WriteHeader(httpStatus)
resp.Write(data)
}

func withHTTPCode(statusCode int, err error) error {
return ociregistry.NewHTTPError(err, statusCode, nil, nil)
}
Expand Down
18 changes: 18 additions & 0 deletions ociregistry/ociserver/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ import (
"github.com/go-quicktest/qt"
)

func TestCustomErrorWriter(t *testing.T) {
// Test that if an Interface method returns an HTTPError error, the
// HTTP status code is derived from the OCI error code in preference
// to the HTTPError status code.
r := New(&ociregistry.Funcs{}, &Options{
WriteError: func(w http.ResponseWriter, err error) error {
w.Header().Set("Some-Header", "a value")
return ociregistry.WriteError(w, err)
},
})
s := httptest.NewServer(r)
defer s.Close()
resp, err := http.Get(s.URL + "/v2/foo/manifests/sometag")
qt.Assert(t, qt.IsNil(err))
defer resp.Body.Close()
qt.Assert(t, qt.Equals(resp.Header.Get("Some-Header"), "a value"))
}

func TestHTTPStatusOverriddenByErrorCode(t *testing.T) {
// Test that if an Interface method returns an HTTPError error, the
// HTTP status code is derived from the OCI error code in preference
Expand Down
11 changes: 10 additions & 1 deletion ociregistry/ociserver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ var v2 = ocispecroot.Versioned{

// Options holds options for the server.
type Options struct {
// WriteError is used to write error responses. It is passed the
// error an API call has returned and is responsible for writing
// it to w. If WriteError is nil, [ociregistry.WriteError] will
// be used.
WriteError func(w http.ResponseWriter, err error) error

// DisableReferrersAPI, when true, causes the registry to behave as if
// it does not understand the referrers API.
DisableReferrersAPI bool
Expand Down Expand Up @@ -133,6 +139,9 @@ func New(backend ociregistry.Interface, opts *Options) http.Handler {
if r.opts.DebugID == "" {
r.opts.DebugID = fmt.Sprintf("ociserver%d", atomic.AddInt32(&debugID, 1))
}
if r.opts.WriteError == nil {
r.opts.WriteError = ociregistry.WriteError
}
return r
}

Expand Down Expand Up @@ -167,7 +176,7 @@ var handlers = []func(r *registry, ctx context.Context, w http.ResponseWriter, r

func (r *registry) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
if rerr := r.v2(resp, req); rerr != nil {
writeError(resp, rerr)
r.opts.WriteError(resp, rerr)
return
}
}
Expand Down

0 comments on commit e27348f

Please sign in to comment.