diff --git a/client.go b/client.go index f68d0547..a88b6930 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,12 @@ type ( // ErrorHook type is for reacting to request errors, called after all retries were attempted ErrorHook func(*Request, error) + + // Executor executes a Request + Executor func(req *Request) (*Response, error) + + // ExecutorMiddleware type wraps the execution of a request + ExecutorMiddleware func(req *Request, next Executor) (*Response, error) ) // Client struct is used to create Resty client with client level settings, @@ -136,6 +142,7 @@ type Client struct { requestLog RequestLogCallback responseLog ResponseLogCallback errorHooks []ErrorHook + executor Executor } // User type is to hold an username and password information @@ -423,6 +430,28 @@ func (c *Client) OnError(h ErrorHook) *Client { return c } +// WrapExecutor wraps the execution of a request, granting full access to the request, response, and error. +// Runs on every request attempt, before any request hook and after any response or error hook. +// Can be useful to introduce throttling or add hooks that always fire, regardless of success or error. +// +// c.WrapExecutor(func(req *Request, next Executor) (*Response, error) { +// // do something with the Request +// // e.g. Acquire a lock +// +// resp, err := next(req) +// // do something with the Response or error +// // e.g. Release a lock +// +// return resp, err +// }) +func (c *Client) WrapExecutor(e ExecutorMiddleware) *Client { + next := c.executor + c.executor = func(req *Request) (*Response, error) { + return e(req, next) + } + return c +} + // SetPreRequestHook method sets the given pre-request function into resty client. // It is called right before the request is fired. // @@ -1068,6 +1097,8 @@ func createClient(hc *http.Client) *Client { // Logger c.SetLogger(createLogger()) + c.executor = c.execute + // default before request middlewares c.beforeRequest = []RequestMiddleware{ parseRequestURL, diff --git a/client_test.go b/client_test.go index 84ae715e..1e3ce2d8 100644 --- a/client_test.go +++ b/client_test.go @@ -735,6 +735,45 @@ func TestClientOnResponseError(t *testing.T) { } } +func TestWrapExecutor(t *testing.T) { + ts := createGetServer(t) + defer ts.Close() + + t.Run("abort", func(t *testing.T) { + c := dc() + c.WrapExecutor(func(req *Request, next Executor) (*Response, error) { + return nil, fmt.Errorf("abort") + }) + + resp, err := c.R().Get(ts.URL) + assertNil(t, resp) + assertEqual(t, "abort", err.Error()) + }) + + t.Run("noop", func(t *testing.T) { + c := dc() + c.WrapExecutor(func(req *Request, next Executor) (*Response, error) { + return next(req) + }) + + resp, err := c.R().Get(ts.URL) + assertNil(t, err) + assertEqual(t, 200, resp.StatusCode()) + }) + + t.Run("add error", func(t *testing.T) { + c := dc() + c.WrapExecutor(func(req *Request, next Executor) (*Response, error) { + resp, _ := next(req) + return resp, fmt.Errorf("error") + }) + + resp, err := c.R().Get(ts.URL) + assertEqual(t, "error", err.Error()) + assertEqual(t, 200, resp.StatusCode()) + }) +} + func TestResponseError(t *testing.T) { err := errors.New("error message") re := &ResponseError{ diff --git a/example_test.go b/example_test.go index 2d8d3f74..9522855c 100644 --- a/example_test.go +++ b/example_test.go @@ -12,6 +12,7 @@ import ( "net/http" "os" "strconv" + "sync" "time" "golang.org/x/net/proxy" @@ -241,3 +242,35 @@ func Example_socks5Proxy() { func printOutput(resp *resty.Response, err error) { fmt.Println(resp, err) } + +// +// Throttling +// + +func ExampleClient_throttling() { + // Consider the use of proper throttler, possibly waiting for resources to free up + // e.g. https://github.com/throttled/throttled or https://pkg.go.dev/golang.org/x/time/rate + var lock sync.Mutex + currentConcurrent := 0 + maxConcurrent := 10 + + resty.New().WrapExecutor(func(req *resty.Request, next resty.Executor) (*resty.Response, error) { + lock.Lock() + current := currentConcurrent + if current == maxConcurrent { + lock.Unlock() + return nil, fmt.Errorf("max concurrency exceeded") + } + + current++ + lock.Unlock() + + defer func() { + lock.Lock() + current-- + lock.Unlock() + }() + + return next(req) + }) +} diff --git a/request.go b/request.go index 52166e69..5cdc41e3 100644 --- a/request.go +++ b/request.go @@ -753,7 +753,7 @@ func (r *Request) Execute(method, url string) (*Response, error) { if r.client.RetryCount == 0 { r.Attempt = 1 - resp, err = r.client.execute(r) + resp, err = r.client.executor(r) r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err)) return resp, unwrapNoRetryErr(err) } @@ -764,7 +764,7 @@ func (r *Request) Execute(method, url string) (*Response, error) { r.URL = r.selectAddr(addrs, url, r.Attempt) - resp, err = r.client.execute(r) + resp, err = r.client.executor(r) if err != nil { r.client.log.Errorf("%v, Attempt %v", err, r.Attempt) }