From b50119bc378cb32851e1a1bbfcafc5b5394d2023 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Wed, 9 Dec 2020 23:45:20 +0800 Subject: [PATCH] fix(mono): add panic recover for most of Mono operations (#30) * fix(mono): add panic recover for most of the Mono operations * fix: lint --- flux/op_filter.go | 17 +- internal/misc.go | 15 -- internal/misc_test.go | 8 - internal/subscribers/switch_if_empty.go | 14 +- internal/subscribers/switch_value_if_error.go | 4 +- mono/mono_create.go | 25 ++- mono/mono_just.go | 13 +- mono/mono_test.go | 4 +- mono/mono_zip_test.go | 38 ++++ mono/op_filter.go | 11 +- mono/op_flatmap.go | 28 ++- mono/op_map.go | 18 +- mono/op_peek.go | 11 +- mono/op_switch_if_error.go | 24 +- mono/panic_test.go | 210 ++++++++++++++++++ 15 files changed, 386 insertions(+), 54 deletions(-) create mode 100644 mono/panic_test.go diff --git a/flux/op_filter.go b/flux/op_filter.go index 3b48f34..fbe3af6 100644 --- a/flux/op_filter.go +++ b/flux/op_filter.go @@ -5,6 +5,7 @@ import ( "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/internal" + "github.com/pkg/errors" ) type fluxFilter struct { @@ -42,11 +43,23 @@ func (p *filterSubscriber) OnError(err error) { } func (p *filterSubscriber) OnNext(v Any) { + if p.f == nil { + p.OnError(errors.New("the Filter predicate is nil")) + return + } + defer func() { - if err := internal.TryRecoverError(recover()); err != nil { - p.OnError(err) + rec := recover() + if rec == nil { + return + } + if e, ok := rec.(error); ok { + p.OnError(errors.WithStack(e)) + } else { + p.OnError(errors.Errorf("%v", rec)) } }() + if p.f(v) { p.actual.OnNext(v) return diff --git a/internal/misc.go b/internal/misc.go index 467a53f..950b4b4 100644 --- a/internal/misc.go +++ b/internal/misc.go @@ -2,26 +2,11 @@ package internal import ( "errors" - "fmt" ) var ErrCallOnSubscribeDuplicated = errors.New("call OnSubscribe duplicated") var EmptySubscription = emptySubscription{} -func TryRecoverError(re interface{}) error { - if re == nil { - return nil - } - switch e := re.(type) { - case error: - return e - case string: - return errors.New(e) - default: - return fmt.Errorf("%s", e) - } -} - type emptySubscription struct { } diff --git a/internal/misc_test.go b/internal/misc_test.go index fc9747c..95a9ba5 100644 --- a/internal/misc_test.go +++ b/internal/misc_test.go @@ -1,20 +1,12 @@ package internal_test import ( - "errors" "testing" "github.com/jjeffcaii/reactor-go/internal" "github.com/stretchr/testify/assert" ) -func TestTryRecoverError(t *testing.T) { - fakeErr := errors.New("fake error") - assert.Equal(t, fakeErr, internal.TryRecoverError(fakeErr)) - assert.Error(t, internal.TryRecoverError("fake error")) - assert.Error(t, internal.TryRecoverError(123)) -} - func TestEmptySubscription(t *testing.T) { assert.NotPanics(t, func() { internal.EmptySubscription.Cancel() diff --git a/internal/subscribers/switch_if_empty.go b/internal/subscribers/switch_if_empty.go index 4efcd6e..fbcf905 100644 --- a/internal/subscribers/switch_if_empty.go +++ b/internal/subscribers/switch_if_empty.go @@ -2,6 +2,7 @@ package subscribers import ( "context" + "errors" "github.com/jjeffcaii/reactor-go" ) @@ -54,12 +55,17 @@ func (s *SwitchIfEmptySubscriber) OnSubscribe(ctx context.Context, su reactor.Su } func (s *SwitchIfEmptySubscriber) OnComplete() { - if !s.nextOnce { - s.nextOnce = true - s.other.SubscribeWith(s.ctx, s) - } else { + if s.nextOnce { s.actual.OnComplete() + return } + s.nextOnce = true + if s.other == nil { + s.actual.OnError(errors.New("the alternative SwitchIfEmpty Mono is nil")) + } else { + s.other.SubscribeWith(s.ctx, s) + } + } func NewSwitchIfEmptySubscriber(alternative reactor.RawPublisher, actual reactor.Subscriber) *SwitchIfEmptySubscriber { diff --git a/internal/subscribers/switch_value_if_error.go b/internal/subscribers/switch_value_if_error.go index 262225a..574ffd4 100644 --- a/internal/subscribers/switch_value_if_error.go +++ b/internal/subscribers/switch_value_if_error.go @@ -41,7 +41,9 @@ func (s *SwitchValueIfErrorSubscriber) Cancel() { func (s *SwitchValueIfErrorSubscriber) OnError(err error) { if atomic.AddInt32(&s.errorCalls, 1) == 1 { hooks.Global().OnErrorDrop(err) - s.actual.OnNext(s.v) + if s.v != nil { + s.actual.OnNext(s.v) + } s.actual.OnComplete() } else { s.actual.OnError(err) diff --git a/mono/mono_create.go b/mono/mono_create.go index 9e4ce0a..f3266c3 100644 --- a/mono/mono_create.go +++ b/mono/mono_create.go @@ -2,17 +2,14 @@ package mono import ( "context" - "errors" "sync" "sync/atomic" "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/hooks" - "github.com/jjeffcaii/reactor-go/internal" + "github.com/pkg/errors" ) -var _errRunSinkFailed = errors.New("execute creation func failed") - var _sinkPool = sync.Pool{ New: func() interface{} { return new(sink) @@ -49,8 +46,14 @@ func newMonoCreate(gen func(context.Context, Sink)) monoCreate { return monoCreate{ sinker: func(ctx context.Context, sink Sink) { defer func() { - if e := recover(); e != nil { - sink.Error(_errRunSinkFailed) + rec := recover() + if rec == nil { + return + } + if e, ok := rec.(error); ok { + sink.Error(errors.WithStack(e)) + } else { + sink.Error(errors.Errorf("%v", rec)) } }() @@ -103,8 +106,14 @@ func (s *sink) Error(err error) { func (s *sink) Next(v Any) { defer func() { - if err := internal.TryRecoverError(recover()); err != nil { - s.Error(err) + rec := recover() + if rec == nil { + return + } + if e, ok := rec.(error); ok { + s.Error(errors.WithStack(e)) + } else { + s.Error(errors.Errorf("%v", rec)) } }() s.actual.OnNext(v) diff --git a/mono/mono_just.go b/mono/mono_just.go index 518f13b..0f39802 100644 --- a/mono/mono_just.go +++ b/mono/mono_just.go @@ -6,7 +6,7 @@ import ( "sync/atomic" "github.com/jjeffcaii/reactor-go" - "github.com/jjeffcaii/reactor-go/internal" + "github.com/pkg/errors" ) var _justSubscriptionPool = sync.Pool{ @@ -64,10 +64,15 @@ func (j *justSubscription) Request(n int) { return } defer func() { - if err := internal.TryRecoverError(recover()); err != nil { - j.actual.OnError(err) - } else { + rec := recover() + if rec == nil { j.actual.OnComplete() + return + } + if e, ok := rec.(error); ok { + j.actual.OnError(errors.WithStack(e)) + } else { + j.actual.OnError(errors.Errorf("%v", rec)) } }() j.actual.OnNext(j.parent.value) diff --git a/mono/mono_test.go b/mono/mono_test.go index 517567c..129aca0 100644 --- a/mono/mono_test.go +++ b/mono/mono_test.go @@ -2,7 +2,6 @@ package mono_test import ( "context" - "errors" "fmt" "sync/atomic" "testing" @@ -12,6 +11,7 @@ import ( "github.com/jjeffcaii/reactor-go/hooks" "github.com/jjeffcaii/reactor-go/mono" "github.com/jjeffcaii/reactor-go/scheduler" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -144,7 +144,7 @@ func testPanic(m mono.Mono, t *testing.T) { in.DoOnError(func(e error) { catches = e }).Subscribe(context.Background()) - assert.Equal(t, fakeErr, catches, "not that error") + assert.Equal(t, fakeErr, errors.Cause(catches), "not that error") } checker(m.DoOnNext(func(v Any) error { return fakeErr diff --git a/mono/mono_zip_test.go b/mono/mono_zip_test.go index c4ffc33..3f6c3ff 100644 --- a/mono/mono_zip_test.go +++ b/mono/mono_zip_test.go @@ -121,3 +121,41 @@ func TestZip_context(t *testing.T) { assert.Error(t, err) assert.True(t, reactor.IsCancelledError(err)) } + +func TestZip_EdgeCase(t *testing.T) { + var ( + nextCnt = new(int32) + completeCnt = new(int32) + errorCnt = new(int32) + ) + mono.Zip(mono.JustOneshot("1"), mono.JustOneshot("2")). + FlatMap(func(any reactor.Any) mono.Mono { + if any != nil { + return mono.Zip(mono.JustOneshot("333"), mono.JustOneshot("44444444")). + Filter(func(any reactor.Any) bool { + panic("fake panic") + }). + Map(func(any reactor.Any) (reactor.Any, error) { + panic("ddddddd") + }) + } + return mono.JustOneshot("dddd") + }). + Subscribe(context.Background(), + reactor.OnNext(func(v reactor.Any) error { + atomic.AddInt32(nextCnt, 1) + return nil + }), + reactor.OnError(func(e error) { + atomic.AddInt32(errorCnt, 1) + t.Logf("%v", e) + }), + reactor.OnComplete(func() { + atomic.AddInt32(completeCnt, 1) + }), + ) + + assert.Equal(t, int32(0), atomic.LoadInt32(nextCnt), "next count should be zero") + assert.Equal(t, int32(1), atomic.LoadInt32(errorCnt), "error count should be 1") + assert.Equal(t, int32(0), atomic.LoadInt32(completeCnt), "complete count should be zero") +} diff --git a/mono/op_filter.go b/mono/op_filter.go index 9bea5c1..b47292e 100644 --- a/mono/op_filter.go +++ b/mono/op_filter.go @@ -5,6 +5,7 @@ import ( "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/internal" + "github.com/pkg/errors" ) type filterSubscriber struct { @@ -46,8 +47,14 @@ func (f *filterSubscriber) OnError(err error) { func (f *filterSubscriber) OnNext(v Any) { defer func() { - if err := internal.TryRecoverError(recover()); err != nil { - f.OnError(err) + rec := recover() + if rec == nil { + return + } + if e, ok := rec.(error); ok { + f.OnError(errors.WithStack(e)) + } else { + f.OnError(errors.Errorf("%v", rec)) } }() if f.predicate(v) { diff --git a/mono/op_flatmap.go b/mono/op_flatmap.go index 0098ff8..1be23e5 100644 --- a/mono/op_flatmap.go +++ b/mono/op_flatmap.go @@ -5,6 +5,7 @@ import ( "sync/atomic" "github.com/jjeffcaii/reactor-go" + "github.com/pkg/errors" ) const ( @@ -71,11 +72,34 @@ func (p *flatMapSubscriber) OnNext(v Any) { if atomic.LoadInt32(&p.stat) != 0 { return } - m := p.mapper(v) + nextMono, err := p.computeNextMono(v) + if err != nil { + p.actual.OnError(err) + return + } inner := &innerFlatMapSubscriber{ parent: p, } - m.SubscribeWith(p.ctx, inner) + nextMono.SubscribeWith(p.ctx, inner) +} + +func (p *flatMapSubscriber) computeNextMono(v Any) (next Mono, err error) { + defer func() { + rec := recover() + if rec == nil { + return + } + if e, ok := rec.(error); ok { + err = errors.WithStack(e) + } else { + err = errors.Errorf("%v", rec) + } + }() + next = p.mapper(v) + if next == nil { + err = errors.New("the FlatMap result is nil") + } + return } func (p *flatMapSubscriber) OnSubscribe(ctx context.Context, s reactor.Subscription) { diff --git a/mono/op_map.go b/mono/op_map.go index 3637a5b..e21be19 100644 --- a/mono/op_map.go +++ b/mono/op_map.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/jjeffcaii/reactor-go" + "github.com/pkg/errors" ) var _mapSubscriberPool = sync.Pool{ @@ -68,12 +69,25 @@ func (m *mapSubscriber) OnError(err error) { func (m *mapSubscriber) OnNext(v Any) { if m == nil || m.actual == nil || m.t == nil { + // TODO: return } - if o, err := m.t(v); err != nil { + defer func() { + rec := recover() + if rec == nil { + return + } + if e, ok := rec.(error); ok { + m.actual.OnError(errors.WithStack(e)) + } else { + m.actual.OnError(errors.Errorf("%v", rec)) + } + }() + + if transformed, err := m.t(v); err != nil { m.actual.OnError(err) } else { - m.actual.OnNext(o) + m.actual.OnNext(transformed) } } diff --git a/mono/op_peek.go b/mono/op_peek.go index a921b42..c187e7c 100644 --- a/mono/op_peek.go +++ b/mono/op_peek.go @@ -6,6 +6,7 @@ import ( "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/internal" + "github.com/pkg/errors" ) type monoPeek struct { @@ -87,8 +88,14 @@ func (p *peekSubscriber) OnNext(v Any) { } if call := p.parent.onNextCall; call != nil { defer func() { - if err := internal.TryRecoverError(recover()); err != nil { - p.OnError(err) + rec := recover() + if rec == nil { + return + } + if e, ok := rec.(error); ok { + p.OnError(errors.WithStack(e)) + } else { + p.OnError(errors.Errorf("%v", rec)) } }() if err := call(v); err != nil { diff --git a/mono/op_switch_if_error.go b/mono/op_switch_if_error.go index d169bc7..138b7a6 100644 --- a/mono/op_switch_if_error.go +++ b/mono/op_switch_if_error.go @@ -5,6 +5,7 @@ import ( "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/internal/subscribers" + "github.com/pkg/errors" ) type monoSwitchIfError struct { @@ -17,8 +18,27 @@ func (m monoSwitchIfError) Parent() reactor.RawPublisher { } func (m monoSwitchIfError) SubscribeWith(ctx context.Context, actual reactor.Subscriber) { - alternative := func(err error) reactor.RawPublisher { - return m.sw(err) + alternative := func(err error) (pub reactor.RawPublisher) { + if m.sw == nil { + pub = newMonoError(errors.New("the SwitchIfError transform is nil")) + return + } + defer func() { + rec := recover() + if rec == nil { + return + } + if e, ok := rec.(error); ok { + pub = newMonoError(errors.WithStack(e)) + } else { + pub = newMonoError(errors.Errorf("%v", rec)) + } + }() + pub = m.sw(err) + if pub == nil { + pub = newMonoError(errors.New("the SwitchIfError returns nil Mono")) + } + return } s := subscribers.NewSwitchIfErrorSubscriber(alternative, actual) actual.OnSubscribe(ctx, s) diff --git a/mono/panic_test.go b/mono/panic_test.go new file mode 100644 index 0000000..df20945 --- /dev/null +++ b/mono/panic_test.go @@ -0,0 +1,210 @@ +package mono_test + +import ( + "context" + "testing" + + "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go/mono" + "github.com/stretchr/testify/assert" +) + +const ( + cntDoOnNext int = iota + cntDoOnError + cntDoOnComplete + cntOnNext + cntOnError + cntOnComplete + cntTotal +) + +func TestPanic_Map(t *testing.T) { + var cnt [cntTotal]int + runPanicTest(t, func(m mono.Mono) mono.Mono { + return m.Map(func(any reactor.Any) (reactor.Any, error) { + panic("fake panic") + }) + }, cnt[:]) + normalPanicCheck(t, cnt[:]) +} + +func TestPanic_Filter(t *testing.T) { + var cnt [cntTotal]int + runPanicTest(t, func(m mono.Mono) mono.Mono { + return m.Filter(func(any reactor.Any) bool { + panic("fake panic") + }) + }, cnt[:]) + normalPanicCheck(t, cnt[:]) +} + +func TestPanic_DoOnNext(t *testing.T) { + var cnt [cntTotal]int + runPanicTest(t, func(m mono.Mono) mono.Mono { + return m.DoOnNext(func(v reactor.Any) error { + panic("fake panic") + }) + }, cnt[:]) + normalPanicCheck(t, cnt[:]) +} + +func TestPanic_FlatMap(t *testing.T) { + var cnt [cntTotal]int + runPanicTest(t, func(m mono.Mono) mono.Mono { + return m.FlatMap(func(any reactor.Any) mono.Mono { + panic("fake panic") + }) + }, cnt[:]) + normalPanicCheck(t, cnt[:]) + for i := 0; i < len(cnt); i++ { + cnt[i] = 0 + } + runPanicTest(t, func(m mono.Mono) mono.Mono { + return m.FlatMap(func(any reactor.Any) mono.Mono { + return nil + }) + }, cnt[:]) + normalPanicCheck(t, cnt[:]) +} + +func TestPanic_SwitchIfError(t *testing.T) { + run := func(sw func(err error) mono.Mono) { + var cnt [cntTotal]int + var one int + mono.Error(fakeErr). + DoOnError(func(e error) { + one++ + }). + SwitchIfError(sw). + DoOnNext(func(v reactor.Any) error { + cnt[cntDoOnNext]++ + return nil + }). + DoOnError(func(e error) { + cnt[cntDoOnError]++ + }). + Subscribe(context.Background(), reactor.OnError(func(e error) { + cnt[cntOnError]++ + }), reactor.OnNext(func(v reactor.Any) error { + cnt[cntOnNext]++ + return nil + }), reactor.OnComplete(func() { + cnt[cntOnComplete]++ + })) + normalPanicCheck(t, cnt[:]) + } + + run(func(err error) mono.Mono { + panic("fake panic") + }) + + run(func(err error) mono.Mono { + return nil + }) + + run(nil) + + run(func(err error) mono.Mono { + return mono.Just(1).Map(func(any reactor.Any) (reactor.Any, error) { + panic("fake panic") + }) + }) + +} + +func TestPanic_SwitchIfEmpty(t *testing.T) { + run := func(replace mono.Mono) { + var zero int + var cnt [cntTotal]int + mono.Empty(). + DoOnNext(func(v reactor.Any) error { + zero++ + return nil + }). + DoOnError(func(e error) { + zero++ + }). + SwitchIfEmpty(replace). + DoOnNext(func(v reactor.Any) error { + cnt[cntDoOnNext]++ + return nil + }). + DoOnError(func(e error) { + cnt[cntDoOnError]++ + t.Log("[PANIC] DoOnError:", e) + }). + Subscribe(context.Background(), + reactor.OnNext(func(v reactor.Any) error { + cnt[cntOnNext]++ + return nil + }), + reactor.OnError(func(e error) { + cnt[cntOnError]++ + t.Log("[PANIC] DoOnError:", e) + }), + reactor.OnComplete(func() { + cnt[cntOnComplete]++ + }), + ) + assert.Zero(t, zero) + normalPanicCheck(t, cnt[:]) + } + + run(mono.Just(1). + Map(func(any reactor.Any) (reactor.Any, error) { + panic("fake panic") + })) + + run(nil) +} + +func normalPanicCheck(t *testing.T, cnt []int) { + assert.Equal(t, 0, cnt[cntDoOnNext]) + assert.Equal(t, 1, cnt[cntDoOnError]) + assert.Equal(t, 0, cnt[cntDoOnComplete]) + assert.Equal(t, 1, cnt[cntOnError]) + assert.Equal(t, 0, cnt[cntOnNext]) + assert.Equal(t, 0, cnt[cntOnComplete]) +} + +func runPanicTest(t *testing.T, convert func(mono.Mono) mono.Mono, cnt []int) { + var one int + var zero int + source := convert(mono.Just(1). + DoOnNext(func(v reactor.Any) error { + one++ + return nil + }). + DoOnError(func(e error) { + zero++ + })) + + source. + DoOnNext(func(v reactor.Any) error { + cnt[cntDoOnNext]++ + return nil + }). + DoOnError(func(e error) { + cnt[cntDoOnError]++ + t.Log("[PANIC] DoOnError:", e) + }). + DoOnComplete(func() { + cnt[cntDoOnComplete]++ + }). + Subscribe(context.Background(), + reactor.OnNext(func(v reactor.Any) error { + cnt[cntOnNext]++ + return nil + }), + reactor.OnError(func(e error) { + cnt[cntOnError]++ + t.Log("[PANIC] OnError:", e) + }), + reactor.OnComplete(func() { + cnt[cntOnComplete]++ + }), + ) + assert.Equal(t, 1, one) + assert.Equal(t, 0, zero) +}