From bc158ce7f151c3d966cbfb6fbe4cb6d758f05e56 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 11:59:03 +0800 Subject: [PATCH] Support SetHeaderOrder for HTTP/1.1 Request --- header.go | 70 +++++++++++++++++++-------------------- http_request.go | 12 ++++--- internal/header/header.go | 2 ++ internal/header/sort.go | 37 +++++++++++++++++++++ request.go | 21 ++++++++++++ transfer.go | 30 +++++------------ transport.go | 52 +++++++++++++++++++++++------ 7 files changed, 152 insertions(+), 72 deletions(-) create mode 100644 internal/header/sort.go diff --git a/header.go b/header.go index eb991c49..d31da8d7 100644 --- a/header.go +++ b/header.go @@ -1,10 +1,10 @@ package req import ( + "github.com/imroc/req/v3/internal/header" "golang.org/x/net/http/httpguts" "io" "net/http" - "net/http/httptrace" "net/textproto" "sort" "strings" @@ -22,21 +22,16 @@ func (w stringWriter) WriteString(s string) (n int, err error) { return w.w.Write([]byte(s)) } -type keyValues struct { - key string - values []string -} - // A headerSorter implements sort.Interface by sorting a []keyValues // by key. It's used as a pointer, so it can fit in a sort.Interface // interface value without allocation. type headerSorter struct { - kvs []keyValues + kvs []header.KeyValues } func (s *headerSorter) Len() int { return len(s.kvs) } func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } -func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].Key < s.kvs[j].Key } var headerSorterPool = sync.Pool{ New: func() interface{} { return new(headerSorter) }, @@ -60,15 +55,15 @@ func headerHas(h http.Header, key string) bool { // sortedKeyValues returns h's keys sorted in the returned kvs // slice. The headerSorter used to sort is also returned, for possible // return to headerSorterCache. -func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { +func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []header.KeyValues, hs *headerSorter) { hs = headerSorterPool.Get().(*headerSorter) if cap(hs.kvs) < len(h) { - hs.kvs = make([]keyValues, 0, len(h)) + hs.kvs = make([]header.KeyValues, 0, len(h)) } kvs = hs.kvs[:0] for k, vv := range h { if !exclude[k] { - kvs = append(kvs, keyValues{k, vv}) + kvs = append(kvs, header.KeyValues{k, vv}) } } hs.kvs = kvs @@ -76,43 +71,48 @@ func headerSortedKeyValues(h http.Header, exclude map[string]bool) (kvs []keyVal return kvs, hs } -func headerWrite(h http.Header, w io.Writer, trace *httptrace.ClientTrace) error { - return headerWriteSubset(h, w, nil, trace) +func headerWrite(h http.Header, writeHeader func(key string, values ...string) error, sort bool) error { + return headerWriteSubset(h, nil, writeHeader, sort) } -func headerWriteSubset(h http.Header, w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error { - ws, ok := w.(io.StringWriter) - if !ok { - ws = stringWriter{w} +func headerWriteSubset(h http.Header, exclude map[string]bool, writeHeader func(key string, values ...string) error, sort bool) error { + var kvs []header.KeyValues + var hs *headerSorter + if sort { + kvs = make([]header.KeyValues, 0, len(h)) + for k, v := range h { + if !exclude[k] { + kvs = append(kvs, header.KeyValues{k, v}) + } + } + } else { + kvs, hs = headerSortedKeyValues(h, exclude) } - kvs, sorter := headerSortedKeyValues(h, exclude) - var formattedVals []string for _, kv := range kvs { - if !httpguts.ValidHeaderFieldName(kv.key) { + if !httpguts.ValidHeaderFieldName(kv.Key) { // This could be an error. In the common case of // writing response headers, however, we have no good // way to provide the error back to the server // handler, so just drop invalid headers instead. continue } - for _, v := range kv.values { - v = headerNewlineToSpace.Replace(v) - v = textproto.TrimString(v) - for _, s := range []string{kv.key, ": ", v, "\r\n"} { - if _, err := ws.WriteString(s); err != nil { - headerSorterPool.Put(sorter) - return err - } - } - if trace != nil && trace.WroteHeaderField != nil { - formattedVals = append(formattedVals, v) + for i, v := range kv.Values { + vv := headerNewlineToSpace.Replace(v) + vv = textproto.TrimString(v) + if vv != v { + kv.Values[i] = vv } } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField(kv.key, formattedVals) - formattedVals = nil + err := writeHeader(kv.Key, kv.Values...) + if err != nil { + if hs != nil { + headerSorterPool.Put(hs) + } + return err } } - headerSorterPool.Put(sorter) + if hs != nil { + headerSorterPool.Put(hs) + } return nil } diff --git a/http_request.go b/http_request.go index 2d471b4d..c737a6b3 100644 --- a/http_request.go +++ b/http_request.go @@ -3,6 +3,7 @@ package req import ( "errors" "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/header" "golang.org/x/net/http/httpguts" "net/http" "strings" @@ -87,11 +88,12 @@ func closeRequestBody(r *http.Request) error { // Headers that Request.Write handles itself and should be skipped. var reqWriteExcludeHeader = map[string]bool{ - "Host": true, // not in Header map anyway - "User-Agent": true, - "Content-Length": true, - "Transfer-Encoding": true, - "Trailer": true, + "Host": true, // not in Header map anyway + "User-Agent": true, + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, + header.HeaderOderKey: true, } // requestMethodUsuallyLacksBody reports whether the given request diff --git a/internal/header/header.go b/internal/header/header.go index b6064558..0aead102 100644 --- a/internal/header/header.go +++ b/internal/header/header.go @@ -11,4 +11,6 @@ const ( FormContentType = "application/x-www-form-urlencoded" WwwAuthenticate = "WWW-Authenticate" Authorization = "Authorization" + HeaderOderKey = "__Header_Order__" + PseudoHeaderOderKey = "__Pseudo_Header_Order__" ) diff --git a/internal/header/sort.go b/internal/header/sort.go new file mode 100644 index 00000000..8c768c36 --- /dev/null +++ b/internal/header/sort.go @@ -0,0 +1,37 @@ +package header + +import "sort" + +type KeyValues struct { + Key string + Values []string +} + +type sorter struct { + order map[string]int + kvs []KeyValues +} + +func (s *sorter) Len() int { return len(s.kvs) } +func (s *sorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *sorter) Less(i, j int) bool { + if index, ok := s.order[s.kvs[i].Key]; ok { + i = index + } + if index, ok := s.order[s.kvs[j].Key]; ok { + j = index + } + return i < j +} + +func SortKeyValues(kvs []KeyValues, orderedKeys []string) { + order := make(map[string]int) + for i, key := range orderedKeys { + order[key] = i + } + s := &sorter{ + order: order, + kvs: kvs, + } + sort.Sort(s) +} diff --git a/request.go b/request.go index 63df395d..e57d82bc 100644 --- a/request.go +++ b/request.go @@ -463,6 +463,27 @@ func (r *Request) SetHeaderNonCanonical(key, value string) *Request { return r } +const ( + HeaderOderKey = "__Header_Order__" + PseudoHeaderOderKey = "__Pseudo_Header_Order__" +) + +func (r *Request) SetHeaderOrder(keys ...string) *Request { + if r.Headers == nil { + r.Headers = make(http.Header) + } + r.Headers[HeaderOderKey] = append(r.Headers[HeaderOderKey], keys...) + return r +} + +func (r *Request) SetPseudoHeaderOrder(keys ...string) *Request { + if r.Headers == nil { + r.Headers = make(http.Header) + } + r.Headers[PseudoHeaderOderKey] = append(r.Headers[PseudoHeaderOderKey], keys...) + return r +} + // SetOutputFile set the file that response Body will be downloaded to. func (r *Request) SetOutputFile(file string) *Request { r.isSaveResponse = true diff --git a/transfer.go b/transfer.go index 58e64793..c7c623f8 100644 --- a/transfer.go +++ b/transfer.go @@ -14,7 +14,6 @@ import ( "github.com/imroc/req/v3/internal/dump" "io" "net/http" - "net/http/httptrace" "net/textproto" "reflect" "sort" @@ -245,36 +244,27 @@ func (t *transferWriter) shouldSendContentLength() bool { return false } -func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error { +func (t *transferWriter) writeHeader(writeHeader func(key string, values ...string) error) error { if t.Close && !hasToken(headerGet(t.Header, "Connection"), "close") { - if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { + err := writeHeader("Connection", "close") + if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Connection", []string{"close"}) - } } // Write Content-Length and/or Transfer-Encoding whose values are a // function of the sanitized field triple (Body, ContentLength, // TransferEncoding) if t.shouldSendContentLength() { - if _, err := io.WriteString(w, "Content-Length: "); err != nil { - return err - } - if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil { + err := writeHeader("Content-Length", strconv.FormatInt(t.ContentLength, 10)) + if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)}) - } } else if chunked(t.TransferEncoding) { - if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil { + err := writeHeader("Transfer-Encoding", "chunked") + if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"}) - } } // Write Trailer header @@ -292,12 +282,10 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) sort.Strings(keys) // TODO: could do better allocation-wise here, but trailers are rare, // so being lazy for now. - if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil { + err := writeHeader("Trailer", strings.Join(keys, ",")) + if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Trailer", keys) - } } } diff --git a/transport.go b/transport.go index e18ca80f..f3b4acac 100644 --- a/transport.go +++ b/transport.go @@ -2966,14 +2966,40 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo return err } + _writeHeader := func(key string, values ...string) error { + for _, value := range values { + _, err := fmt.Fprintf(w, "%s: %s\r\n", key, value) + if err != nil { + return err + } + } + if trace != nil && trace.WroteHeaderField != nil { + trace.WroteHeaderField(key, values) + } + return nil + } + + var writeHeader func(key string, values ...string) error + var kvs []header.KeyValues + sort := false + + if r.Header != nil && len(r.Header[header.HeaderOderKey]) > 0 { + writeHeader = func(key string, values ...string) error { + kvs = append(kvs, header.KeyValues{ + Key: key, + Values: values, + }) + return nil + } + sort = true + } else { + writeHeader = _writeHeader + } // Header lines - _, err = fmt.Fprintf(w, "Host: %s\r\n", host) + err = writeHeader("Host", host) if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("Host", []string{host}) - } // Use the defaultUserAgent unless the Header contains one, which // may be blank to not send the header. @@ -2982,13 +3008,10 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo userAgent = r.Header.Get("User-Agent") } if userAgent != "" { - _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent) + err = writeHeader("User-Agent", userAgent) if err != nil { return err } - if trace != nil && trace.WroteHeaderField != nil { - trace.WroteHeaderField("User-Agent", []string{userAgent}) - } } // Process Body,ContentLength,Close,Trailer @@ -2996,23 +3019,30 @@ func (pc *persistConn) writeRequest(r *http.Request, w io.Writer, usingProxy boo if err != nil { return err } - err = tw.writeHeader(w, trace) + err = tw.writeHeader(writeHeader) if err != nil { return err } - err = headerWriteSubset(r.Header, w, reqWriteExcludeHeader, trace) + err = headerWriteSubset(r.Header, reqWriteExcludeHeader, writeHeader, sort) if err != nil { return err } if extraHeaders != nil { - err = headerWrite(extraHeaders, w, trace) + err = headerWrite(extraHeaders, writeHeader, sort) if err != nil { return err } } + if sort { // sort and write headers + header.SortKeyValues(kvs, r.Header[header.HeaderOderKey]) + for _, kv := range kvs { + _writeHeader(kv.Key, kv.Values...) + } + } + _, err = io.WriteString(w, "\r\n") if err != nil { return err