diff --git a/go/pkg/errcode/error.go b/go/pkg/errcode/error.go index 283f701829..c38211c438 100644 --- a/go/pkg/errcode/error.go +++ b/go/pkg/errcode/error.go @@ -7,7 +7,7 @@ type WithCode interface { Code() int32 } -// Code returns the code of the +// Code returns the code of the actual error without trying to unwrap it, or -1. func Code(err error) int32 { typed, ok := err.(WithCode) if ok { @@ -16,6 +16,53 @@ func Code(err error) int32 { return -1 } +// LastCode walks the passed error and returns the code of the latest ErrCode, or -1. +func LastCode(err error) int32 { + if err == nil { + return -1 + } + + if cause := genericCause(err); cause != nil { + if ret := LastCode(cause); ret != -1 { + return ret + } + } + + return Code(err) +} + +// FirstCode walks the passed error and returns the code of the first ErrCode met, or -1. +func FirstCode(err error) int32 { + if err == nil { + return -1 + } + + if code := Code(err); code != -1 { + return code + } + + if cause := genericCause(err); cause != nil { + return FirstCode(cause) + } + + return -1 +} + +func genericCause(err error) error { + type causer interface{ Cause() error } + type wrapper interface{ Unwrap() error } + + if causer, ok := err.(causer); ok { + return causer.Cause() + } + + if wrapper, ok := err.(wrapper); ok { + return wrapper.Unwrap() + } + + return nil +} + // // Error // diff --git a/go/pkg/errcode/error_test.go b/go/pkg/errcode/error_test.go index 32b185f2f6..1bb7c67b0c 100644 --- a/go/pkg/errcode/error_test.go +++ b/go/pkg/errcode/error_test.go @@ -21,72 +21,94 @@ func TestError(t *testing.T) { errCodeUndef = ErrCode(65530) // simulate a client receiving an error generated from a more recent API ) var tests = []struct { - name string - input error - expectedString string - expectedCode int32 - expectedCause error + name string + input error + expectedString string + expectedCause error + expectedCode int32 + expectedFirstCode int32 + expectedLastCode int32 }{ { "ErrNotImplemented", ErrNotImplemented, "ErrNotImplemented(#777)", - 777, ErrNotImplemented, + 777, + 777, + 777, }, { "ErrInternal", ErrInternal, "ErrInternal(#999)", - 999, ErrInternal, + 999, + 999, + 999, }, { "ErrNotImplemented.Wrap(errStdHello)", ErrNotImplemented.Wrap(errStdHello), "ErrNotImplemented(#777): hello", - 777, errStdHello, + 777, + 777, + 777, }, { "ErrNotImplemented.Wrap(ErrInternal)", ErrNotImplemented.Wrap(ErrInternal), "ErrNotImplemented(#777): ErrInternal(#999)", - 777, ErrInternal, + 777, + 777, + 999, }, { "ErrNotImplemented.Wrap(ErrInternal.Wrap(errStdHello))", ErrNotImplemented.Wrap(ErrInternal.Wrap(errStdHello)), "ErrNotImplemented(#777): ErrInternal(#999): hello", - 777, errStdHello, + 777, + 777, + 999, }, { - `errors.Wrap(ErrNotImplemented, "blah")`, + `errors.Wrap(ErrNotImplemented,blah)`, errors.Wrap(ErrNotImplemented, "blah"), "blah: ErrNotImplemented(#777)", - -1, ErrNotImplemented, + -1, + 777, + 777, }, { - `errors.Wrap(ErrNotImplemented.Wrap(ErrInternal), "blah")`, + `errors.Wrap(ErrNotImplemented.Wrap(ErrInternal),blah)`, errors.Wrap(ErrNotImplemented.Wrap(ErrInternal), "blah"), "blah: ErrNotImplemented(#777): ErrInternal(#999)", - -1, ErrInternal, + -1, + 777, + 999, }, { "nil", nil, "", - -1, nil, + -1, + -1, + -1, }, { "errStdHello", errStdHello, "hello", - -1, errStdHello, + -1, + -1, + -1, }, { "errCodeUndef", errCodeUndef, "UNKNOWN_ERRCODE(#65530)", - 65530, errCodeUndef, + 65530, + 65530, + 65530, }, } @@ -102,6 +124,16 @@ func TestError(t *testing.T) { t.Errorf("Expected code to be %d, got %d.", test.expectedCode, actualCode) } + actualCode = FirstCode(test.input) + if test.expectedFirstCode != actualCode { + t.Errorf("Expected first-code to be %d, got %d.", test.expectedCode, actualCode) + } + + actualCode = LastCode(test.input) + if test.expectedLastCode != actualCode { + t.Errorf("Expected last-code to be %d, got %d.", test.expectedCode, actualCode) + } + actualCause := errors.Cause(test.input) if test.expectedCause != actualCause { t.Errorf("Expected cause to be %v, got %v.", test.expectedCause, actualCause)