Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry extra handling #161

Merged
merged 7 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ The returned response object is an `*http.Response`, the same thing you would
usually get from `net/http`. Had the request failed one or more times, the above
call would block and retry with exponential backoff.

## Retrying cases that fail after a seeming success

It's possible for a request to succeed in the sense that the expected response headers are received, but then to encounter network-level errors while reading the response body. In go-retryablehttp's most basic usage, this error would not be retryable, due to the out-of-band handling of the response body. In some cases it may be desirable to handle the response body as part of the retryable operation.

A toy example (which will retry the full request and succeed on the second attempt) is shown below:

```go
c := retryablehttp.NewClient()
r := retryablehttp.NewRequest("GET", "://foo", nil)
handlerShouldRetry := true
r.SetResponseHandler(func(*http.Response) error {
if !handlerShouldRetry {
return nil
}
handlerShouldRetry = false
return errors.New("retryable error")
})
```

## Getting a stdlib `*http.Client` with retries

It's possible to convert a `*retryablehttp.Client` directly to a `*http.Client`.
Expand Down
54 changes: 44 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,21 @@ var (
// scheme specified in the URL is invalid. This error isn't typed
// specifically so we resort to matching on the error string.
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)

// A regular expression to match the error returned by net/http when the
// TLS certificate is not trusted. This error isn't typed
// specifically so we resort to matching on the error string.
notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`)
)

// ReaderFunc is the type of function that can be given natively to NewRequest
type ReaderFunc func() (io.Reader, error)

// ResponseHandlerFunc is a type of function that takes in a Response, and does something with it.
// It only runs if the initial part of the request was successful.
// If an error is returned, the client's retry policy will be used to determine whether to retry the whole request.
type ResponseHandlerFunc func(*http.Response) error

// LenReader is an interface implemented by many in-memory io.Reader's. Used
// for automatically sending the right Content-Length header when possible.
type LenReader interface {
Expand All @@ -86,6 +96,8 @@ type Request struct {
// used to rewind the request data in between retries.
body ReaderFunc

responseHandler ResponseHandlerFunc

// Embed an HTTP request directly. This makes a *Request act exactly
// like an *http.Request so that all meta methods are supported.
*http.Request
Expand All @@ -95,11 +107,17 @@ type Request struct {
// with its context changed to ctx. The provided ctx must be non-nil.
func (r *Request) WithContext(ctx context.Context) *Request {
return &Request{
body: r.body,
Request: r.Request.WithContext(ctx),
body: r.body,
responseHandler: r.responseHandler,
Request: r.Request.WithContext(ctx),
}
}

// SetResponseHandler allows setting the response handler.
func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) {
r.responseHandler = fn
}

// BodyBytes allows accessing the request body. It is an analogue to
// http.Request's Body variable, but it returns a copy of the underlying data
// rather than consuming it.
Expand Down Expand Up @@ -254,7 +272,7 @@ func FromRequest(r *http.Request) (*Request, error) {
return nil, err
}
// Could assert contentLength == r.ContentLength
return &Request{bodyReader, r}, nil
return &Request{body: bodyReader, Request: r}, nil
}

// NewRequest creates a new wrapped request.
Expand All @@ -278,7 +296,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, rawBody inte
}
httpReq.ContentLength = contentLength

return &Request{bodyReader, httpReq}, nil
return &Request{body: bodyReader, Request: httpReq}, nil
}

// Logger interface allows to use other loggers than
Expand Down Expand Up @@ -445,6 +463,9 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
}

// Don't retry if the error was due to TLS cert verification failure.
if notTrustedErrorRe.MatchString(v.Error()) {
gavriel-hc marked this conversation as resolved.
Show resolved Hide resolved
return false, v
}
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
return false, v
}
Expand Down Expand Up @@ -565,9 +586,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
var resp *http.Response
var attempt int
var shouldRetry bool
var doErr, checkErr error
var doErr, respErr, checkErr error

for i := 0; ; i++ {
doErr, respErr = nil, nil
attempt++

// Always rewind the request body when non-nil.
Expand Down Expand Up @@ -600,13 +622,21 @@ func (c *Client) Do(req *Request) (*http.Response, error) {

// Check if we should continue with retries.
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)
if !shouldRetry && doErr == nil && req.responseHandler != nil {
respErr = req.responseHandler(resp)
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr)
}

if doErr != nil {
err := doErr
if respErr != nil {
err = respErr
}
if err != nil {
switch v := logger.(type) {
case LeveledLogger:
v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL)
v.Error("request failed", "error", err, "method", req.Method, "url", req.URL)
case Logger:
v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr)
v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err)
}
} else {
// Call this here to maintain the behavior of logging all requests,
Expand Down Expand Up @@ -669,15 +699,19 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
}

// this is the closest we have to success criteria
if doErr == nil && checkErr == nil && !shouldRetry {
if doErr == nil && respErr == nil && checkErr == nil && !shouldRetry {
return resp, nil
}

defer c.HTTPClient.CloseIdleConnections()

err := doErr
var err error
if checkErr != nil {
err = checkErr
} else if respErr != nil {
err = respErr
} else {
err = doErr
}

if c.ErrorHandler != nil {
Expand Down
108 changes: 104 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ func testClientDo(t *testing.T, body interface{}) {
// Send the request
var resp *http.Response
doneCh := make(chan struct{})
errCh := make(chan error, 1)
go func() {
defer close(doneCh)
defer close(errCh)
var err error
resp, err = client.Do(req)
if err != nil {
t.Fatalf("err: %v", err)
}
errCh <- err
}()

select {
Expand Down Expand Up @@ -247,6 +247,106 @@ func testClientDo(t *testing.T, body interface{}) {
if retryCount < 0 {
t.Fatal("request log hook was not called")
}

err = <-errCh
if err != nil {
t.Fatalf("err: %v", err)
}
}

func TestClient_Do_WithResponseHandler(t *testing.T) {
// Create the client. Use short retry windows so we fail faster.
client := NewClient()
client.RetryWaitMin = 10 * time.Millisecond
client.RetryWaitMax = 10 * time.Millisecond
client.RetryMax = 2

var checks int
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
checks++
if err != nil && strings.Contains(err.Error(), "nonretryable") {
return false, nil
}
return DefaultRetryPolicy(context.TODO(), resp, err)
}

// Mock server which always responds 200.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()

var shouldSucceed bool
tests := []struct {
name string
handler ResponseHandlerFunc
expectedChecks int // often 2x number of attempts since we check twice
err string
}{
{
name: "nil handler",
handler: nil,
expectedChecks: 1,
},
{
name: "handler always succeeds",
handler: func(*http.Response) error {
return nil
},
expectedChecks: 2,
},
{
name: "handler always fails in a retryable way",
handler: func(*http.Response) error {
return errors.New("retryable failure")
},
expectedChecks: 6,
},
{
name: "handler always fails in a nonretryable way",
handler: func(*http.Response) error {
return errors.New("nonretryable failure")
},
expectedChecks: 2,
},
{
name: "handler succeeds on second attempt",
handler: func(*http.Response) error {
if shouldSucceed {
return nil
}
shouldSucceed = true
return errors.New("retryable failure")
},
expectedChecks: 4,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
checks = 0
shouldSucceed = false
// Create the request
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
req.SetResponseHandler(tt.handler)

// Send the request.
_, err = client.Do(req)
if err != nil && !strings.Contains(err.Error(), tt.err) {
t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error())
}
if err == nil && tt.err != "" {
t.Fatalf("no error, expected: %s", tt.err)
}

if checks != tt.expectedChecks {
t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks)
}
})
}
}

func TestClient_Do_fails(t *testing.T) {
Expand Down Expand Up @@ -598,7 +698,7 @@ func TestClient_DefaultRetryPolicy_TLS(t *testing.T) {

func TestClient_DefaultRetryPolicy_redirects(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", 302)
http.Redirect(w, r, "/", http.StatusFound)
}))
defer ts.Close()

Expand Down
8 changes: 4 additions & 4 deletions roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {

expectedError := &url.Error{
Op: "Get",
URL: "http://this-url-does-not-exist-ed2fb.com/",
URL: "http://999.999.999.999:999/",
Err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &net.DNSError{
Name: "this-url-does-not-exist-ed2fb.com",
Name: "999.999.999.999",
Err: "no such host",
IsNotFound: true,
},
Expand All @@ -121,10 +121,10 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {

// Get the standard client and execute the request.
client := retryClient.StandardClient()
_, err := client.Get("http://this-url-does-not-exist-ed2fb.com/")
_, err := client.Get("http://999.999.999.999:999/")

// assert expectations
if !reflect.DeepEqual(normalizeError(err), expectedError) {
if !reflect.DeepEqual(expectedError, normalizeError(err)) {
t.Fatalf("expected %q, got %q", expectedError, err)
}
}
Expand Down