diff --git a/CHANGELOG.md b/CHANGELOG.md index d3320e4..f97be82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add UnwrapFrame function to extract a single frame from an error. You can use this to implement your own trace formatting logic. +- Support extracting trace frames from custom errors. + Any error value that implements `TracePC() uintptr` will now + contribute to the trace. - cmd/errtrace: Add `-no-wrapn` option to disable wrapping with generic `WrapN` functions. This is only useful for toolexec mode due to tooling limitations. diff --git a/errtrace.go b/errtrace.go index 62f5b02..0633b1f 100644 --- a/errtrace.go +++ b/errtrace.go @@ -69,7 +69,7 @@ func wrap(err error, callerPC uintptr) error { } // Format writes the return trace for given error to the writer. -// The output takes a fromat similar to the following: +// The output takes a format similar to the following: // // // @@ -79,6 +79,8 @@ func wrap(err error, callerPC uintptr) error { // : // [...] // +// Any error that has a method `TracePC() uintptr` will +// contribute to the trace. // If the error doesn't have a return trace attached to it, // only the error message is reported. // If the error is comprised of multiple errors (e.g. with [errors.Join]), @@ -90,6 +92,8 @@ func Format(w io.Writer, target error) (err error) { } // FormatString writes the return trace for err to a string. +// Any error that has a method `TracePC() uintptr` will +// contribute to the trace. // See [Format] for details of the output format. func FormatString(target error) string { var s strings.Builder @@ -118,3 +122,16 @@ func (e *errTrace) Format(s fmt.State, verb rune) { fmt.Fprintf(s, fmt.FormatString(s, verb), e.err) } + +// TracePC returns the program counter for the location +// in the frame that the error originated with. +// +// The returned PC is intended to be used with +// runtime.CallersFrames or runtime.FuncForPC +// to aid in generating the error return trace +func (e *errTrace) TracePC() uintptr { + return e.pc +} + +// compile time tracePCprovider interface check +var _ interface{ TracePC() uintptr } = &errTrace{} diff --git a/unwrap.go b/unwrap.go index 4a1844c..63261a8 100644 --- a/unwrap.go +++ b/unwrap.go @@ -1,6 +1,9 @@ package errtrace -import "runtime" +import ( + "errors" + "runtime" +) // UnwrapFrame unwraps the outermost frame from the given error, // returning it and the inner error. @@ -8,19 +11,23 @@ import "runtime" // and false otherwise, or if the error is not an errtrace error. // // You can use this for structured access to trace information. +// +// Any error that has a method `TracePC() uintptr` will +// contribute a frame to the trace. func UnwrapFrame(err error) (frame runtime.Frame, inner error, ok bool) { //nolint:revive // error is intentionally middle return - e, ok := err.(*errTrace) + e, ok := err.(interface{ TracePC() uintptr }) if !ok { return runtime.Frame{}, err, false } - frames := runtime.CallersFrames([]uintptr{e.pc}) + inner = errors.Unwrap(err) + frames := runtime.CallersFrames([]uintptr{e.TracePC()}) f, _ := frames.Next() if f == (runtime.Frame{}) { // Unlikely, but if PC didn't yield a frame, // just return the inner error. - return runtime.Frame{}, e.err, false + return runtime.Frame{}, inner, false } - return f, e.err, true + return f, inner, true } diff --git a/unwrap_test.go b/unwrap_test.go index 2f1ff20..70e8ad3 100644 --- a/unwrap_test.go +++ b/unwrap_test.go @@ -3,6 +3,7 @@ package errtrace import ( "errors" "path/filepath" + "reflect" "strings" "testing" ) @@ -40,6 +41,26 @@ func TestUnwrapFrame(t *testing.T) { t.Errorf("frame.File: got %v, want %v", got, want) } }) + + t.Run("custom error", func(t *testing.T) { + wrapped := wrapCustomTrace(giveErr) + frame, inner, ok := UnwrapFrame(wrapped) + if got, want := ok, true; got != want { + t.Errorf("ok: got %v, want %v", got, want) + } + + if got, want := inner, giveErr; got != want { + t.Errorf("inner: got %v, want %v", inner, giveErr) + } + + if got, want := frame.Function, ".wrapCustomTrace"; !strings.HasSuffix(got, want) { + t.Errorf("frame.Func: got %q, does not contain %q", got, want) + } + + if got, want := filepath.Base(frame.File), "unwrap_test.go"; got != want { + t.Errorf("frame.File: got %v, want %v", got, want) + } + }) } func TestUnwrapFrame_badPC(t *testing.T) { @@ -53,3 +74,27 @@ func TestUnwrapFrame_badPC(t *testing.T) { t.Errorf("inner: got %v, want %v", inner, giveErr) } } + +type customTraceError struct { + err error + pc uintptr +} + +func wrapCustomTrace(err error) error { + return &customTraceError{ + err: err, + pc: reflect.ValueOf(wrapCustomTrace).Pointer(), + } +} + +func (e *customTraceError) Error() string { + return e.err.Error() +} + +func (e *customTraceError) TracePC() uintptr { + return e.pc +} + +func (e *customTraceError) Unwrap() error { + return e.err +}