Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions markers/markers.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ func Is(err, reference error) bool {
return err == nil
}

isComparable := reflect.TypeOf(reference).Comparable()

// Direct reference comparison is the fastest, and most
// likely to be true, so do this first.
for c := err; c != nil; c = errbase.UnwrapOnce(c) {
if c == reference {
if isComparable && c == reference {
return true
}
// Compatibility with std go errors: if the error object itself
Expand Down Expand Up @@ -141,10 +143,27 @@ func If(err error, pred func(err error) (interface{}, bool)) (interface{}, bool)
// package location or a different type, ensure that
// RegisterTypeMigration() was called prior to IsAny().
func IsAny(err error, references ...error) bool {
if err == nil {
for _, refErr := range references {
if refErr == nil {
return true
}
}
// The mark-based comparison below will never match anything if
// the error is nil, so don't bother with computing the marks in
// that case. This avoids the computational expense of computing
// the reference marks upfront.
return false
}

// First try using direct reference comparison.
for c := err; ; c = errbase.UnwrapOnce(c) {
for c := err; c != nil; c = errbase.UnwrapOnce(c) {
for _, refErr := range references {
if c == refErr {
if refErr == nil {
continue
}
isComparable := reflect.TypeOf(refErr).Comparable()
if isComparable && c == refErr {
return true
}
// Compatibility with std go errors: if the error object itself
Expand All @@ -153,19 +172,6 @@ func IsAny(err error, references ...error) bool {
return true
}
}
if c == nil {
// This special case is to support a comparison to a nil
// reference.
break
}
}

if err == nil {
// The mark-based comparison below will never match anything if
// the error is nil, so don't bother with computing the marks in
// that case. This avoids the computational expense of computing
// the reference marks upfront.
return false
}

// Try harder with marks.
Expand Down
39 changes: 39 additions & 0 deletions markers/markers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,42 @@ func (e *errWithIs) Is(o error) bool {
}
return false
}

func TestCompareUncomparable(t *testing.T) {
tt := testutils.T{T: t}

err1 := errors.New("hello")
var nilErr error
f := []string{"woo"}
tt.Check(markers.Is(errorUncomparable{f}, errorUncomparable{}))
tt.Check(markers.IsAny(errorUncomparable{f}, errorUncomparable{}))
tt.Check(markers.IsAny(errorUncomparable{f}, nilErr, errorUncomparable{}))
tt.Check(!markers.Is(errorUncomparable{f}, &errorUncomparable{}))
tt.Check(!markers.IsAny(errorUncomparable{f}, &errorUncomparable{}))
tt.Check(!markers.IsAny(errorUncomparable{f}, nilErr, &errorUncomparable{}))
tt.Check(markers.Is(&errorUncomparable{f}, errorUncomparable{}))
tt.Check(markers.IsAny(&errorUncomparable{f}, errorUncomparable{}))
tt.Check(markers.IsAny(&errorUncomparable{f}, nilErr, errorUncomparable{}))
tt.Check(!markers.Is(&errorUncomparable{f}, &errorUncomparable{}))
tt.Check(!markers.IsAny(&errorUncomparable{f}, &errorUncomparable{}))
tt.Check(!markers.IsAny(&errorUncomparable{f}, nilErr, &errorUncomparable{}))
tt.Check(!markers.Is(errorUncomparable{f}, err1))
tt.Check(!markers.IsAny(errorUncomparable{f}, err1))
tt.Check(!markers.IsAny(errorUncomparable{f}, nilErr, err1))
tt.Check(!markers.Is(&errorUncomparable{f}, err1))
tt.Check(!markers.IsAny(&errorUncomparable{f}, err1))
tt.Check(!markers.IsAny(&errorUncomparable{f}, nilErr, err1))
}

type errorUncomparable struct {
f []string
}

func (e errorUncomparable) Error() string {
return fmt.Sprintf("uncomparable error %d", len(e.f))
}

func (errorUncomparable) Is(target error) bool {
_, ok := target.(errorUncomparable)
return ok
}