Skip to content

Commit

Permalink
Sets request's GetBody field on wrapper
Browse files Browse the repository at this point in the history
This commit addresses two key issues:
- Upon receiving a temporary redirect, net/http will only preserve the request's body if GetBody() is defined.
- Upon receiving a GOAWAY frame, the client will create a new connection. We define GetBody() in order to reuse the body sent in the last stream on the now terminated connection.
  • Loading branch information
sebasslash committed Nov 6, 2023
1 parent 571a88b commit 9bb2062
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 5 deletions.
25 changes: 20 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,20 @@ func (r *Request) SetBody(rawBody interface{}) error {
}
r.body = bodyReader
r.ContentLength = contentLength
if bodyReader != nil {
r.GetBody = func() (io.ReadCloser, error) {
body, err := bodyReader()
if err != nil {
return nil, err
}
if rc, ok := body.(io.ReadCloser); ok {
return rc, nil
}
return io.NopCloser(body), nil
}
} else {
r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil }
}
return nil
}

Expand Down Expand Up @@ -302,18 +316,19 @@ func NewRequest(method, url string, rawBody interface{}) (*Request, error) {
// The context controls the entire lifetime of a request and its response:
// obtaining a connection, sending the request, and reading the response headers and body.
func NewRequestWithContext(ctx context.Context, method, url string, rawBody interface{}) (*Request, error) {
bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody)
httpReq, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
}

httpReq, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
req := &Request{
Request: httpReq,
}
if err := req.SetBody(rawBody); err != nil {
return nil, err
}
httpReq.ContentLength = contentLength

return &Request{body: bodyReader, Request: httpReq}, nil
return req, nil
}

// Logger interface allows to use other loggers than
Expand Down
59 changes: 59 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -978,3 +978,62 @@ func TestClient_StandardClient(t *testing.T) {
t.Fatalf("expected %v, got %v", client, v)
}
}

func TestClient_RedirectWithBody(t *testing.T) {
var redirects int32
// Mock server which always responds 200.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.RequestURI {
case "/redirect":
w.Header().Set("Location", "/target")
w.WriteHeader(http.StatusTemporaryRedirect)
case "/target":
atomic.AddInt32(&redirects, 1)
w.WriteHeader(http.StatusCreated)
default:
t.Fatalf("bad uri: %s", r.RequestURI)
}
}))
defer ts.Close()

client := NewClient()
client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) {
if _, err := req.GetBody(); err != nil {
t.Fatalf("unexpected error with GetBody: %v", err)
}
}
// create a request with a body
req, err := NewRequest(http.MethodPost, ts.URL+"/redirect", strings.NewReader(`{"foo":"bar"}`))
if err != nil {
t.Fatalf("err: %v", err)
}

resp, err := client.Do(req)
if err != nil {
t.Fatalf("err: %v", err)
}
resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected status code 201, got: %d", resp.StatusCode)
}

// now one without a body
if err := req.SetBody(nil); err != nil {
t.Fatalf("err: %v", err)
}

resp, err = client.Do(req)
if err != nil {
t.Fatalf("err: %v", err)
}
resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected status code 201, got: %d", resp.StatusCode)
}

if atomic.LoadInt32(&redirects) != 2 {
t.Fatalf("Expected the client to be redirected 2 times, got: %d", atomic.LoadInt32(&redirects))
}
}

0 comments on commit 9bb2062

Please sign in to comment.