Skip to content

Commit

Permalink
feat!: Make IsAggregate work with Go-wrapped errors (#97)
Browse files Browse the repository at this point in the history
Originally missed from #95.

I've done a few other tweaks to existing code, but only is IsAggregate and
Errors are visibly-changed (other methods may be slightly more performant).

I've also added "IsError" tests for aggregated errors and fixed an issue where
our custom IsNotFound logic didn't call `As(any) bool`.
  • Loading branch information
codingllama committed Aug 4, 2023
1 parent a6ba0d5 commit 1cff453
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 55 deletions.
51 changes: 29 additions & 22 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@ import (
"os"
)

// traceDepth is the depth to be used by error constructors.
const traceDepth = 2

// NotFound returns new instance of not found error
func NotFound(message string, args ...interface{}) Error {
return newTrace(&NotFoundError{
Message: fmt.Sprintf(message, args...),
}, traceDepth)
})
}

// NotFoundError indicates that object has not been found
Expand Down Expand Up @@ -76,12 +73,22 @@ func (e *NotFoundError) Is(target error) bool {
// IsNotFound returns true if `e` contains a [NotFoundError] in its chain.
func IsNotFound(e error) bool {
for e != nil {
if _, ok := e.(*NotFoundError); ok {
switch e := e.(type) {
case *NotFoundError:
return true

// Aggregates and other errors.
case interface{ As(interface{}) bool }:
nfe := &NotFoundError{}
if e.As(&nfe) {
return true
}
}

if os.IsNotExist(e) {
return true
}

e = errors.Unwrap(e)
}
return false
Expand All @@ -91,7 +98,7 @@ func IsNotFound(e error) bool {
func AlreadyExists(message string, args ...interface{}) Error {
return newTrace(&AlreadyExistsError{
Message: fmt.Sprintf(message, args...),
}, traceDepth)
})
}

// AlreadyExistsError indicates that there's a duplicate object that already
Expand Down Expand Up @@ -139,7 +146,7 @@ func IsAlreadyExists(e error) bool {
func BadParameter(message string, args ...interface{}) Error {
return newTrace(&BadParameterError{
Message: fmt.Sprintf(message, args...),
}, traceDepth)
})
}

// BadParameterError indicates that something is wrong with passed
Expand Down Expand Up @@ -184,7 +191,7 @@ func IsBadParameter(e error) bool {
func NotImplemented(message string, args ...interface{}) Error {
return newTrace(&NotImplementedError{
Message: fmt.Sprintf(message, args...),
}, traceDepth)
})
}

// NotImplementedError defines an error condition to describe the result
Expand Down Expand Up @@ -229,7 +236,7 @@ func IsNotImplemented(e error) bool {
func CompareFailed(message string, args ...interface{}) Error {
return newTrace(&CompareFailedError{
Message: fmt.Sprintf(message, args...),
}, traceDepth)
})
}

// CompareFailedError indicates a failed comparison (e.g. bad password or hash)
Expand Down Expand Up @@ -277,7 +284,7 @@ func IsCompareFailed(e error) bool {
func AccessDenied(message string, args ...interface{}) Error {
return newTrace(&AccessDeniedError{
Message: fmt.Sprintf(message, args...),
}, traceDepth)
})
}

// AccessDeniedError indicates denied access
Expand Down Expand Up @@ -328,35 +335,35 @@ func ConvertSystemError(err error) error {
if os.IsExist(innerError) {
return newTrace(&AlreadyExistsError{
Message: innerError.Error(),
}, traceDepth)
})
}
if os.IsNotExist(innerError) {
return newTrace(&NotFoundError{
Message: innerError.Error(),
}, traceDepth)
})
}
if os.IsPermission(innerError) {
return newTrace(&AccessDeniedError{
Message: innerError.Error(),
}, traceDepth)
})
}
switch realErr := innerError.(type) {
case *net.OpError:
return newTrace(&ConnectionProblemError{
Err: realErr,
}, traceDepth)
})
case *os.PathError:
message := fmt.Sprintf("failed to execute command %v error: %v", realErr.Path, realErr.Err)
return newTrace(&AccessDeniedError{
Message: message,
}, traceDepth)
})
case x509.SystemRootsError, x509.UnknownAuthorityError:
return newTrace(&TrustError{Err: innerError}, traceDepth)
return newTrace(&TrustError{Err: innerError})
}
if _, ok := innerError.(net.Error); ok {
return newTrace(&ConnectionProblemError{
Err: innerError,
}, traceDepth)
})
}
return err
}
Expand All @@ -366,7 +373,7 @@ func ConnectionProblem(err error, message string, args ...interface{}) Error {
return newTrace(&ConnectionProblemError{
Message: fmt.Sprintf(message, args...),
Err: err,
}, traceDepth)
})
}

// ConnectionProblemError indicates a network related problem
Expand Down Expand Up @@ -422,7 +429,7 @@ func IsConnectionProblem(e error) bool {
func LimitExceeded(message string, args ...interface{}) Error {
return newTrace(&LimitExceededError{
Message: fmt.Sprintf(message, args...),
}, traceDepth)
})
}

// LimitExceededError indicates rate limit or connection limit problem
Expand Down Expand Up @@ -467,7 +474,7 @@ func Trust(err error, message string, args ...interface{}) Error {
return newTrace(&TrustError{
Message: fmt.Sprintf(message, args...),
Err: err,
}, traceDepth)
})
}

// TrustError indicates trust-related validation error (e.g. untrusted cert)
Expand Down Expand Up @@ -525,7 +532,7 @@ func OAuth2(code, message string, query url.Values) Error {
Code: code,
Message: message,
Query: query,
}, traceDepth)
})
}

// OAuth2Error defined an error used in OpenID Connect Flow (OIDC)
Expand Down Expand Up @@ -592,7 +599,7 @@ func Retry(err error, message string, args ...interface{}) Error {
return newTrace(&RetryError{
Message: fmt.Sprintf(message, args...),
Err: err,
}, traceDepth)
})
}

// RetryError indicates a transient error type
Expand Down
3 changes: 2 additions & 1 deletion errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,10 +471,11 @@ func TestGoErrorWrap_IsError_allTypes(t *testing.T) {
err2 := Wrap(err1)
err3 := fmt.Errorf("go wrap: %w", err1)
err4 := fmt.Errorf("go plus trace wrap: %w", err2)
err5 := NewAggregate(errors.New("some other error"), err4)
errUnrelated := fmt.Errorf("go wrap: %w", Wrap(errors.New("unrelated")))

// Verify positive matches.
for _, testErr := range []error{err1, err2, err3, err4} {
for _, testErr := range []error{err1, err2, err3, err4, err5} {
if !test.isError(testErr) {
t.Errorf("Is%v failed, err=%#v", test.name, testErr)
}
Expand Down
48 changes: 29 additions & 19 deletions trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func Wrap(err error, args ...interface{}) Error {
if traceErr, ok := err.(Error); ok {
trace = traceErr
} else {
trace = newTrace(err, 2)
trace = newTrace(err)
}
if len(args) > 0 {
trace = WithUserMessage(trace, args[0], args[1:]...)
Expand Down Expand Up @@ -178,7 +178,7 @@ func WrapWithMessage(err error, message interface{}, args ...interface{}) Error
if traceErr, ok := err.(Error); ok {
trace = traceErr
} else {
trace = newTrace(err, 2)
trace = newTrace(err)
}
return WithUserMessage(trace, message, args...)
}
Expand All @@ -188,7 +188,7 @@ func WrapWithMessage(err error, message interface{}, args ...interface{}) Error
// callee, line number and function that simplifies debugging
func Errorf(format string, args ...interface{}) (err error) {
err = fmt.Errorf(format, args...)
return newTrace(err, 2)
return newTrace(err)
}

// Fatalf - If debug is false Fatalf calls Errorf. If debug is
Expand All @@ -201,7 +201,15 @@ func Fatalf(format string, args ...interface{}) error {
}
}

func newTrace(err error, depth int) *TraceErr {
func newTrace(err error) *TraceErr {
// newTrace does not call newTraceWithDepth so the depth value is consistent
// between both methods.
const depth = 2
traces := internal.CaptureTraces(depth)
return &TraceErr{Err: err, Traces: traces}
}

func newTraceWithDepth(err error, depth int) *TraceErr {
traces := internal.CaptureTraces(depth)
return &TraceErr{Err: err, Traces: traces}
}
Expand Down Expand Up @@ -427,8 +435,7 @@ func WithFields(err Error, fields map[string]interface{}) *TraceErr {
// NewAggregate creates a new aggregate instance from the specified
// list of errors
func NewAggregate(errs ...error) error {
// filter out possible nil values
var nonNils []error
nonNils := make([]error, 0, len(errs))
for _, err := range errs {
if err != nil {
nonNils = append(nonNils, err)
Expand All @@ -437,7 +444,7 @@ func NewAggregate(errs ...error) error {
if len(nonNils) == 0 {
return nil
}
return newTrace(aggregate(nonNils), 2)
return newTrace(aggregate(nonNils))
}

// NewAggregateFromChannel creates a new aggregate instance from the provided
Expand Down Expand Up @@ -476,14 +483,14 @@ type aggregate []error

// Error implements the error interface
func (r aggregate) Error() string {
if len(r) == 0 {
return ""
}
output := r[0].Error()
for i := 1; i < len(r); i++ {
output = fmt.Sprintf("%v, %v", output, r[i])
buf := &strings.Builder{}
for i, e := range r {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(e.Error())
}
return output
return buf.String()
}

// Is implements the `Is` interface, by iterating through each error in the
Expand All @@ -510,13 +517,16 @@ func (r aggregate) As(t interface{}) bool {

// Errors obtains the list of errors this aggregate combines
func (r aggregate) Errors() []error {
return []error(r)
cp := make([]error, len(r))
copy(cp, r)
return cp
}

// IsAggregate returns whether this error of Aggregate error type
// IsAggregate returns true if `err` contains an [Aggregate] error in its
// chain.
func IsAggregate(err error) bool {
_, ok := Unwrap(err).(Aggregate)
return ok
var other Aggregate
return errors.As(err, &other)
}

// wrapProxy wraps the specified error as a new error trace
Expand All @@ -526,7 +536,7 @@ func wrapProxy(err error) Error {
}
return proxyError{
// Do not include ReadError in the trace
TraceErr: newTrace(err, 3),
TraceErr: newTraceWithDepth(err, 3),
}
}

Expand Down
Loading

0 comments on commit 1cff453

Please sign in to comment.