Skip to content

Commit

Permalink
Support contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaud-lb committed Feb 1, 2017
1 parent b02f2bb commit eff3c63
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 4 deletions.
59 changes: 59 additions & 0 deletions context.go
@@ -0,0 +1,59 @@
package backoff

import (
"context"
"time"
)

// BackOffContext is a backoff policy that stops retrying after the context
// is canceled.
type BackOffContext interface {
BackOff
Context() context.Context
}

type backOffContext struct {
BackOff
ctx context.Context
}

// WithContext returns BackOffContext with context ctx
//
// ctx must not be nil
func WithContext(b BackOff, ctx context.Context) BackOffContext {
if ctx == nil {
panic("nil context")
}

if b, ok := b.(*backOffContext); ok {
return &backOffContext{
BackOff: b.BackOff,
ctx: ctx,
}
}

return &backOffContext{
BackOff: b,
ctx: ctx,
}
}

func ensureContext(b BackOff) BackOffContext {
if cb, ok := b.(BackOffContext); ok {
return cb
}
return WithContext(b, context.Background())
}

func (b *backOffContext) Context() context.Context {
return b.ctx
}

func (b *backOffContext) NextBackOff() time.Duration {
select {
case <-b.Context().Done():
return Stop
default:
return b.BackOff.NextBackOff()
}
}
25 changes: 25 additions & 0 deletions context_test.go
@@ -0,0 +1,25 @@
package backoff

import (
"context"
"testing"
"time"
)

func TestContext(t *testing.T) {
b := NewConstantBackOff(time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

cb := WithContext(b, ctx)

if cb.Context() != ctx {
t.Error("invalid context")
}

cancel()

if cb.NextBackOff() != Stop {
t.Error("invalid next back off")
}
}
25 changes: 24 additions & 1 deletion example_test.go
@@ -1,6 +1,9 @@
package backoff

import "log"
import (
"context"
"log"
)

func ExampleRetry() {
// An operation that may fail.
Expand All @@ -17,6 +20,26 @@ func ExampleRetry() {
// Operation is successful.
}

func ExampleRetryContext() {
// A context
ctx := context.Background()

// An operation that may fail.
operation := func() error {
return nil // or an error
}

b := WithContext(NewExponentialBackOff(), ctx)

err := Retry(operation, b)
if err != nil {
// Handle error.
return
}

// Operation is successful.
}

func ExampleTicker() {
// An operation that may fail.
operation := func() error {
Expand Down
11 changes: 10 additions & 1 deletion retry.go
Expand Up @@ -27,6 +27,8 @@ func RetryNotify(operation Operation, b BackOff, notify Notify) error {
var err error
var next time.Duration

cb := ensureContext(b)

b.Reset()
for {
if err = operation(); err == nil {
Expand All @@ -41,6 +43,13 @@ func RetryNotify(operation Operation, b BackOff, notify Notify) error {
notify(err, next)
}

time.Sleep(next)
t := time.NewTimer(next)

select {
case <-cb.Context().Done():
t.Stop()
return err
case <-t.C:
}
}
}
37 changes: 37 additions & 0 deletions retry_test.go
@@ -1,9 +1,12 @@
package backoff

import (
"context"
"errors"
"fmt"
"log"
"testing"
"time"
)

func TestRetry(t *testing.T) {
Expand Down Expand Up @@ -32,3 +35,37 @@ func TestRetry(t *testing.T) {
t.Errorf("invalid number of retries: %d", i)
}
}

func TestRetryContext(t *testing.T) {
var cancelOn = 3
var i = 0

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// This function cancels context on "cancelOn" calls.
f := func() error {
i++
log.Printf("function is called %d. time\n", i)

// cancelling the context in the operation function is not a typical
// use-case, however it allows to get predictable test results.
if i == cancelOn {
cancel()
}

log.Println("error")
return fmt.Errorf("error (%d)", i)
}

err := Retry(f, WithContext(NewConstantBackOff(time.Millisecond), ctx))
if err == nil {
t.Errorf("error is unexpectedly nil")
}
if err.Error() != "error (3)" {
t.Errorf("unexpected error: %s", err.Error())
}
if i != cancelOn {
t.Errorf("invalid number of retries: %d", i)
}
}
6 changes: 4 additions & 2 deletions ticker.go
Expand Up @@ -13,7 +13,7 @@ import (
type Ticker struct {
C <-chan time.Time
c chan time.Time
b BackOff
b BackOffContext
stop chan struct{}
stopOnce sync.Once
}
Expand All @@ -26,7 +26,7 @@ func NewTicker(b BackOff) *Ticker {
t := &Ticker{
C: c,
c: c,
b: b,
b: ensureContext(b),
stop: make(chan struct{}),
}
go t.run()
Expand Down Expand Up @@ -58,6 +58,8 @@ func (t *Ticker) run() {
case <-t.stop:
t.c = nil // Prevent future ticks from being sent to the channel.
return
case <-t.b.Context().Done():
return
}
}
}
Expand Down
48 changes: 48 additions & 0 deletions ticker_test.go
@@ -1,9 +1,12 @@
package backoff

import (
"context"
"errors"
"fmt"
"log"
"testing"
"time"
)

func TestTicker(t *testing.T) {
Expand Down Expand Up @@ -43,3 +46,48 @@ func TestTicker(t *testing.T) {
t.Errorf("invalid number of retries: %d", i)
}
}

func TestTickerContext(t *testing.T) {
const cancelOn = 3
var i = 0

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// This function cancels context on "cancelOn" calls.
f := func() error {
i++
log.Printf("function is called %d. time\n", i)

// cancelling the context in the operation function is not a typical
// use-case, however it allows to get predictable test results.
if i == cancelOn {
cancel()
}

log.Println("error")
return fmt.Errorf("error (%d)", i)
}

b := WithContext(NewConstantBackOff(time.Millisecond), ctx)
ticker := NewTicker(b)

var err error
for _ = range ticker.C {
if err = f(); err != nil {
t.Log(err)
continue
}

break
}
if err == nil {
t.Errorf("error is unexpectedly nil")
}
if err.Error() != "error (3)" {
t.Errorf("unexpected error: %s", err.Error())
}
if i != cancelOn {
t.Errorf("invalid number of retries: %d", i)
}
}

0 comments on commit eff3c63

Please sign in to comment.