From 509daddce11c4769d4d6eae34d7ec6afef8d81d5 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Tue, 20 Sep 2022 16:11:00 +0200 Subject: [PATCH] markers: avoid panic on non-comparable structs If an error struct implements `error` by value and the struct is incomparable, the previous implementation of `Is` would panic. This patch fixes it. Inspired from https://go-review.googlesource.com/c/go/+/175260 --- markers/markers.go | 38 ++++++++++++++++++++++---------------- markers/markers_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 16 deletions(-) diff --git a/markers/markers.go b/markers/markers.go index f718d10..d270125 100644 --- a/markers/markers.go +++ b/markers/markers.go @@ -46,10 +46,12 @@ func Is(err, reference error) bool { return err == nil } + isComparable := reflect.TypeOf(reference).Comparable() + // Direct reference comparison is the fastest, and most // likely to be true, so do this first. for c := err; c != nil; c = errbase.UnwrapOnce(c) { - if c == reference { + if isComparable && c == reference { return true } // Compatibility with std go errors: if the error object itself @@ -141,10 +143,27 @@ func If(err error, pred func(err error) (interface{}, bool)) (interface{}, bool) // package location or a different type, ensure that // RegisterTypeMigration() was called prior to IsAny(). func IsAny(err error, references ...error) bool { + if err == nil { + for _, refErr := range references { + if refErr == nil { + return true + } + } + // The mark-based comparison below will never match anything if + // the error is nil, so don't bother with computing the marks in + // that case. This avoids the computational expense of computing + // the reference marks upfront. + return false + } + // First try using direct reference comparison. - for c := err; ; c = errbase.UnwrapOnce(c) { + for c := err; c != nil; c = errbase.UnwrapOnce(c) { for _, refErr := range references { - if c == refErr { + if refErr == nil { + continue + } + isComparable := reflect.TypeOf(refErr).Comparable() + if isComparable && c == refErr { return true } // Compatibility with std go errors: if the error object itself @@ -153,19 +172,6 @@ func IsAny(err error, references ...error) bool { return true } } - if c == nil { - // This special case is to support a comparison to a nil - // reference. - break - } - } - - if err == nil { - // The mark-based comparison below will never match anything if - // the error is nil, so don't bother with computing the marks in - // that case. This avoids the computational expense of computing - // the reference marks upfront. - return false } // Try harder with marks. diff --git a/markers/markers_test.go b/markers/markers_test.go index 773d994..28ecf13 100644 --- a/markers/markers_test.go +++ b/markers/markers_test.go @@ -599,3 +599,42 @@ func (e *errWithIs) Is(o error) bool { } return false } + +func TestCompareUncomparable(t *testing.T) { + tt := testutils.T{T: t} + + err1 := errors.New("hello") + var nilErr error + f := []string{"woo"} + tt.Check(markers.Is(errorUncomparable{f}, errorUncomparable{})) + tt.Check(markers.IsAny(errorUncomparable{f}, errorUncomparable{})) + tt.Check(markers.IsAny(errorUncomparable{f}, nilErr, errorUncomparable{})) + tt.Check(!markers.Is(errorUncomparable{f}, &errorUncomparable{})) + tt.Check(!markers.IsAny(errorUncomparable{f}, &errorUncomparable{})) + tt.Check(!markers.IsAny(errorUncomparable{f}, nilErr, &errorUncomparable{})) + tt.Check(markers.Is(&errorUncomparable{f}, errorUncomparable{})) + tt.Check(markers.IsAny(&errorUncomparable{f}, errorUncomparable{})) + tt.Check(markers.IsAny(&errorUncomparable{f}, nilErr, errorUncomparable{})) + tt.Check(!markers.Is(&errorUncomparable{f}, &errorUncomparable{})) + tt.Check(!markers.IsAny(&errorUncomparable{f}, &errorUncomparable{})) + tt.Check(!markers.IsAny(&errorUncomparable{f}, nilErr, &errorUncomparable{})) + tt.Check(!markers.Is(errorUncomparable{f}, err1)) + tt.Check(!markers.IsAny(errorUncomparable{f}, err1)) + tt.Check(!markers.IsAny(errorUncomparable{f}, nilErr, err1)) + tt.Check(!markers.Is(&errorUncomparable{f}, err1)) + tt.Check(!markers.IsAny(&errorUncomparable{f}, err1)) + tt.Check(!markers.IsAny(&errorUncomparable{f}, nilErr, err1)) +} + +type errorUncomparable struct { + f []string +} + +func (e errorUncomparable) Error() string { + return fmt.Sprintf("uncomparable error %d", len(e.f)) +} + +func (errorUncomparable) Is(target error) bool { + _, ok := target.(errorUncomparable) + return ok +}