Skip to content

Commit

Permalink
fix: Leave RunCtx func backward compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
lbcPolo committed Dec 11, 2023
1 parent 42cb227 commit 1d4a97f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 15 deletions.
16 changes: 14 additions & 2 deletions retrier/retrier.go
Expand Up @@ -43,7 +43,7 @@ func (r *Retrier) WithInfiniteRetry() *Retrier {

// Run executes the given work function by executing RunCtx without context.Context.
func (r *Retrier) Run(work func() error) error {
return r.RunCtx(context.Background(), func(ctx context.Context, retries int) error {
return r.RunFn(context.Background(), func(c context.Context, r int) error {
// never use ctx
return work()
})
Expand All @@ -54,7 +54,19 @@ func (r *Retrier) Run(work func() error) error {
// returned to the caller. If the result is Retry, then Run sleeps according to the its backoff policy
// before retrying. If the total number of retries is exceeded then the return value of the work function
// is returned to the caller regardless.
func (r *Retrier) RunCtx(ctx context.Context, work func(ctx context.Context, retries int) error) error {
func (r *Retrier) RunCtx(ctx context.Context, work func(ctx context.Context) error) error {
return r.RunFn(ctx, func(c context.Context, r int) error {
return work(c)
})
}

// RunFn executes the given work function, then classifies its return value based on the classifier used
// to construct the Retrier. If the result is Succeed or Fail, the return value of the work function is
// returned to the caller. If the result is Retry, then Run sleeps according to the backoff policy
// before retrying. If the total number of retries is exceeded then the return value of the work function
// is returned to the caller regardless. The work function takes 2 args, the context and
// the number of attempted retries.
func (r *Retrier) RunFn(ctx context.Context, work func(ctx context.Context, retries int) error) error {
retries := 0
for {
ret := work(ctx, retries)
Expand Down
53 changes: 40 additions & 13 deletions retrier/retrier_test.go
Expand Up @@ -20,9 +20,9 @@ func genWork(returns []error) func() error {
}
}

func genWorkWithCtx() func(ctx context.Context, retries int) error {
func genWorkWithCtx() func(ctx context.Context) error {
i = 0
return func(ctx context.Context, retries int) error {
return func(ctx context.Context) error {
select {
case <-ctx.Done():
return errFoo
Expand All @@ -33,15 +33,6 @@ func genWorkWithCtx() func(ctx context.Context, retries int) error {
}
}

func genWorkWithCtxError(returns []error) func(ctx context.Context, retries int) error {
return func(ctx context.Context, retries int) error {
if retries > len(returns) {
return nil
}
return returns[retries-1]
}
}

func TestRetrier(t *testing.T) {
r := New([]time.Duration{0, 10 * time.Millisecond}, WhitelistClassifier{errFoo})

Expand Down Expand Up @@ -98,8 +89,26 @@ func TestRetrierCtxError(t *testing.T) {
ctx := context.Background()
r := New([]time.Duration{0, 10 * time.Millisecond}, nil)
errExpected := []error{errFoo, errFoo, errBar, errBaz}
retries := 0
err := r.RunCtx(ctx, func(ctx context.Context) error {
if retries >= len(errExpected) {
return nil
}
err := errExpected[retries]
retries++
return err
})
if err != errBar {
t.Error(err)
}
}

err := r.RunCtx(ctx, func(ctx context.Context, retries int) error {
func TestRetrierRunFnError(t *testing.T) {
ctx := context.Background()
r := New([]time.Duration{0, 10 * time.Millisecond}, nil)
errExpected := []error{errFoo, errFoo, errBar, errBaz}

err := r.RunFn(ctx, func(ctx context.Context, retries int) error {
if retries >= len(errExpected) {
return nil
}
Expand All @@ -114,8 +123,26 @@ func TestRetrierCtxWithInfinite(t *testing.T) {
ctx := context.Background()
r := New([]time.Duration{0, 10 * time.Millisecond}, nil).WithInfiniteRetry()
errExpected := []error{errFoo, errFoo, errFoo, errBar, errBaz}
retries := 0
err := r.RunCtx(ctx, func(ctx context.Context) error {
if retries >= len(errExpected) {
return nil
}
err := errExpected[retries]
retries++
return err
})
if err != nil {
t.Error(err)
}
}

func TestRetrierRunFnWithInfinite(t *testing.T) {
ctx := context.Background()
r := New([]time.Duration{0, 10 * time.Millisecond}, nil).WithInfiniteRetry()
errExpected := []error{errFoo, errFoo, errFoo, errBar, errBaz}

err := r.RunCtx(ctx, func(ctx context.Context, retries int) error {
err := r.RunFn(ctx, func(ctx context.Context, retries int) error {
if retries >= len(errExpected) {
return nil
}
Expand Down

0 comments on commit 1d4a97f

Please sign in to comment.