diff --git a/client.go b/client.go index c713423408..2179692702 100644 --- a/client.go +++ b/client.go @@ -48,6 +48,11 @@ type Response = fasthttp.Response // Copy from fasthttp type Args = fasthttp.Args +// RetryIfFunc signature of retry if function +// Request argument passed to RetryIfFunc, if there are any request errors. +// Copy from fasthttp +type RetryIfFunc = fasthttp.RetryIfFunc + var defaultClient Client // Client implements http client. @@ -718,6 +723,14 @@ func (a *Agent) Dest(dest []byte) *Agent { return a } +// RetryIf controls whether a retry should be attempted after an error. +// +// By default, will use isIdempotent function from fasthttp +func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent { + a.HostClient.RetryIf = retryIf + return a +} + /************************** End Agent Setting **************************/ var warnOnce sync.Once diff --git a/client_test.go b/client_test.go index 730ba2c76a..81480b29ec 100644 --- a/client_test.go +++ b/client_test.go @@ -6,6 +6,7 @@ import ( "encoding/base64" stdjson "encoding/json" "errors" + "fmt" "io" "io/ioutil" "mime/multipart" @@ -562,6 +563,66 @@ func Test_Client_Agent_Dest(t *testing.T) { }) } +// readErrorConn is a struct for testing retryIf +type readErrorConn struct { + net.Conn +} + +func (r *readErrorConn) Read(p []byte) (int, error) { + return 0, fmt.Errorf("error") +} + +func (r *readErrorConn) Write(p []byte) (int, error) { + return len(p), nil +} + +func (r *readErrorConn) Close() error { + return nil +} + +func (r *readErrorConn) LocalAddr() net.Addr { + return nil +} + +func (r *readErrorConn) RemoteAddr() net.Addr { + return nil +} +func Test_Client_Agent_RetryIf(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := New(Config{DisableStartupMessage: true}) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + a := Post("http://example.com"). + RetryIf(func(req *Request) bool { + return true + }) + dialsCount := 0 + a.HostClient.Dial = func(addr string) (net.Conn, error) { + dialsCount++ + switch dialsCount { + case 1: + return &readErrorConn{}, nil + case 2: + return &readErrorConn{}, nil + case 3: + return &readErrorConn{}, nil + case 4: + return ln.Dial() + default: + t.Fatalf("unexpected number of dials: %d", dialsCount) + } + panic("unreachable") + } + + _, _, errs := a.String() + utils.AssertEqual(t, dialsCount, 4) + utils.AssertEqual(t, 0, len(errs)) +} + func Test_Client_Stdjson_Gojson(t *testing.T) { type User struct { Account *string `json:"account"`