Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry extra handling #161

Merged
merged 7 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 28 additions & 26 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ var (
// ReaderFunc is the type of function that can be given natively to NewRequest
type ReaderFunc func() (io.Reader, error)

// ResponseHandlingFunc is a type of function that takes in a Response, and does something with it.
// It only runs if the initial part of the request was successful.
// If an error is returned, the client's retry policy will be used to determine whether to retry the whole request.
type ResponseHandlingFunc func(*http.Response) error
gavriel-hc marked this conversation as resolved.
Show resolved Hide resolved

// LenReader is an interface implemented by many in-memory io.Reader's. Used
// for automatically sending the right Content-Length header when possible.
type LenReader interface {
Expand All @@ -91,6 +96,8 @@ type Request struct {
// used to rewind the request data in between retries.
body ReaderFunc

responseHandler ResponseHandlingFunc

// Embed an HTTP request directly. This makes a *Request act exactly
// like an *http.Request so that all meta methods are supported.
*http.Request
Expand All @@ -100,11 +107,17 @@ type Request struct {
// with its context changed to ctx. The provided ctx must be non-nil.
func (r *Request) WithContext(ctx context.Context) *Request {
return &Request{
body: r.body,
Request: r.Request.WithContext(ctx),
body: r.body,
responseHandler: r.responseHandler,
Request: r.Request.WithContext(ctx),
}
}

// SetResponseHandler allows setting the response handler.
func (r *Request) SetResponseHandler(fn ResponseHandlingFunc) {
r.responseHandler = fn
}

// BodyBytes allows accessing the request body. It is an analogue to
// http.Request's Body variable, but it returns a copy of the underlying data
// rather than consuming it.
Expand Down Expand Up @@ -259,7 +272,7 @@ func FromRequest(r *http.Request) (*Request, error) {
return nil, err
}
// Could assert contentLength == r.ContentLength
return &Request{bodyReader, r}, nil
return &Request{body: bodyReader, Request: r}, nil
}

// NewRequest creates a new wrapped request.
Expand All @@ -283,7 +296,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, rawBody inte
}
httpReq.ContentLength = contentLength

return &Request{bodyReader, httpReq}, nil
return &Request{body: bodyReader, Request: httpReq}, nil
}

// Logger interface allows to use other loggers than
Expand Down Expand Up @@ -553,11 +566,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo

// Do wraps calling an HTTP method with retries.
func (c *Client) Do(req *Request) (*http.Response, error) {
return c.DoWithResponseHandler(req, nil)
}

// DoWithResponseHandler wraps calling an HTTP method plus a response handler with retries.
func (c *Client) DoWithResponseHandler(req *Request, handler func(*http.Response) (shouldRetry bool)) (*http.Response, error) {
c.clientInit.Do(func() {
if c.HTTPClient == nil {
c.HTTPClient = cleanhttp.DefaultPooledClient()
Expand All @@ -578,7 +586,7 @@ func (c *Client) DoWithResponseHandler(req *Request, handler func(*http.Response
var resp *http.Response
var attempt int
var shouldRetry bool
var doErr, checkErr error
var doErr, respErr, checkErr error

for i := 0; ; i++ {
attempt++
Expand Down Expand Up @@ -636,11 +644,11 @@ func (c *Client) DoWithResponseHandler(req *Request, handler func(*http.Response

// Check if we should continue with retries.
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)

successSoFar := !shouldRetry && doErr == nil && checkErr == nil
if successSoFar && handler != nil {
shouldRetry = handler(resp)
if !shouldRetry && doErr == nil && req.responseHandler != nil {
respErr = req.responseHandler(resp)
gavriel-hc marked this conversation as resolved.
Show resolved Hide resolved
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr)
}

if !shouldRetry {
break
}
Expand Down Expand Up @@ -686,15 +694,19 @@ func (c *Client) DoWithResponseHandler(req *Request, handler func(*http.Response
}

// this is the closest we have to success criteria
if doErr == nil && checkErr == nil && !shouldRetry {
if doErr == nil && respErr == nil && checkErr == nil && !shouldRetry {
return resp, nil
}

defer c.HTTPClient.CloseIdleConnections()

err := doErr
var err error
if checkErr != nil {
err = checkErr
} else if respErr != nil {
err = respErr
} else {
err = doErr
}

if c.ErrorHandler != nil {
Expand Down Expand Up @@ -748,16 +760,6 @@ func (c *Client) Get(url string) (*http.Response, error) {
return c.Do(req)
}

// GetWithResponseHandler is a helper for doing a GET request followed by a function on the response.
// The intention is for this to be used when errors in the response handling should also be retried.
func (c *Client) GetWithResponseHandler(url string, handler func(*http.Response) (shouldRetry bool)) (*http.Response, error) {
req, err := NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return c.DoWithResponseHandler(req, handler)
}

// Head is a shortcut for doing a HEAD request without making a new client.
func Head(url string) (*http.Response, error) {
return defaultClient.Head(url)
Expand Down
72 changes: 45 additions & 27 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,19 @@ func testClientDo(t *testing.T, body interface{}) {
}
}

func TestClient_DoWithHandler(t *testing.T) {
func TestClient_Do_WithResponseHandler(t *testing.T) {
// Create the client. Use short retry windows so we fail faster.
client := NewClient()
client.RetryWaitMin = 10 * time.Millisecond
client.RetryWaitMax = 10 * time.Millisecond
client.RetryMax = 2

var attempts int
var checks int
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
attempts++
checks++
if err != nil && strings.Contains(err.Error(), "nonretryable") {
return false, nil
}
return DefaultRetryPolicy(context.TODO(), resp, err)
}

Expand All @@ -273,59 +276,74 @@ func TestClient_DoWithHandler(t *testing.T) {
}))
defer ts.Close()

alternatingBool := false
var shouldSucceed bool
tests := []struct {
name string
handler func(*http.Response) bool
expectedAttempts int
err string
name string
handler ResponseHandlingFunc
expectedChecks int // often 2x number of attempts since we check twice
err string
}{
{
name: "nil handler",
handler: nil,
expectedAttempts: 1,
name: "nil handler",
handler: nil,
expectedChecks: 1,
},
{
name: "handler always succeeds",
handler: func(*http.Response) error {
return nil
},
expectedChecks: 2,
},
{
name: "handler never should retry",
handler: func(*http.Response) bool { return false },
expectedAttempts: 1,
name: "handler always fails in a retryable way",
handler: func(*http.Response) error {
return errors.New("retryable failure")
},
expectedChecks: 6,
},
{
name: "handler alternates should retry",
handler: func(*http.Response) bool {
alternatingBool = !alternatingBool
return alternatingBool
name: "handler always fails in a nonretryable way",
handler: func(*http.Response) error {
return errors.New("nonretryable failure")
},
expectedAttempts: 2,
expectedChecks: 2,
},
{
name: "handler always should retry",
handler: func(*http.Response) bool { return true },
expectedAttempts: 3,
err: "giving up after 3 attempt(s)",
name: "handler succeeds on second attempt",
handler: func(*http.Response) error {
if shouldSucceed {
return nil
}
shouldSucceed = true
return errors.New("retryable failure")
},
expectedChecks: 4,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
attempts = 0
checks = 0
shouldSucceed = false
// Create the request
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
req.SetResponseHandler(tt.handler)

// Send the request.
_, err = client.DoWithResponseHandler(req, tt.handler)
_, err = client.Do(req)
if err != nil && !strings.Contains(err.Error(), tt.err) {
t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error())
}
if err == nil && tt.err != "" {
t.Fatalf("no error, expected: %s", tt.err)
}

if attempts != tt.expectedAttempts {
t.Fatalf("expected %d attempts, got %d attempts", tt.expectedAttempts, attempts)
if checks != tt.expectedChecks {
t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks)
}
})
}
Expand Down
6 changes: 3 additions & 3 deletions roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {

expectedError := &url.Error{
Op: "Get",
URL: "http://asdfsa.com/",
URL: "http://999.999.999.999:999/",
Err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &net.DNSError{
Name: "asdfsa.com",
Name: "999.999.999.999",
Err: "no such host",
IsNotFound: true,
},
Expand All @@ -121,7 +121,7 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {

// Get the standard client and execute the request.
client := retryClient.StandardClient()
_, err := client.Get("http://asdfsa.com/")
_, err := client.Get("http://999.999.999.999:999/")

// assert expectations
if !reflect.DeepEqual(expectedError, normalizeError(err)) {
Expand Down