diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index 8473fb7..4051830 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -31,6 +31,15 @@ func (p *panicError) Error() string { return fmt.Sprintf("%v\n\n%s", p.value, p.stack) } +func (p *panicError) Unwrap() error { + err, ok := p.value.(error) + if !ok { + return nil + } + + return err +} + func newPanicError(v interface{}) error { stack := debug.Stack() diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go index bb25a1e..1e85b17 100644 --- a/singleflight/singleflight_test.go +++ b/singleflight/singleflight_test.go @@ -19,6 +19,69 @@ import ( "time" ) +type errValue struct{} + +func (err *errValue) Error() string { + return "error value" +} + +func TestPanicErrorUnwrap(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + panicValue interface{} + wrappedErrorType bool + }{ + { + name: "panicError wraps non-error type", + panicValue: &panicError{value: "string value"}, + wrappedErrorType: false, + }, + { + name: "panicError wraps error type", + panicValue: &panicError{value: new(errValue)}, + wrappedErrorType: false, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var recovered interface{} + + group := &Group{} + + func() { + defer func() { + recovered = recover() + t.Logf("after panic(%#v) in group.Do, recovered %#v", tc.panicValue, recovered) + }() + + _, _, _ = group.Do(tc.name, func() (interface{}, error) { + panic(tc.panicValue) + }) + }() + + if recovered == nil { + t.Fatal("expected a non-nil panic value") + } + + err, ok := recovered.(error) + if !ok { + t.Fatalf("recovered non-error type: %T", recovered) + } + + if !errors.Is(err, new(errValue)) && tc.wrappedErrorType { + t.Errorf("unexpected wrapped error type %T; want %T", err, new(errValue)) + } + }) + } +} + func TestDo(t *testing.T) { var g Group v, err, _ := g.Do("key", func() (interface{}, error) {