From 1d4a97f7bea2cafdb21086714a56d63dbc04440a Mon Sep 17 00:00:00 2001 From: Paul DRAPPIER Date: Mon, 11 Dec 2023 11:29:19 +0100 Subject: [PATCH] fix: Leave RunCtx func backward compatible --- retrier/retrier.go | 16 +++++++++++-- retrier/retrier_test.go | 53 +++++++++++++++++++++++++++++++---------- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/retrier/retrier.go b/retrier/retrier.go index bb8dc37..35f2c47 100644 --- a/retrier/retrier.go +++ b/retrier/retrier.go @@ -43,7 +43,7 @@ func (r *Retrier) WithInfiniteRetry() *Retrier { // Run executes the given work function by executing RunCtx without context.Context. func (r *Retrier) Run(work func() error) error { - return r.RunCtx(context.Background(), func(ctx context.Context, retries int) error { + return r.RunFn(context.Background(), func(c context.Context, r int) error { // never use ctx return work() }) @@ -54,7 +54,19 @@ func (r *Retrier) Run(work func() error) error { // returned to the caller. If the result is Retry, then Run sleeps according to the its backoff policy // before retrying. If the total number of retries is exceeded then the return value of the work function // is returned to the caller regardless. -func (r *Retrier) RunCtx(ctx context.Context, work func(ctx context.Context, retries int) error) error { +func (r *Retrier) RunCtx(ctx context.Context, work func(ctx context.Context) error) error { + return r.RunFn(ctx, func(c context.Context, r int) error { + return work(c) + }) +} + +// RunFn executes the given work function, then classifies its return value based on the classifier used +// to construct the Retrier. If the result is Succeed or Fail, the return value of the work function is +// returned to the caller. If the result is Retry, then Run sleeps according to the backoff policy +// before retrying. If the total number of retries is exceeded then the return value of the work function +// is returned to the caller regardless. The work function takes 2 args, the context and +// the number of attempted retries. +func (r *Retrier) RunFn(ctx context.Context, work func(ctx context.Context, retries int) error) error { retries := 0 for { ret := work(ctx, retries) diff --git a/retrier/retrier_test.go b/retrier/retrier_test.go index afe7e4f..c23407d 100644 --- a/retrier/retrier_test.go +++ b/retrier/retrier_test.go @@ -20,9 +20,9 @@ func genWork(returns []error) func() error { } } -func genWorkWithCtx() func(ctx context.Context, retries int) error { +func genWorkWithCtx() func(ctx context.Context) error { i = 0 - return func(ctx context.Context, retries int) error { + return func(ctx context.Context) error { select { case <-ctx.Done(): return errFoo @@ -33,15 +33,6 @@ func genWorkWithCtx() func(ctx context.Context, retries int) error { } } -func genWorkWithCtxError(returns []error) func(ctx context.Context, retries int) error { - return func(ctx context.Context, retries int) error { - if retries > len(returns) { - return nil - } - return returns[retries-1] - } -} - func TestRetrier(t *testing.T) { r := New([]time.Duration{0, 10 * time.Millisecond}, WhitelistClassifier{errFoo}) @@ -98,8 +89,26 @@ func TestRetrierCtxError(t *testing.T) { ctx := context.Background() r := New([]time.Duration{0, 10 * time.Millisecond}, nil) errExpected := []error{errFoo, errFoo, errBar, errBaz} + retries := 0 + err := r.RunCtx(ctx, func(ctx context.Context) error { + if retries >= len(errExpected) { + return nil + } + err := errExpected[retries] + retries++ + return err + }) + if err != errBar { + t.Error(err) + } +} - err := r.RunCtx(ctx, func(ctx context.Context, retries int) error { +func TestRetrierRunFnError(t *testing.T) { + ctx := context.Background() + r := New([]time.Duration{0, 10 * time.Millisecond}, nil) + errExpected := []error{errFoo, errFoo, errBar, errBaz} + + err := r.RunFn(ctx, func(ctx context.Context, retries int) error { if retries >= len(errExpected) { return nil } @@ -114,8 +123,26 @@ func TestRetrierCtxWithInfinite(t *testing.T) { ctx := context.Background() r := New([]time.Duration{0, 10 * time.Millisecond}, nil).WithInfiniteRetry() errExpected := []error{errFoo, errFoo, errFoo, errBar, errBaz} + retries := 0 + err := r.RunCtx(ctx, func(ctx context.Context) error { + if retries >= len(errExpected) { + return nil + } + err := errExpected[retries] + retries++ + return err + }) + if err != nil { + t.Error(err) + } +} + +func TestRetrierRunFnWithInfinite(t *testing.T) { + ctx := context.Background() + r := New([]time.Duration{0, 10 * time.Millisecond}, nil).WithInfiniteRetry() + errExpected := []error{errFoo, errFoo, errFoo, errBar, errBaz} - err := r.RunCtx(ctx, func(ctx context.Context, retries int) error { + err := r.RunFn(ctx, func(ctx context.Context, retries int) error { if retries >= len(errExpected) { return nil }