Skip to content

Commit

Permalink
fix transport middleware cannot access common header and cookies
Browse files Browse the repository at this point in the history
  • Loading branch information
imroc committed Aug 12, 2023
1 parent ad9fd1e commit abc4962
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 36 deletions.
2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,8 @@ func C() *Client {
Timeout: 2 * time.Minute,
}
beforeRequest := []RequestMiddleware{
parseRequestHeader,
parseRequestCookie,
parseRequestURL,
parseRequestBody,
}
Expand Down
23 changes: 23 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,26 @@ func parseRequestURL(c *Client, r *Request) error {
r.URL = reqURL
return nil
}

func parseRequestHeader(c *Client, r *Request) error {
if c.Headers == nil {
return nil
}
if r.Headers == nil {
r.Headers = make(http.Header)
}
for k, vs := range c.Headers {
if len(r.Headers[k]) == 0 {
r.Headers[k] = vs
}
}
return nil
}

func parseRequestCookie(c *Client, r *Request) error {
if len(c.Cookies) == 0 {
return nil
}
r.Cookies = append(r.Cookies, c.Cookies...)
return nil
}
57 changes: 21 additions & 36 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -854,39 +854,6 @@ func (t *Transport) roundTripAltSvc(req *http.Request, as *altsvc.AltSvc) (resp
return
}

func (t *Transport) ensureHeaderAndCookie(req *http.Request, isHTTP bool) error {
if req.Header == nil {
closeBody(req)
return errors.New("http: nil Request.Header")
}
for k, vs := range t.Headers {
if len(req.Header[k]) == 0 {
req.Header[k] = vs
}
}
for _, c := range t.Cookies {
req.AddCookie(c)
}
if !isHTTP {
return nil
}
// TODO: is h2c should also check this?
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
closeBody(req)
return fmt.Errorf("net/http: invalid header field name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
closeBody(req)
// Don't include the value in the error, because it may be sensitive.
return fmt.Errorf("net/http: invalid header field value for %q", k)
}
}
}
return nil
}

func (t *Transport) checkAltSvc(req *http.Request) (resp *http.Response, err error) {
if t.altSvcJar == nil {
return
Expand Down Expand Up @@ -938,9 +905,27 @@ func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error
scheme := req.URL.Scheme
isHTTP := scheme == "http" || scheme == "https"

err = t.ensureHeaderAndCookie(req, isHTTP)
if err != nil {
return
if isHTTP {
// TODO: is h2c should also check this?
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
closeBody(req)
err = fmt.Errorf("net/http: invalid header field name %q", k)
return
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
closeBody(req)
// Don't include the value in the error, because it may be sensitive.
err = fmt.Errorf("net/http: invalid header field value for %q", k)
return
}
}
}
}

if req.Header == nil {
req.Header = make(http.Header)
}

if t.forceHttpVersion != "" {
Expand Down

0 comments on commit abc4962

Please sign in to comment.