Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 67 additions & 40 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ package errors

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"strings"
)

// DefaultHTTPCode is used when the error Code cannot be used as an HTTP code.
//
//nolint:gochecknoglobals // it should have been a constant in the first place, but now it is mutable so we have to leave it here or introduce a breaking change.
var DefaultHTTPCode = http.StatusUnprocessableEntity

// Error represents a error interface all swagger framework errors implement
// Error represents a error interface all swagger framework errors implement.
type Error interface {
error
Code() int32
Expand All @@ -33,15 +36,15 @@ func (a *apiError) Code() int32 {
return a.code
}

// MarshalJSON implements the JSON encoding interface
// MarshalJSON implements the JSON encoding interface.
func (a apiError) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"code": a.code,
"message": a.message,
})
}

// New creates a new API error with a code and a message
// New creates a new API error with a code and a message.
func New(code int32, message string, args ...any) Error {
if len(args) > 0 {
return &apiError{
Expand All @@ -55,20 +58,20 @@ func New(code int32, message string, args ...any) Error {
}
}

// NotFound creates a new not found error
// NotFound creates a new not found error.
func NotFound(message string, args ...any) Error {
if message == "" {
message = "Not found"
}
return New(http.StatusNotFound, message, args...)
}

// NotImplemented creates a new not implemented error
// NotImplemented creates a new not implemented error.
func NotImplemented(message string) Error {
return New(http.StatusNotImplemented, "%s", message)
}

// MethodNotAllowedError represents an error for when the path matches but the method doesn't
// MethodNotAllowedError represents an error for when the path matches but the method doesn't.
type MethodNotAllowedError struct {
code int32
Allowed []string
Expand All @@ -79,12 +82,12 @@ func (m *MethodNotAllowedError) Error() string {
return m.message
}

// Code the error code
// Code the error code.
func (m *MethodNotAllowedError) Code() int32 {
return m.code
}

// MarshalJSON implements the JSON encoding interface
// MarshalJSON implements the JSON encoding interface.
func (m MethodNotAllowedError) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"code": m.code,
Expand All @@ -104,25 +107,33 @@ func errorAsJSON(err Error) []byte {

func flattenComposite(errs *CompositeError) *CompositeError {
var res []error
for _, er := range errs.Errors {
switch e := er.(type) {
case *CompositeError:
if e != nil && len(e.Errors) > 0 {
flat := flattenComposite(e)
if len(flat.Errors) > 0 {
res = append(res, flat.Errors...)
}
}
default:
if e != nil {
res = append(res, e)
}

for _, err := range errs.Errors {
if err == nil {
continue
}

e := &CompositeError{}
if !errors.As(err, &e) {
res = append(res, err)

continue
}

if len(e.Errors) == 0 {
res = append(res, e)

continue
}

flat := flattenComposite(e)
res = append(res, flat.Errors...)
}

return CompositeValidationError(res...)
}

// MethodNotAllowed creates a new method not allowed error
// MethodNotAllowed creates a new method not allowed error.
func MethodNotAllowed(requested string, allow []string) Error {
msg := fmt.Sprintf("method %s is not allowed, but [%s] are", requested, strings.Join(allow, ","))
return &MethodNotAllowedError{
Expand All @@ -132,39 +143,55 @@ func MethodNotAllowed(requested string, allow []string) Error {
}
}

// ServeError implements the http error handler interface
// ServeError implements the http error handler interface.
func ServeError(rw http.ResponseWriter, r *http.Request, err error) {
rw.Header().Set("Content-Type", "application/json")
switch e := err.(type) {
case *CompositeError:
er := flattenComposite(e)

if err == nil {
rw.WriteHeader(http.StatusInternalServerError)
_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, "Unknown error")))

return
}

errComposite := &CompositeError{}
errMethodNotAllowed := &MethodNotAllowedError{}
var errError Error

switch {
case errors.As(err, &errComposite):
er := flattenComposite(errComposite)
// strips composite errors to first element only
if len(er.Errors) > 0 {
ServeError(rw, r, er.Errors[0])
} else {
// guard against empty CompositeError (invalid construct)
ServeError(rw, r, nil)

return
}
case *MethodNotAllowedError:
rw.Header().Add("Allow", strings.Join(e.Allowed, ","))
rw.WriteHeader(asHTTPCode(int(e.Code())))

// guard against empty CompositeError (invalid construct)
ServeError(rw, r, nil)

case errors.As(err, &errMethodNotAllowed):
rw.Header().Add("Allow", strings.Join(errMethodNotAllowed.Allowed, ","))
rw.WriteHeader(asHTTPCode(int(errMethodNotAllowed.Code())))
if r == nil || r.Method != http.MethodHead {
_, _ = rw.Write(errorAsJSON(e))
_, _ = rw.Write(errorAsJSON(errMethodNotAllowed))
}
case Error:
value := reflect.ValueOf(e)

case errors.As(err, &errError):
value := reflect.ValueOf(errError)
if value.Kind() == reflect.Ptr && value.IsNil() {
rw.WriteHeader(http.StatusInternalServerError)
_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, "Unknown error")))

return
}
rw.WriteHeader(asHTTPCode(int(e.Code())))

rw.WriteHeader(asHTTPCode(int(errError.Code())))
if r == nil || r.Method != http.MethodHead {
_, _ = rw.Write(errorAsJSON(e))
_, _ = rw.Write(errorAsJSON(errError))
}
case nil:
rw.WriteHeader(http.StatusInternalServerError)
_, _ = rw.Write(errorAsJSON(New(http.StatusInternalServerError, "Unknown error")))

default:
rw.WriteHeader(http.StatusInternalServerError)
if r == nil || r.Method != http.MethodHead {
Expand Down
24 changes: 23 additions & 1 deletion api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestServeError(t *testing.T) {
)
})

t.Run("with composite erors", func(t *testing.T) {
t.Run("with composite errors", func(t *testing.T) {
t.Run("unrecognized - return internal error with first error only - the second error is ignored", func(t *testing.T) {
compositeErr := &CompositeError{
Errors: []error{
Expand Down Expand Up @@ -169,6 +169,28 @@ func TestServeError(t *testing.T) {
)
})

t.Run("check guard against nil members in a CompositeError", func(t *testing.T) {
compositeErr := &CompositeError{
Errors: []error{
New(600, "myApiError"),
nil,
New(601, "myOtherApiError"),
},
}
t.Run("flatten CompositeError should strip nil members", func(t *testing.T) {
flat := flattenComposite(compositeErr)
require.Len(t, flat.Errors, 2)
})

recorder := httptest.NewRecorder()
ServeError(recorder, nil, compositeErr)
assert.Equal(t, CompositeErrorCode, recorder.Code)
assert.JSONEq(t,
`{"code":600,"message":"myApiError"}`,
recorder.Body.String(),
)
})

t.Run("check guard against nil type", func(t *testing.T) {
recorder := httptest.NewRecorder()
ServeError(recorder, nil, nil)
Expand Down
2 changes: 1 addition & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package errors

import "net/http"

// Unauthenticated returns an unauthenticated error
// Unauthenticated returns an unauthenticated error.
func Unauthenticated(scheme string) Error {
return New(http.StatusUnauthorized, "unauthenticated for %s", scheme)
}
14 changes: 7 additions & 7 deletions headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"net/http"
)

// Validation represents a failure of a precondition
type Validation struct { //nolint: errname
// Validation represents a failure of a precondition.
type Validation struct { //nolint: errname // changing the name to abide by the naming rule would bring a breaking change.
code int32
Name string
In string
Expand All @@ -23,12 +23,12 @@ func (e *Validation) Error() string {
return e.message
}

// Code the error code
// Code the error code.
func (e *Validation) Code() int32 {
return e.code
}

// MarshalJSON implements the JSON encoding interface
// MarshalJSON implements the JSON encoding interface.
func (e Validation) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"code": e.code,
Expand All @@ -40,7 +40,7 @@ func (e Validation) MarshalJSON() ([]byte, error) {
})
}

// ValidateName sets the name for a validation or updates it for a nested property
// ValidateName sets the name for a validation or updates it for a nested property.
func (e *Validation) ValidateName(name string) *Validation {
if name != "" {
if e.Name == "" {
Expand All @@ -59,7 +59,7 @@ const (
responseFormatFail = `unsupported media type requested, only %v are available`
)

// InvalidContentType error for an invalid content type
// InvalidContentType error for an invalid content type.
func InvalidContentType(value string, allowed []string) *Validation {
values := make([]any, 0, len(allowed))
for _, v := range allowed {
Expand All @@ -75,7 +75,7 @@ func InvalidContentType(value string, allowed []string) *Validation {
}
}

// InvalidResponseFormat error for an unacceptable response format request
// InvalidResponseFormat error for an unacceptable response format request.
func InvalidResponseFormat(value string, allowed []string) *Validation {
values := make([]any, 0, len(allowed))
for _, v := range allowed {
Expand Down
2 changes: 1 addition & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

// APIVerificationFailed is an error that contains all the missing info for a mismatched section
// between the api registrations and the api spec
// between the api registrations and the api spec.
type APIVerificationFailed struct { //nolint: errname
Section string `json:"section,omitempty"`
MissingSpecification []string `json:"missingSpecification,omitempty"`
Expand Down
8 changes: 4 additions & 4 deletions parsing.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"net/http"
)

// ParseError represents a parsing error
// ParseError represents a parsing error.
type ParseError struct {
code int32
Name string
Expand All @@ -19,7 +19,7 @@ type ParseError struct {
message string
}

// NewParseError creates a new parse error
// NewParseError creates a new parse error.
func NewParseError(name, in, value string, reason error) *ParseError {
var msg string
if in == "" {
Expand All @@ -41,12 +41,12 @@ func (e *ParseError) Error() string {
return e.message
}

// Code returns the http status code for this error
// Code returns the http status code for this error.
func (e *ParseError) Code() int32 {
return e.code
}

// MarshalJSON implements the JSON encoding interface
// MarshalJSON implements the JSON encoding interface.
func (e ParseError) MarshalJSON() ([]byte, error) {
var reason string
if e.Reason != nil {
Expand Down
Loading