diff --git a/middleware.go b/middleware.go index c4ed8d96..91eed9b2 100644 --- a/middleware.go +++ b/middleware.go @@ -263,17 +263,19 @@ func createHTTPRequest(c *Client, r *Request) (err error) { r.RawRequest = r.RawRequest.WithContext(r.ctx) } - bodyCopy, err := getBodyCopy(r) - if err != nil { - return err - } - // assign get body func for the underlying raw request instance - r.RawRequest.GetBody = func() (io.ReadCloser, error) { + if r.RawRequest.GetBody == nil { + bodyCopy, err := getBodyCopy(r) + if err != nil { + return err + } if bodyCopy != nil { - return io.NopCloser(bytes.NewReader(bodyCopy.Bytes())), nil + buf := bodyCopy.Bytes() + r.RawRequest.GetBody = func() (io.ReadCloser, error) { + b := bytes.NewReader(buf) + return io.NopCloser(b), nil + } } - return nil, nil } return diff --git a/retry_test.go b/retry_test.go index 84c12a48..cef9e81c 100644 --- a/retry_test.go +++ b/retry_test.go @@ -219,8 +219,8 @@ func TestClientRetryWait(t *testing.T) { retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := time.Duration(3) * time.Second - retryMaxWaitTime := time.Duration(9) * time.Second + retryWaitTime := time.Duration(50) * time.Millisecond + retryMaxWaitTime := time.Duration(150) * time.Millisecond c := dc(). SetRetryCount(retryCount). @@ -262,7 +262,7 @@ func TestClientRetryWaitMaxInfinite(t *testing.T) { retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := time.Duration(3) * time.Second + retryWaitTime := time.Duration(100) * time.Millisecond retryMaxWaitTime := time.Duration(-1.0) // negative value c := dc(). @@ -319,8 +319,8 @@ func TestClientRetryWaitCallbackError(t *testing.T) { retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := 3 * time.Second - retryMaxWaitTime := 9 * time.Second + retryWaitTime := 50 * time.Millisecond + retryMaxWaitTime := 150 * time.Millisecond retryAfter := func(client *Client, resp *Response) (time.Duration, error) { return 0, errors.New("quota exceeded") @@ -359,11 +359,11 @@ func TestClientRetryWaitCallback(t *testing.T) { retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := 3 * time.Second - retryMaxWaitTime := 9 * time.Second + retryWaitTime := 50 * time.Millisecond + retryMaxWaitTime := 150 * time.Millisecond retryAfter := func(client *Client, resp *Response) (time.Duration, error) { - return 5 * time.Second, nil + return 50 * time.Millisecond, nil } c := dc(). @@ -407,11 +407,11 @@ func TestClientRetryWaitCallbackTooShort(t *testing.T) { retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := 3 * time.Second - retryMaxWaitTime := 9 * time.Second + retryWaitTime := 50 * time.Millisecond + retryMaxWaitTime := 150 * time.Millisecond retryAfter := func(client *Client, resp *Response) (time.Duration, error) { - return 2 * time.Second, nil // too short duration + return 10 * time.Millisecond, nil // too short duration } c := dc(). @@ -455,11 +455,11 @@ func TestClientRetryWaitCallbackTooLong(t *testing.T) { retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones - retryWaitTime := 1 * time.Second - retryMaxWaitTime := 3 * time.Second + retryWaitTime := 10 * time.Millisecond + retryMaxWaitTime := 100 * time.Millisecond retryAfter := func(client *Client, resp *Response) (time.Duration, error) { - return 4 * time.Second, nil // too long duration + return 150 * time.Millisecond, nil // too long duration } c := dc().