From 22e6fc978118b8a830efac5dde802aa766deb6b4 Mon Sep 17 00:00:00 2001 From: roc Date: Mon, 31 Jul 2023 16:20:46 +0800 Subject: [PATCH] Support SetHeaderOrder and SetPseudoHeaderOrder for HTTP/2 Request --- internal/header/header.go | 4 +- internal/header/sort.go | 11 ++-- internal/http2/transport.go | 117 ++++++++++++++++++++++++++++-------- request.go | 4 +- 4 files changed, 102 insertions(+), 34 deletions(-) diff --git a/internal/header/header.go b/internal/header/header.go index 0aead102..33e779bc 100644 --- a/internal/header/header.go +++ b/internal/header/header.go @@ -11,6 +11,6 @@ const ( FormContentType = "application/x-www-form-urlencoded" WwwAuthenticate = "WWW-Authenticate" Authorization = "Authorization" - HeaderOderKey = "__Header_Order__" - PseudoHeaderOderKey = "__Pseudo_Header_Order__" + HeaderOderKey = "__header_order__" + PseudoHeaderOderKey = "__pseudo_header_order__" ) diff --git a/internal/header/sort.go b/internal/header/sort.go index 8c768c36..2c61fd2e 100644 --- a/internal/header/sort.go +++ b/internal/header/sort.go @@ -1,6 +1,9 @@ package header -import "sort" +import ( + "net/textproto" + "sort" +) type KeyValues struct { Key string @@ -15,10 +18,10 @@ type sorter struct { func (s *sorter) Len() int { return len(s.kvs) } func (s *sorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } func (s *sorter) Less(i, j int) bool { - if index, ok := s.order[s.kvs[i].Key]; ok { + if index, ok := s.order[textproto.CanonicalMIMEHeaderKey(s.kvs[i].Key)]; ok { i = index } - if index, ok := s.order[s.kvs[j].Key]; ok { + if index, ok := s.order[textproto.CanonicalMIMEHeaderKey(s.kvs[j].Key)]; ok { j = index } return i < j @@ -27,7 +30,7 @@ func (s *sorter) Less(i, j int) bool { func SortKeyValues(kvs []KeyValues, orderedKeys []string) { order := make(map[string]int) for i, key := range orderedKeys { - order[key] = i + order[textproto.CanonicalMIMEHeaderKey(key)] = i } s := &sorter{ order: order, diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 21bf6099..1208cbeb 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -1748,6 +1748,25 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) } } +var reqWriteExcludeHeader = map[string]bool{ + // Host is :authority, already sent. + // Content-Length is automatic. + "host": true, + "content-length": true, + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + "connection": true, + "proxy-connection": true, + "transfer-encoding": true, + "upgrade": true, + "keep-alive": true, + // Ignore header order keys which is only used internally. + header.HeaderOderKey: true, + header.PseudoHeaderOderKey: true, +} + var errNilRequestURL = errors.New("http2: Request.URI is nil") // requires cc.wmu be held. @@ -1797,40 +1816,75 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } enumerateHeaders := func(f func(name, value string)) { + var writeHeader func(name string, value ...string) + var kvs []header.KeyValues + sort := false + + if req.Header != nil && len(req.Header[header.PseudoHeaderOderKey]) > 0 { + writeHeader = func(name string, value ...string) { + kvs = append(kvs, header.KeyValues{ + Key: name, + Values: value, + }) + } + sort = true + } else { + writeHeader = func(name string, value ...string) { + for _, v := range value { + f(name, v) + } + } + } + // 8.1.2.3 Request Pseudo-Header Fields // The :path pseudo-header field includes the path and query parts of the // target URI (the path-absolute production and optionally a '?' character // followed by the query production, see Sections 3.3 and 3.4 of // [RFC3986]). - f(":authority", host) + writeHeader(":authority", host) m := req.Method if m == "" { m = http.MethodGet } - f(":method", m) + writeHeader(":method", m) if req.Method != "CONNECT" { - f(":path", path) - f(":scheme", req.URL.Scheme) + writeHeader(":path", path) + writeHeader(":scheme", req.URL.Scheme) + } + if sort { + header.SortKeyValues(kvs, req.Header[header.PseudoHeaderOderKey]) + for _, kv := range kvs { + for _, v := range kv.Values { + f(kv.Key, v) + } + } } + + if req.Header != nil && len(req.Header[header.HeaderOderKey]) > 0 { + sort = true + kvs = nil + writeHeader = func(name string, value ...string) { + kvs = append(kvs, header.KeyValues{ + Key: name, + Values: value, + }) + } + } else { + sort = false + writeHeader = func(name string, value ...string) { + for _, v := range value { + f(name, v) + } + } + } + if trailers != "" { - f("trailer", trailers) + writeHeader("trailer", trailers) } var didUA bool for k, vv := range req.Header { - if ascii.EqualFold(k, "host") || ascii.EqualFold(k, "content-length") { - // Host is :authority, already sent. - // Content-Length is automatic, set below. - continue - } else if ascii.EqualFold(k, "connection") || - ascii.EqualFold(k, "proxy-connection") || - ascii.EqualFold(k, "transfer-encoding") || - ascii.EqualFold(k, "upgrade") || - ascii.EqualFold(k, "keep-alive") { - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. + if reqWriteExcludeHeader[strings.ToLower(k)] { continue } else if ascii.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one @@ -1846,6 +1900,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail continue } } else if ascii.EqualFold(k, "cookie") { + var vals []string // Per 8.1.2.5 To allow for better compression efficiency, the // Cookie header field MAY be split into separate header fields, // each with one or more cookie-pairs. @@ -1855,7 +1910,8 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail if p < 0 { break } - f("cookie", v[:p]) + vals = append(vals, v[:p]) + //writeHeader("cookie", v[:p]) p++ // strip space after semicolon if any. for p+1 <= len(v) && v[p] == ' ' { @@ -1864,24 +1920,33 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail v = v[p:] } if len(v) > 0 { - f("cookie", v) + vals = append(vals, v) + //writeHeader("cookie", v) } } + writeHeader("cookie", vals...) continue } - for _, v := range vv { - f(k, v) - } + writeHeader(k, vv...) } if shouldSendReqContentLength(req.Method, contentLength) { - f("content-length", strconv.FormatInt(contentLength, 10)) + writeHeader("content-length", strconv.FormatInt(contentLength, 10)) } if addGzipHeader { - f("accept-encoding", "gzip") + writeHeader("accept-encoding", "gzip") } if !didUA { - f("user-agent", header.DefaultUserAgent) + writeHeader("user-agent", header.DefaultUserAgent) + } + + if sort { + header.SortKeyValues(kvs, req.Header[header.HeaderOderKey]) + for _, kv := range kvs { + for _, v := range kv.Values { + f(kv.Key, v) + } + } } } diff --git a/request.go b/request.go index e57d82bc..7e705bb6 100644 --- a/request.go +++ b/request.go @@ -464,8 +464,8 @@ func (r *Request) SetHeaderNonCanonical(key, value string) *Request { } const ( - HeaderOderKey = "__Header_Order__" - PseudoHeaderOderKey = "__Pseudo_Header_Order__" + HeaderOderKey = "__header_order__" + PseudoHeaderOderKey = "__pseudo_header_order__" ) func (r *Request) SetHeaderOrder(keys ...string) *Request {