Skip to content

Commit

Permalink
Merge pull request #161 from hashicorp/retry-extra-handling
Browse files Browse the repository at this point in the history
Retry extra handling
  • Loading branch information
gavriel-hc committed Apr 13, 2022
2 parents ff6d014 + 5bd1a6f commit 98169fe
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 18 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ The returned response object is an `*http.Response`, the same thing you would
usually get from `net/http`. Had the request failed one or more times, the above
call would block and retry with exponential backoff.

## Retrying cases that fail after a seeming success

It's possible for a request to succeed in the sense that the expected response headers are received, but then to encounter network-level errors while reading the response body. In go-retryablehttp's most basic usage, this error would not be retryable, due to the out-of-band handling of the response body. In some cases it may be desirable to handle the response body as part of the retryable operation.

A toy example (which will retry the full request and succeed on the second attempt) is shown below:

```go
c := retryablehttp.NewClient()
r := retryablehttp.NewRequest("GET", "://foo", nil)
handlerShouldRetry := true
r.SetResponseHandler(func(*http.Response) error {
if !handlerShouldRetry {
return nil
}
handlerShouldRetry = false
return errors.New("retryable error")
})
```

## Getting a stdlib `*http.Client` with retries

It's possible to convert a `*retryablehttp.Client` directly to a `*http.Client`.
Expand Down
54 changes: 44 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,21 @@ var (
// scheme specified in the URL is invalid. This error isn't typed
// specifically so we resort to matching on the error string.
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)

// A regular expression to match the error returned by net/http when the
// TLS certificate is not trusted. This error isn't typed
// specifically so we resort to matching on the error string.
notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`)
)

// ReaderFunc is the type of function that can be given natively to NewRequest
type ReaderFunc func() (io.Reader, error)

// ResponseHandlerFunc is a type of function that takes in a Response, and does something with it.
// It only runs if the initial part of the request was successful.
// If an error is returned, the client's retry policy will be used to determine whether to retry the whole request.
type ResponseHandlerFunc func(*http.Response) error

// LenReader is an interface implemented by many in-memory io.Reader's. Used
// for automatically sending the right Content-Length header when possible.
type LenReader interface {
Expand All @@ -86,6 +96,8 @@ type Request struct {
// used to rewind the request data in between retries.
body ReaderFunc

responseHandler ResponseHandlerFunc

// Embed an HTTP request directly. This makes a *Request act exactly
// like an *http.Request so that all meta methods are supported.
*http.Request
Expand All @@ -95,11 +107,17 @@ type Request struct {
// with its context changed to ctx. The provided ctx must be non-nil.
func (r *Request) WithContext(ctx context.Context) *Request {
return &Request{
body: r.body,
Request: r.Request.WithContext(ctx),
body: r.body,
responseHandler: r.responseHandler,
Request: r.Request.WithContext(ctx),
}
}

// SetResponseHandler allows setting the response handler.
func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) {
r.responseHandler = fn
}

// BodyBytes allows accessing the request body. It is an analogue to
// http.Request's Body variable, but it returns a copy of the underlying data
// rather than consuming it.
Expand Down Expand Up @@ -254,7 +272,7 @@ func FromRequest(r *http.Request) (*Request, error) {
return nil, err
}
// Could assert contentLength == r.ContentLength
return &Request{bodyReader, r}, nil
return &Request{body: bodyReader, Request: r}, nil
}

// NewRequest creates a new wrapped request.
Expand All @@ -278,7 +296,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, rawBody inte
}
httpReq.ContentLength = contentLength

return &Request{bodyReader, httpReq}, nil
return &Request{body: bodyReader, Request: httpReq}, nil
}

// Logger interface allows to use other loggers than
Expand Down Expand Up @@ -445,6 +463,9 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
}

// Don't retry if the error was due to TLS cert verification failure.
if notTrustedErrorRe.MatchString(v.Error()) {
return false, v
}
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
return false, v
}
Expand Down Expand Up @@ -565,9 +586,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
var resp *http.Response
var attempt int
var shouldRetry bool
var doErr, checkErr error
var doErr, respErr, checkErr error

for i := 0; ; i++ {
doErr, respErr = nil, nil
attempt++

// Always rewind the request body when non-nil.
Expand Down Expand Up @@ -600,13 +622,21 @@ func (c *Client) Do(req *Request) (*http.Response, error) {

// Check if we should continue with retries.
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)
if !shouldRetry && doErr == nil && req.responseHandler != nil {
respErr = req.responseHandler(resp)
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr)
}

if doErr != nil {
err := doErr
if respErr != nil {
err = respErr
}
if err != nil {
switch v := logger.(type) {
case LeveledLogger:
v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL)
v.Error("request failed", "error", err, "method", req.Method, "url", req.URL)
case Logger:
v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr)
v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err)
}
} else {
// Call this here to maintain the behavior of logging all requests,
Expand Down Expand Up @@ -669,15 +699,19 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
}

// this is the closest we have to success criteria
if doErr == nil && checkErr == nil && !shouldRetry {
if doErr == nil && respErr == nil && checkErr == nil && !shouldRetry {
return resp, nil
}

defer c.HTTPClient.CloseIdleConnections()

err := doErr
var err error
if checkErr != nil {
err = checkErr
} else if respErr != nil {
err = respErr
} else {
err = doErr
}

if c.ErrorHandler != nil {
Expand Down
108 changes: 104 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ func testClientDo(t *testing.T, body interface{}) {
// Send the request
var resp *http.Response
doneCh := make(chan struct{})
errCh := make(chan error, 1)
go func() {
defer close(doneCh)
defer close(errCh)
var err error
resp, err = client.Do(req)
if err != nil {
t.Fatalf("err: %v", err)
}
errCh <- err
}()

select {
Expand Down Expand Up @@ -247,6 +247,106 @@ func testClientDo(t *testing.T, body interface{}) {
if retryCount < 0 {
t.Fatal("request log hook was not called")
}

err = <-errCh
if err != nil {
t.Fatalf("err: %v", err)
}
}

func TestClient_Do_WithResponseHandler(t *testing.T) {
// Create the client. Use short retry windows so we fail faster.
client := NewClient()
client.RetryWaitMin = 10 * time.Millisecond
client.RetryWaitMax = 10 * time.Millisecond
client.RetryMax = 2

var checks int
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
checks++
if err != nil && strings.Contains(err.Error(), "nonretryable") {
return false, nil
}
return DefaultRetryPolicy(context.TODO(), resp, err)
}

// Mock server which always responds 200.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()

var shouldSucceed bool
tests := []struct {
name string
handler ResponseHandlerFunc
expectedChecks int // often 2x number of attempts since we check twice
err string
}{
{
name: "nil handler",
handler: nil,
expectedChecks: 1,
},
{
name: "handler always succeeds",
handler: func(*http.Response) error {
return nil
},
expectedChecks: 2,
},
{
name: "handler always fails in a retryable way",
handler: func(*http.Response) error {
return errors.New("retryable failure")
},
expectedChecks: 6,
},
{
name: "handler always fails in a nonretryable way",
handler: func(*http.Response) error {
return errors.New("nonretryable failure")
},
expectedChecks: 2,
},
{
name: "handler succeeds on second attempt",
handler: func(*http.Response) error {
if shouldSucceed {
return nil
}
shouldSucceed = true
return errors.New("retryable failure")
},
expectedChecks: 4,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
checks = 0
shouldSucceed = false
// Create the request
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
req.SetResponseHandler(tt.handler)

// Send the request.
_, err = client.Do(req)
if err != nil && !strings.Contains(err.Error(), tt.err) {
t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error())
}
if err == nil && tt.err != "" {
t.Fatalf("no error, expected: %s", tt.err)
}

if checks != tt.expectedChecks {
t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks)
}
})
}
}

func TestClient_Do_fails(t *testing.T) {
Expand Down Expand Up @@ -598,7 +698,7 @@ func TestClient_DefaultRetryPolicy_TLS(t *testing.T) {

func TestClient_DefaultRetryPolicy_redirects(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", 302)
http.Redirect(w, r, "/", http.StatusFound)
}))
defer ts.Close()

Expand Down
8 changes: 4 additions & 4 deletions roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {

expectedError := &url.Error{
Op: "Get",
URL: "http://this-url-does-not-exist-ed2fb.com/",
URL: "http://999.999.999.999:999/",
Err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &net.DNSError{
Name: "this-url-does-not-exist-ed2fb.com",
Name: "999.999.999.999",
Err: "no such host",
IsNotFound: true,
},
Expand All @@ -121,10 +121,10 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {

// Get the standard client and execute the request.
client := retryClient.StandardClient()
_, err := client.Get("http://this-url-does-not-exist-ed2fb.com/")
_, err := client.Get("http://999.999.999.999:999/")

// assert expectations
if !reflect.DeepEqual(normalizeError(err), expectedError) {
if !reflect.DeepEqual(expectedError, normalizeError(err)) {
t.Fatalf("expected %q, got %q", expectedError, err)
}
}
Expand Down

0 comments on commit 98169fe

Please sign in to comment.