From 35263d8babc7ae7fefa6f2170d4288f693c20e66 Mon Sep 17 00:00:00 2001 From: Jeff Swenson Date: Fri, 3 Oct 2025 15:00:54 -0400 Subject: [PATCH 1/2] benchmark: add a benchmark for the errors package This is a benchmark created by @rafi. It demonstrates `errors.Is` is very inefficient when the reference error does not match the input error. --- benchmark_test.go | 181 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 benchmark_test.go diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000..0c403fc --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,181 @@ +package errors_test + +import ( + "context" + "fmt" + "net" + "testing" + + "github.com/cockroachdb/errors" +) + +func BenchmarkErrorsIs(b *testing.B) { + b.Run("NilError", func(b *testing.B) { + var err error + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("SimpleError", func(b *testing.B) { + err := errors.New("test") + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("WrappedError", func(b *testing.B) { + baseErr := errors.New("test") + err := errors.Wrap(baseErr, "wrapped error") + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("WrappedWithStack", func(b *testing.B) { + baseErr := errors.New("test") + err := errors.WithStack(baseErr) + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("NetworkError", func(b *testing.B) { + netErr := &net.OpError{ + Op: "dial", + Net: "tcp", + Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 26257}, + Err: fmt.Errorf("connection refused"), + } + err := errors.Wrap(netErr, "network connection failed") + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("DeeplyWrappedNetworkError", func(b *testing.B) { + netErr := &net.OpError{ + Op: "dial", + Net: "tcp", + Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 26257}, + Err: fmt.Errorf("connection refused"), + } + err := errors.WithStack(netErr) + err = errors.Wrap(err, "failed to connect to database") + err = errors.Wrap(err, "unable to establish connection") + err = errors.WithStack(err) + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("MultipleWrappedErrors", func(b *testing.B) { + baseErr := errors.New("internal error") + err := errors.WithStack(baseErr) + err = errors.Wrap(err, "operation failed") + err = errors.WithStack(err) + err = errors.Wrap(err, "transaction failed") + err = errors.WithStack(err) + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("NetworkErrorWithLongAddress", func(b *testing.B) { + netErr := &net.OpError{ + Op: "read", + Net: "tcp", + Addr: &net.TCPAddr{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Port: 26257, + }, + Err: fmt.Errorf("i/o timeout"), + } + err := errors.WithStack(netErr) + err = errors.Wrap(err, "failed to read from connection") + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("WithMessage", func(b *testing.B) { + baseErr := errors.New("test") + err := errors.WithMessage(baseErr, "additional context") + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("MultipleWithMessage", func(b *testing.B) { + baseErr := errors.New("internal error") + err := errors.WithMessage(baseErr, "first message") + err = errors.WithMessage(err, "second message") + err = errors.WithMessage(err, "third message") + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("WithMessageAndStack", func(b *testing.B) { + baseErr := errors.New("test") + err := errors.WithStack(baseErr) + err = errors.WithMessage(err, "operation context") + err = errors.WithStack(err) + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("NetworkErrorWithMessage", func(b *testing.B) { + netErr := &net.OpError{ + Op: "dial", + Net: "tcp", + Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 26257}, + Err: fmt.Errorf("connection refused"), + } + err := errors.WithMessage(netErr, "database connection failed") + err = errors.WithMessage(err, "unable to reach server") + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("NetworkErrorWithEverything", func(b *testing.B) { + netErr := &net.OpError{ + Op: "dial", + Net: "tcp", + Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 26257}, + Err: fmt.Errorf("connection refused"), + } + err := errors.WithStack(netErr) + err = errors.WithMessage(err, "database connection failed") + err = errors.Wrap(err, "failed to establish TCP connection") + err = errors.WithStack(err) + err = errors.WithMessage(err, "unable to reach CockroachDB server") + err = errors.Wrap(err, "connection attempt failed") + for range b.N { + errors.Is(err, context.Canceled) + } + }) + + b.Run("DeeplyNested100Levels", func(b *testing.B) { + baseErr := errors.New("base error") + err := baseErr + + // Create a 100-level deep error chain + for i := 0; i < 100; i++ { + switch i % 3 { + case 0: + err = errors.Wrap(err, fmt.Sprintf("wrap level %d", i)) + case 1: + err = errors.WithMessage(err, fmt.Sprintf("message level %d", i)) + case 2: + err = errors.WithStack(err) + } + } + + for range b.N { + errors.Is(err, context.Canceled) + } + }) +} From 3a21e3db17636866852baa8b4637a6295681f808 Mon Sep 17 00:00:00 2001 From: Jeff Swenson Date: Mon, 6 Oct 2025 14:12:21 -0400 Subject: [PATCH 2/2] errbase: optimize errors.Is Previously, `errors.Is` was very inefficient if the reference error did not match any errors in the chain. There were two significant sources of inefficiency: 1. The code would pessimistically construct the error mark for every error in the input error chain. This is O(chain_length^2). This was a lot of unnecessary allocations and caused the runtime to be O(n^2) in the average case instead of O(n). 2. The code compared the `Error()` message before comparing the types. It is possible to compare error types with zero allocations whereas computing the message often requires an allocation. --- errbase/encode.go | 28 +++++++++-- markers/markers.go | 114 ++++++++++++++++----------------------------- 2 files changed, 65 insertions(+), 77 deletions(-) diff --git a/errbase/encode.go b/errbase/encode.go index 418cc50..61ea0bb 100644 --- a/errbase/encode.go +++ b/errbase/encode.go @@ -305,6 +305,30 @@ func GetTypeMark(err error) errorspb.ErrorTypeMark { return errorspb.ErrorTypeMark{FamilyName: familyName, Extension: extension} } +// EqualTypeMark checks whether `GetTypeMark(e1).Equals(GetTypeMark(e2))`. It +// is written to be be optimized for the case where neither error has +// serialized type information. +func EqualTypeMark(e1, e2 error) bool { + slowPath := func(err error) bool { + switch err.(type) { + case *opaqueLeaf: + return true + case *opaqueLeafCauses: + return true + case *opaqueWrapper: + return true + case TypeKeyMarker: + return true + } + return false + } + if slowPath(e1) || slowPath(e2) { + return GetTypeMark(e1).Equals(GetTypeMark(e2)) + } + + return reflect.TypeOf(e1) == reflect.TypeOf(e2) +} + // RegisterLeafEncoder can be used to register new leaf error types to // the library. Registered types will be encoded using their own // Go type when an error is encoded. Wrappers that have not been @@ -385,9 +409,7 @@ func RegisterWrapperEncoder(theType TypeKey, encoder WrapperEncoder) { // Note: if the error type has been migrated from a previous location // or a different type, ensure that RegisterTypeMigration() was called // prior to RegisterWrapperEncoder(). -func RegisterWrapperEncoderWithMessageType( - theType TypeKey, encoder WrapperEncoderWithMessageType, -) { +func RegisterWrapperEncoderWithMessageType(theType TypeKey, encoder WrapperEncoderWithMessageType) { if encoder == nil { delete(encoders, theType) } else { diff --git a/markers/markers.go b/markers/markers.go index 3f81794..50a16e8 100644 --- a/markers/markers.go +++ b/markers/markers.go @@ -68,29 +68,42 @@ func Is(err, reference error) bool { } } - if err == nil { - // Err is nil and reference is non-nil, so it cannot match. We - // want to short-circuit the loop below in this case, otherwise - // we're paying the expense of getMark() without need. - return false - } - - // Not directly equal. Try harder, using error marks. We don't do - // this during the loop above as it may be more expensive. - // - // Note: there is a more effective recursive algorithm that ensures - // that any pair of string only gets compared once. Should the - // following code become a performance bottleneck, that algorithm - // can be considered instead. - refMark := getMark(reference) - for c := err; c != nil; c = errbase.UnwrapOnce(c) { - if equalMarks(getMark(c), refMark) { + for errNext := err; errNext != nil; errNext = errbase.UnwrapOnce(errNext) { + if isMarkEqual(errNext, reference) { return true } } + return false } +func isMarkEqual(err, reference error) bool { + _, errIsMark := err.(*withMark) + _, refIsMark := reference.(*withMark) + if errIsMark || refIsMark { + // If either error is a mark, use the more general + // equalMarks() function. + return equalMarks(getMark(err), getMark(reference)) + } + + m1 := err + m2 := reference + for m1 != nil && m2 != nil { + if !errbase.EqualTypeMark(m1, m2) { + return false + } + m1 = errbase.UnwrapOnce(m1) + m2 = errbase.UnwrapOnce(m2) + } + + // The two chains have different lengths, so they cannot be equal. + if m1 != nil || m2 != nil { + return false + } + + return safeGetErrMsg(err) == safeGetErrMsg(reference) +} + func tryDelegateToIsMethod(err, reference error) bool { if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(reference) { return true @@ -150,62 +163,9 @@ func If(err error, pred func(err error) (interface{}, bool)) (interface{}, bool) // package location or a different type, ensure that // RegisterTypeMigration() was called prior to IsAny(). func IsAny(err error, references ...error) bool { - if err == nil { - for _, refErr := range references { - if refErr == nil { - return true - } - } - // The mark-based comparison below will never match anything if - // the error is nil, so don't bother with computing the marks in - // that case. This avoids the computational expense of computing - // the reference marks upfront. - return false - } - - // First try using direct reference comparison. - for c := err; c != nil; c = errbase.UnwrapOnce(c) { - for _, refErr := range references { - if refErr == nil { - continue - } - isComparable := reflect.TypeOf(refErr).Comparable() - if isComparable && c == refErr { - return true - } - // Compatibility with std go errors: if the error object itself - // implements Is(), try to use that. - if tryDelegateToIsMethod(c, refErr) { - return true - } - } - - // Recursively try multi-error causes, if applicable. - for _, me := range errbase.UnwrapMulti(c) { - if IsAny(me, references...) { - return true - } - } - } - - // Try harder with marks. - // Note: there is a more effective recursive algorithm that ensures - // that any pair of string only gets compared once. Should this - // become a performance bottleneck, that algorithm can be considered - // instead. - refMarks := make([]errorMark, 0, len(references)) - for _, refErr := range references { - if refErr == nil { - continue - } - refMarks = append(refMarks, getMark(refErr)) - } - for c := err; c != nil; c = errbase.UnwrapOnce(c) { - errMark := getMark(c) - for _, refMark := range refMarks { - if equalMarks(errMark, refMark) { - return true - } + for _, reference := range references { + if Is(err, reference) { + return true } } return false @@ -221,6 +181,9 @@ func equalMarks(m1, m2 errorMark) bool { if m1.msg != m2.msg { return false } + if len(m1.types) != len(m2.types) { + return false + } for i, t := range m1.types { if !t.Equals(m2.types[i]) { return false @@ -234,7 +197,10 @@ func getMark(err error) errorMark { if m, ok := err.(*withMark); ok { return m.mark } - m := errorMark{msg: safeGetErrMsg(err), types: []errorspb.ErrorTypeMark{errbase.GetTypeMark(err)}} + m := errorMark{ + msg: safeGetErrMsg(err), + types: []errorspb.ErrorTypeMark{errbase.GetTypeMark(err)}, + } for c := errbase.UnwrapOnce(err); c != nil; c = errbase.UnwrapOnce(c) { m.types = append(m.types, errbase.GetTypeMark(c)) }