Skip to content
77 changes: 69 additions & 8 deletions sd/lb/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,76 @@ import (
"github.com/go-kit/kit/endpoint"
)

// RetryError is an error wrapper that is used by the retry mechanism. All
// errors returned by the retry mechanism via its endpoint will be RetryErrors.
type RetryError struct {
RawErrors []error // all errors encountered from endpoints directly
Final error // the final, terminating error
}

func (e RetryError) Error() string {
var suffix string
if len(e.RawErrors) > 1 {
a := make([]string, len(e.RawErrors)-1)
for i := 0; i < len(e.RawErrors)-1; i++ { // last one is Final
a[i] = e.RawErrors[i].Error()
}
suffix = fmt.Sprintf(" (previously: %s)", strings.Join(a, "; "))
}
return fmt.Sprintf("%v%s", e.Final, suffix)
}

// Callback is a function that is given the current attempt count and the error
// received from the underlying endpoint. It should return whether the Retry
// function should continue trying to get a working endpoint, and a custom error
// if desired. The error message may be nil, but a true/false is always
// expected. In all cases, if the replacement error is supplied, the received
// error will be replaced in the calling context.
type Callback func(n int, received error) (keepTrying bool, replacement error)

// Retry wraps a service load balancer and returns an endpoint oriented load
// balancer for the specified service method.
// Requests to the endpoint will be automatically load balanced via the load
// balancer. Requests that return errors will be retried until they succeed,
// up to max times, or until the timeout is elapsed, whichever comes first.
// balancer for the specified service method. Requests to the endpoint will be
// automatically load balanced via the load balancer. Requests that return
// errors will be retried until they succeed, up to max times, or until the
// timeout is elapsed, whichever comes first.
func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint {
return RetryWithCallback(timeout, b, maxRetries(max))
}
Copy link
Copy Markdown
Member

@peterbourgon peterbourgon Oct 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I haven't given this enough thought to give you really actionable advice yet, but I'm not happy with the way this is currently structured. Namely I'd want the "retry up to max times, or until timeout, whichever comes first" logic to be encoded in the callback passed to the function, and not hard-coded into the function itself. I'm not sure how it would look, precisely. Can you see if that's possible?

Copy link
Copy Markdown
Contributor Author

@rossmcf rossmcf Oct 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've split the max retry logic out into the callback. There's a helper function for generating a max retries callback if that's all you want to do, and Retry() passes one of those on to RetryWithCallback() now.

I'm not sure if or how to tackle handling timeouts in the callback. I think it's reasonable to keep those two jobs separate. To me, I'm happy with the timeout in the select inside RetryWithCallback(). What do you reckon @peterbourgon?


func maxRetries(max int) Callback {
return func(n int, err error) (keepTrying bool, replacement error) {
return n < max, nil
}
}

func alwaysRetry(int, error) (keepTrying bool, replacement error) {
return true, nil
}

// RetryWithCallback wraps a service load balancer and returns an endpoint
// oriented load balancer for the specified service method. Requests to the
// endpoint will be automatically load balanced via the load balancer. Requests
// that return errors will be retried until they succeed, up to max times, until
// the callback returns false, or until the timeout is elapsed, whichever comes
// first.
func RetryWithCallback(timeout time.Duration, b Balancer, cb Callback) endpoint.Endpoint {
if cb == nil {
cb = alwaysRetry
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can be safer, here. If there's a nil callback, we can substitute a no-op or always-pass implementation instead.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Substituting a no-op now, with a test.

if b == nil {
panic("nil Balancer")
}

return func(ctx context.Context, request interface{}) (response interface{}, err error) {
var (
newctx, cancel = context.WithTimeout(ctx, timeout)
responses = make(chan interface{}, 1)
errs = make(chan error, 1)
a = []string{}
final RetryError
)
defer cancel()
for i := 1; i <= max; i++ {

for i := 1; ; i++ {
go func() {
e, err := b.Endpoint()
if err != nil {
Expand All @@ -45,13 +97,22 @@ func Retry(max int, timeout time.Duration, b Balancer) endpoint.Endpoint {
select {
case <-newctx.Done():
return nil, newctx.Err()

case response := <-responses:
return response, nil

case err := <-errs:
a = append(a, err.Error())
final.RawErrors = append(final.RawErrors, err)
keepTrying, replacement := cb(i, err)
if replacement != nil {
err = replacement
}
if !keepTrying {
final.Final = err
return nil, final
}
continue
}
}
return nil, fmt.Errorf("retry attempts exceeded (%s)", strings.Join(a, "; "))
}
}
70 changes: 61 additions & 9 deletions sd/lb/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import (

"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/sd"
loadbalancer "github.com/go-kit/kit/sd/lb"
"github.com/go-kit/kit/sd/lb"
)

func TestRetryMaxTotalFail(t *testing.T) {
var (
endpoints = sd.FixedSubscriber{} // no endpoints
lb = loadbalancer.NewRoundRobin(endpoints)
retry = loadbalancer.Retry(999, time.Second, lb) // lots of retries
rr = lb.NewRoundRobin(endpoints)
retry = lb.Retry(999, time.Second, rr) // lots of retries
ctx = context.Background()
)
if _, err := retry(ctx, struct{}{}); err == nil {
Expand All @@ -37,11 +37,11 @@ func TestRetryMaxPartialFail(t *testing.T) {
2: endpoints[2],
}
retries = len(endpoints) - 1 // not quite enough retries
lb = loadbalancer.NewRoundRobin(subscriber)
rr = lb.NewRoundRobin(subscriber)
ctx = context.Background()
)
if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err == nil {
t.Errorf("expected error, got none")
if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err == nil {
t.Errorf("expected error two, got none")
}
}

Expand All @@ -58,10 +58,10 @@ func TestRetryMaxSuccess(t *testing.T) {
2: endpoints[2],
}
retries = len(endpoints) // exactly enough retries
lb = loadbalancer.NewRoundRobin(subscriber)
rr = lb.NewRoundRobin(subscriber)
ctx = context.Background()
)
if _, err := loadbalancer.Retry(retries, time.Second, lb)(ctx, struct{}{}); err != nil {
if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err != nil {
t.Error(err)
}
}
Expand All @@ -71,7 +71,7 @@ func TestRetryTimeout(t *testing.T) {
step = make(chan struct{})
e = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
timeout = time.Millisecond
retry = loadbalancer.Retry(999, timeout, loadbalancer.NewRoundRobin(sd.FixedSubscriber{0: e}))
retry = lb.Retry(999, timeout, lb.NewRoundRobin(sd.FixedSubscriber{0: e}))
errs = make(chan error, 1)
invoke = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
)
Expand All @@ -88,3 +88,55 @@ func TestRetryTimeout(t *testing.T) {
t.Errorf("wanted %v, got none", context.DeadlineExceeded)
}
}

func TestAbortEarlyCustomMessage(t *testing.T) {
var (
myErr = errors.New("aborting early")
cb = func(int, error) (bool, error) { return false, myErr }
endpoints = sd.FixedSubscriber{} // no endpoints
rr = lb.NewRoundRobin(endpoints)
retry = lb.RetryWithCallback(time.Second, rr, cb) // lots of retries
ctx = context.Background()
)
_, err := retry(ctx, struct{}{})
if want, have := myErr, err.(lb.RetryError).Final; want != have {
t.Errorf("want %v, have %v", want, have)
}
}

func TestErrorPassedUnchangedToCallback(t *testing.T) {
var (
myErr = errors.New("my custom error")
cb = func(_ int, err error) (bool, error) {
if want, have := myErr, err; want != have {
t.Errorf("want %v, have %v", want, have)
}
return false, nil
}
endpoint = func(ctx context.Context, request interface{}) (interface{}, error) {
return nil, myErr
}
endpoints = sd.FixedSubscriber{endpoint} // no endpoints
rr = lb.NewRoundRobin(endpoints)
retry = lb.RetryWithCallback(time.Second, rr, cb) // lots of retries
ctx = context.Background()
)
_, err := retry(ctx, struct{}{})
if want, have := myErr, err.(lb.RetryError).Final; want != have {
t.Errorf("want %v, have %v", want, have)
}
}

func TestHandleNilCallback(t *testing.T) {
var (
subscriber = sd.FixedSubscriber{
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
}
rr = lb.NewRoundRobin(subscriber)
ctx = context.Background()
)
retry := lb.RetryWithCallback(time.Second, rr, nil)
if _, err := retry(ctx, struct{}{}); err != nil {
t.Error(err)
}
}