diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 84330c6b..85226415 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "net/http/httptrace" "strconv" "strings" "sync" @@ -70,6 +71,25 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, return err } +var reqWriteExcludeHeader = map[string]bool{ + // Host is :authority, already sent. + // Content-Length is automatic. + "host": true, + "content-length": true, + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + "connection": true, + "proxy-connection": true, + "transfer-encoding": true, + "upgrade": true, + "keep-alive": true, + // Ignore header order keys which is only used internally. + header.HeaderOderKey: true, + header.PseudoHeaderOderKey: true, +} + // copied from net/transport.go // Modified to support Extended CONNECT: // Contrary to what the godoc for the http.Request says, @@ -121,37 +141,73 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } enumerateHeaders := func(f func(name, value string)) { + var writeHeader func(name string, value ...string) + var kvs []header.KeyValues + sort := false + if req.Header != nil && len(req.Header[header.PseudoHeaderOderKey]) > 0 { + writeHeader = func(name string, value ...string) { + kvs = append(kvs, header.KeyValues{ + Key: name, + Values: value, + }) + } + sort = true + } else { + writeHeader = func(name string, value ...string) { + for _, v := range value { + f(name, v) + } + } + } // 8.1.2.3 Request Pseudo-Header Fields // The :path pseudo-header field includes the path and query parts of the // target URI (the path-absolute production and optionally a '?' character // followed by the query production (see Sections 3.3 and 3.4 of // [RFC3986]). - f(":authority", host) - f(":method", req.Method) + writeHeader(":authority", host) + writeHeader(":method", req.Method) if req.Method != http.MethodConnect || isExtendedConnect { - f(":path", path) - f(":scheme", req.URL.Scheme) + writeHeader(":path", path) + writeHeader(":scheme", req.URL.Scheme) } if isExtendedConnect { - f(":protocol", req.Proto) + writeHeader(":protocol", req.Proto) } + + if sort { + header.SortKeyValues(kvs, req.Header[header.PseudoHeaderOderKey]) + for _, kv := range kvs { + for _, v := range kv.Values { + f(kv.Key, v) + } + } + } + + if req.Header != nil && len(req.Header[header.HeaderOderKey]) > 0 { + sort = true + kvs = nil + writeHeader = func(name string, value ...string) { + kvs = append(kvs, header.KeyValues{ + Key: name, + Values: value, + }) + } + } else { + sort = false + writeHeader = func(name string, value ...string) { + for _, v := range value { + f(name, v) + } + } + } + if trailers != "" { - f("trailer", trailers) + writeHeader("trailer", trailers) } var didUA bool for k, vv := range req.Header { - if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") { - // Host is :authority, already sent. - // Content-Length is automatic, set below. - continue - } else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") || - strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") || - strings.EqualFold(k, "keep-alive") { - // Per 8.1.2.2 Connection-Specific Header - // Fields, don't send connection-specific - // fields. We have already checked if any - // are error-worthy so just ignore the rest. + if reqWriteExcludeHeader[strings.ToLower(k)] { continue } else if strings.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one @@ -170,17 +226,26 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } for _, v := range vv { - f(k, v) + writeHeader(k, v) } } if shouldSendReqContentLength(req.Method, contentLength) { - f("content-length", strconv.FormatInt(contentLength, 10)) + writeHeader("content-length", strconv.FormatInt(contentLength, 10)) } if addGzipHeader { - f("accept-encoding", "gzip") + writeHeader("accept-encoding", "gzip") } if !didUA { - f("user-agent", header.DefaultUserAgent) + writeHeader("user-agent", header.DefaultUserAgent) + } + + if sort { + header.SortKeyValues(kvs, req.Header[header.HeaderOderKey]) + for _, kv := range kvs { + for _, v := range kv.Values { + f(kv.Key, v) + } + } } } @@ -199,19 +264,19 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra // return errRequestHeaderListSize // } - // trace := httptrace.ContextClientTrace(req.Context()) - // traceHeaders := traceHasWroteHeaderField(trace) + trace := httptrace.ContextClientTrace(req.Context()) + traceHeaders := traceHasWroteHeaderField(trace) // Header list size is ok. Write the headers. - enumerateHeaders(func(name, value string) { + enumerateHeaders(func(name, vlue string) { name = strings.ToLower(name) for _, dump := range dumps { dump.DumpRequestHeader([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) } w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value}) - // if traceHeaders { - // traceWroteHeaderField(trace, name, value) - // } + if traceHeaders { + trace.WroteHeaderField(name, []string{value}) + } }) for _, dump := range dumps { diff --git a/internal/http3/trace.go b/internal/http3/trace.go new file mode 100644 index 00000000..710072de --- /dev/null +++ b/internal/http3/trace.go @@ -0,0 +1,7 @@ +package http3 + +import "net/http/httptrace" + +func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { + return trace != nil && trace.WroteHeaderField != nil +}