Skip to content

Commit

Permalink
Support SetHeaderOrder for HTTP/1.1 Request
Browse files Browse the repository at this point in the history
  • Loading branch information
imroc committed Jul 31, 2023
1 parent 7c46314 commit bc158ce
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 72 deletions.
70 changes: 35 additions & 35 deletions header.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package req

import (
"github.com/imroc/req/v3/internal/header"
"golang.org/x/net/http/httpguts"
"io"
"net/http"
"net/http/httptrace"
"net/textproto"
"sort"
"strings"
Expand All @@ -22,21 +22,16 @@ func (w stringWriter) WriteString(s string) (n int, err error) {
return w.w.Write([]byte(s))
}

type keyValues struct {
key string
values []string
}

// A headerSorter implements sort.Interface by sorting a []keyValues
// by key. It's used as a pointer, so it can fit in a sort.Interface
// interface value without allocation.
type headerSorter struct {
kvs []keyValues
kvs []header.KeyValues
}

func (s *headerSorter) Len() int { return len(s.kvs) }
func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] }
func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key }
func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].Key < s.kvs[j].Key }

var headerSorterPool = sync.Pool{
New: func() interface{} { return new(headerSorter) },
Expand All @@ -60,59 +55,64 @@ func headerHas(h http.Header, key string) bool {
// sortedKeyValues returns h's keys sorted in the returned kvs
// slice. The headerSorter used to sort is also returned, for possible
// return to headerSorterCache.
func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []keyValues, hs *headerSorter) {
func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []header.KeyValues, hs *headerSorter) {
hs = headerSorterPool.Get().(*headerSorter)
if cap(hs.kvs) < len(h) {
hs.kvs = make([]keyValues, 0, len(h))
hs.kvs = make([]header.KeyValues, 0, len(h))
}
kvs = hs.kvs[:0]
for k, vv := range h {
if !exclude[k] {
kvs = append(kvs, keyValues{k, vv})
kvs = append(kvs, header.KeyValues{k, vv})
}
}
hs.kvs = kvs
sort.Sort(hs)
return kvs, hs
}

func headerWrite(h http.Header, w io.Writer, trace *httptrace.ClientTrace) error {
return headerWriteSubset(h, w, nil, trace)
func headerWrite(h http.Header, writeHeader func(key string, values ...string) error, sort bool) error {
return headerWriteSubset(h, nil, writeHeader, sort)
}

func headerWriteSubset(h http.Header, w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error {
ws, ok := w.(io.StringWriter)
if !ok {
ws = stringWriter{w}
func headerWriteSubset(h http.Header, exclude map[string]bool, writeHeader func(key string, values ...string) error, sort bool) error {
var kvs []header.KeyValues
var hs *headerSorter
if sort {
kvs = make([]header.KeyValues, 0, len(h))
for k, v := range h {
if !exclude[k] {
kvs = append(kvs, header.KeyValues{k, v})
}
}
} else {
kvs, hs = headerSortedKeyValues(h, exclude)
}
kvs, sorter := headerSortedKeyValues(h, exclude)
var formattedVals []string
for _, kv := range kvs {
if !httpguts.ValidHeaderFieldName(kv.key) {
if !httpguts.ValidHeaderFieldName(kv.Key) {
// This could be an error. In the common case of
// writing response headers, however, we have no good
// way to provide the error back to the server
// handler, so just drop invalid headers instead.
continue
}
for _, v := range kv.values {
v = headerNewlineToSpace.Replace(v)
v = textproto.TrimString(v)
for _, s := range []string{kv.key, ": ", v, "\r\n"} {
if _, err := ws.WriteString(s); err != nil {
headerSorterPool.Put(sorter)
return err
}
}
if trace != nil && trace.WroteHeaderField != nil {
formattedVals = append(formattedVals, v)
for i, v := range kv.Values {
vv := headerNewlineToSpace.Replace(v)
vv = textproto.TrimString(v)
if vv != v {
kv.Values[i] = vv
}
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(kv.key, formattedVals)
formattedVals = nil
err := writeHeader(kv.Key, kv.Values...)
if err != nil {
if hs != nil {
headerSorterPool.Put(hs)
}
return err
}
}
headerSorterPool.Put(sorter)
if hs != nil {
headerSorterPool.Put(hs)
}
return nil
}
12 changes: 7 additions & 5 deletions http_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package req
import (
"errors"
"github.com/imroc/req/v3/internal/ascii"
"github.com/imroc/req/v3/internal/header"
"golang.org/x/net/http/httpguts"
"net/http"
"strings"
Expand Down Expand Up @@ -87,11 +88,12 @@ func closeRequestBody(r *http.Request) error {

// Headers that Request.Write handles itself and should be skipped.
var reqWriteExcludeHeader = map[string]bool{
"Host": true, // not in Header map anyway
"User-Agent": true,
"Content-Length": true,
"Transfer-Encoding": true,
"Trailer": true,
"Host": true, // not in Header map anyway
"User-Agent": true,
"Content-Length": true,
"Transfer-Encoding": true,
"Trailer": true,
header.HeaderOderKey: true,
}

// requestMethodUsuallyLacksBody reports whether the given request
Expand Down
2 changes: 2 additions & 0 deletions internal/header/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ const (
FormContentType = "application/x-www-form-urlencoded"
WwwAuthenticate = "WWW-Authenticate"
Authorization = "Authorization"
HeaderOderKey = "__Header_Order__"
PseudoHeaderOderKey = "__Pseudo_Header_Order__"
)
37 changes: 37 additions & 0 deletions internal/header/sort.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package header

import "sort"

type KeyValues struct {
Key string
Values []string
}

type sorter struct {
order map[string]int
kvs []KeyValues
}

func (s *sorter) Len() int { return len(s.kvs) }
func (s *sorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] }
func (s *sorter) Less(i, j int) bool {
if index, ok := s.order[s.kvs[i].Key]; ok {
i = index
}
if index, ok := s.order[s.kvs[j].Key]; ok {
j = index
}
return i < j
}

func SortKeyValues(kvs []KeyValues, orderedKeys []string) {
order := make(map[string]int)
for i, key := range orderedKeys {
order[key] = i
}
s := &sorter{
order: order,
kvs: kvs,
}
sort.Sort(s)
}
21 changes: 21 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,27 @@ func (r *Request) SetHeaderNonCanonical(key, value string) *Request {
return r
}

const (
HeaderOderKey = "__Header_Order__"
PseudoHeaderOderKey = "__Pseudo_Header_Order__"
)

func (r *Request) SetHeaderOrder(keys ...string) *Request {
if r.Headers == nil {
r.Headers = make(http.Header)
}
r.Headers[HeaderOderKey] = append(r.Headers[HeaderOderKey], keys...)
return r
}

func (r *Request) SetPseudoHeaderOrder(keys ...string) *Request {
if r.Headers == nil {
r.Headers = make(http.Header)
}
r.Headers[PseudoHeaderOderKey] = append(r.Headers[PseudoHeaderOderKey], keys...)
return r
}

// SetOutputFile set the file that response Body will be downloaded to.
func (r *Request) SetOutputFile(file string) *Request {
r.isSaveResponse = true
Expand Down
30 changes: 9 additions & 21 deletions transfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/imroc/req/v3/internal/dump"
"io"
"net/http"
"net/http/httptrace"
"net/textproto"
"reflect"
"sort"
Expand Down Expand Up @@ -245,36 +244,27 @@ func (t *transferWriter) shouldSendContentLength() bool {
return false
}

func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error {
func (t *transferWriter) writeHeader(writeHeader func(key string, values ...string) error) error {
if t.Close && !hasToken(headerGet(t.Header, "Connection"), "close") {
if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil {
err := writeHeader("Connection", "close")
if err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Connection", []string{"close"})
}
}

// Write Content-Length and/or Transfer-Encoding whose values are a
// function of the sanitized field triple (Body, ContentLength,
// TransferEncoding)
if t.shouldSendContentLength() {
if _, err := io.WriteString(w, "Content-Length: "); err != nil {
return err
}
if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil {
err := writeHeader("Content-Length", strconv.FormatInt(t.ContentLength, 10))
if err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)})
}
} else if chunked(t.TransferEncoding) {
if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil {
err := writeHeader("Transfer-Encoding", "chunked")
if err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"})
}
}

// Write Trailer header
Expand All @@ -292,12 +282,10 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace)
sort.Strings(keys)
// TODO: could do better allocation-wise here, but trailers are rare,
// so being lazy for now.
if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil {
err := writeHeader("Trailer", strings.Join(keys, ","))
if err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Trailer", keys)
}
}
}

Expand Down
52 changes: 41 additions & 11 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -2966,14 +2966,40 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo
return err
}

_writeHeader := func(key string, values ...string) error {
for _, value := range values {
_, err := fmt.Fprintf(w, "%s: %s\r\n", key, value)
if err != nil {
return err
}
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(key, values)
}
return nil
}

var writeHeader func(key string, values ...string) error
var kvs []header.KeyValues
sort := false

if r.Header != nil && len(r.Header[header.HeaderOderKey]) > 0 {
writeHeader = func(key string, values ...string) error {
kvs = append(kvs, header.KeyValues{
Key: key,
Values: values,
})
return nil
}
sort = true
} else {
writeHeader = _writeHeader
}
// Header lines
_, err = fmt.Fprintf(w, "Host: %s\r\n", host)
err = writeHeader("Host", host)
if err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Host", []string{host})
}

// Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header.
Expand All @@ -2982,37 +3008,41 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo
userAgent = r.Header.Get("User-Agent")
}
if userAgent != "" {
_, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
err = writeHeader("User-Agent", userAgent)
if err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("User-Agent", []string{userAgent})
}
}

// Process Body,ContentLength,Close,Trailer
tw, err := newTransferWriter(r)
if err != nil {
return err
}
err = tw.writeHeader(w, trace)
err = tw.writeHeader(writeHeader)
if err != nil {
return err
}

err = headerWriteSubset(r.Header, w, reqWriteExcludeHeader, trace)
err = headerWriteSubset(r.Header, reqWriteExcludeHeader, writeHeader, sort)
if err != nil {
return err
}

if extraHeaders != nil {
err = headerWrite(extraHeaders, w, trace)
err = headerWrite(extraHeaders, writeHeader, sort)
if err != nil {
return err
}
}

if sort { // sort and write headers
header.SortKeyValues(kvs, r.Header[header.HeaderOderKey])
for _, kv := range kvs {
_writeHeader(kv.Key, kv.Values...)
}
}

_, err = io.WriteString(w, "\r\n")
if err != nil {
return err
Expand Down

0 comments on commit bc158ce

Please sign in to comment.