diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index 1a97b01db8d94..77ab0343f4222 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -954,75 +954,6 @@ func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*ht return p.getClientConn(req, addr, http2noDialOnMiss) } -func http2configureTransport(t1 *Transport) (*http2Transport, error) { - connPool := new(http2clientConnPool) - t2 := &http2Transport{ - ConnPool: http2noDialClientConnPool{connPool}, - t1: t1, - } - connPool.t = t2 - if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { - return nil, err - } - if t1.TLSClientConfig == nil { - t1.TLSClientConfig = new(tls.Config) - } - if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { - t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) - } - if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { - t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") - } - upgradeFn := func(authority string, c *tls.Conn) RoundTripper { - addr := http2authorityAddr("https", authority) - if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { - go c.Close() - return http2erringRoundTripper{err} - } else if !used { - // Turns out we don't need this c. - // For example, two goroutines made requests to the same host - // at the same time, both kicking off TCP dials. (since protocol - // was unknown) - go c.Close() - } - return t2 - } - if m := t1.TLSNextProto; len(m) == 0 { - t1.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{ - "h2": upgradeFn, - } - } else { - m["h2"] = upgradeFn - } - return t2, nil -} - -// registerHTTPSProtocol calls Transport.RegisterProtocol but -// converting panics into errors. -func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) { - defer func() { - if e := recover(); e != nil { - err = fmt.Errorf("%v", e) - } - }() - t.RegisterProtocol("https", rt) - return nil -} - -// noDialH2RoundTripper is a RoundTripper which only tries to complete the request -// if there's already has a cached connection to the host. -// (The field is exported so it can be accessed via reflect from net/http; tested -// by TestNoDialH2RoundTripperType) -type http2noDialH2RoundTripper struct{ *http2Transport } - -func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) { - res, err := rt.http2Transport.RoundTrip(req) - if http2isNoCachedConnError(err) { - return nil, ErrSkipAltProtocol - } - return res, err -} - // Buffer chunks are allocated from a pool to reduce pressure on GC. // The maximum wasted space per dataBuffer is 2x the largest size class, // which happens when the dataBuffer has multiple chunks and there is @@ -2788,7 +2719,7 @@ func (fr *http2Framer) maxHeaderStringLen() int { } // readMetaFrame returns 0 or more CONTINUATION frames from fr and -// merge them into into the provided hf and returns a MetaHeadersFrame +// merge them into the provided hf and returns a MetaHeadersFrame // with the decoded hpack values. func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFrame, error) { if fr.AllowIllegalReads { @@ -2924,181 +2855,23 @@ func http2summarizeFrame(f http2Frame) string { return buf.String() } -func http2traceHasWroteHeaderField(trace *http2clientTrace) bool { +func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool { return trace != nil && trace.WroteHeaderField != nil } -func http2traceWroteHeaderField(trace *http2clientTrace, k, v string) { +func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) { if trace != nil && trace.WroteHeaderField != nil { trace.WroteHeaderField(k, []string{v}) } } -func http2traceGot1xxResponseFunc(trace *http2clientTrace) func(int, textproto.MIMEHeader) error { +func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error { if trace != nil { return trace.Got1xxResponse } return nil } -func http2transportExpectContinueTimeout(t1 *Transport) time.Duration { - return t1.ExpectContinueTimeout -} - -type http2contextContext interface { - context.Context -} - -var http2errCanceled = context.Canceled - -func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) { - ctx, cancel = context.WithCancel(context.Background()) - ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr()) - if hs := opts.baseConfig(); hs != nil { - ctx = context.WithValue(ctx, ServerContextKey, hs) - } - return -} - -func http2contextWithCancel(ctx http2contextContext) (_ http2contextContext, cancel func()) { - return context.WithCancel(ctx) -} - -func http2requestWithContext(req *Request, ctx http2contextContext) *Request { - return req.WithContext(ctx) -} - -type http2clientTrace httptrace.ClientTrace - -func http2reqContext(r *Request) context.Context { return r.Context() } - -func (t *http2Transport) idleConnTimeout() time.Duration { - if t.t1 != nil { - return t.t1.IdleConnTimeout - } - return 0 -} - -func http2setResponseUncompressed(res *Response) { res.Uncompressed = true } - -func http2traceGetConn(req *Request, hostPort string) { - trace := httptrace.ContextClientTrace(req.Context()) - if trace == nil || trace.GetConn == nil { - return - } - trace.GetConn(hostPort) -} - -func http2traceGotConn(req *Request, cc *http2ClientConn) { - trace := httptrace.ContextClientTrace(req.Context()) - if trace == nil || trace.GotConn == nil { - return - } - ci := httptrace.GotConnInfo{Conn: cc.tconn} - cc.mu.Lock() - ci.Reused = cc.nextStreamID > 1 - ci.WasIdle = len(cc.streams) == 0 && ci.Reused - if ci.WasIdle && !cc.lastActive.IsZero() { - ci.IdleTime = time.Now().Sub(cc.lastActive) - } - cc.mu.Unlock() - - trace.GotConn(ci) -} - -func http2traceWroteHeaders(trace *http2clientTrace) { - if trace != nil && trace.WroteHeaders != nil { - trace.WroteHeaders() - } -} - -func http2traceGot100Continue(trace *http2clientTrace) { - if trace != nil && trace.Got100Continue != nil { - trace.Got100Continue() - } -} - -func http2traceWait100Continue(trace *http2clientTrace) { - if trace != nil && trace.Wait100Continue != nil { - trace.Wait100Continue() - } -} - -func http2traceWroteRequest(trace *http2clientTrace, err error) { - if trace != nil && trace.WroteRequest != nil { - trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) - } -} - -func http2traceFirstResponseByte(trace *http2clientTrace) { - if trace != nil && trace.GotFirstResponseByte != nil { - trace.GotFirstResponseByte() - } -} - -func http2requestTrace(req *Request) *http2clientTrace { - trace := httptrace.ContextClientTrace(req.Context()) - return (*http2clientTrace)(trace) -} - -// Ping sends a PING frame to the server and waits for the ack. -func (cc *http2ClientConn) Ping(ctx context.Context) error { - return cc.ping(ctx) -} - -// Shutdown gracefully closes the client connection, waiting for running streams to complete. -func (cc *http2ClientConn) Shutdown(ctx context.Context) error { - return cc.shutdown(ctx) -} - -func http2cloneTLSConfig(c *tls.Config) *tls.Config { - c2 := c.Clone() - c2.GetClientCertificate = c.GetClientCertificate // golang.org/issue/19264 - return c2 -} - -var _ Pusher = (*http2responseWriter)(nil) - -// Push implements http.Pusher. -func (w *http2responseWriter) Push(target string, opts *PushOptions) error { - internalOpts := http2pushOptions{} - if opts != nil { - internalOpts.Method = opts.Method - internalOpts.Header = opts.Header - } - return w.push(target, internalOpts) -} - -func http2configureServer18(h1 *Server, h2 *http2Server) error { - if h2.IdleTimeout == 0 { - if h1.IdleTimeout != 0 { - h2.IdleTimeout = h1.IdleTimeout - } else { - h2.IdleTimeout = h1.ReadTimeout - } - } - return nil -} - -func http2shouldLogPanic(panicValue interface{}) bool { - return panicValue != nil && panicValue != ErrAbortHandler -} - -func http2reqGetBody(req *Request) func() (io.ReadCloser, error) { - return req.GetBody -} - -func http2reqBodyIsNoBody(body io.ReadCloser) bool { - return body == NoBody -} - -func http2go18httpNoBody() io.ReadCloser { return NoBody } // for tests only - -func http2configureServer19(s *Server, conf *http2Server) error { - s.RegisterOnShutdown(conf.state.startGracefulShutdown) - return nil -} - var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" type http2goroutineLock uint64 @@ -3252,12 +3025,17 @@ func http2cutoff64(base int) uint64 { } var ( - http2commonLowerHeader = map[string]string{} // Go-Canonical-Case -> lower-case - http2commonCanonHeader = map[string]string{} // lower-case -> Go-Canonical-Case + http2commonBuildOnce sync.Once + http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case + http2commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case ) -func init() { - for _, v := range []string{ +func http2buildCommonHeaderMapsOnce() { + http2commonBuildOnce.Do(http2buildCommonHeaderMaps) +} + +func http2buildCommonHeaderMaps() { + common := []string{ "accept", "accept-charset", "accept-encoding", @@ -3305,7 +3083,10 @@ func init() { "vary", "via", "www-authenticate", - } { + } + http2commonLowerHeader = make(map[string]string, len(common)) + http2commonCanonHeader = make(map[string]string, len(common)) + for _, v := range common { chk := CanonicalHeaderKey(v) http2commonLowerHeader[chk] = v http2commonCanonHeader[v] = chk @@ -3313,6 +3094,7 @@ func init() { } func http2lowerHeader(v string) string { + http2buildCommonHeaderMapsOnce() if s, ok := http2commonLowerHeader[v]; ok { return s } @@ -3488,19 +3270,12 @@ func http2validWireHeaderFieldName(v string) bool { return true } -var http2httpCodeStringCommon = map[int]string{} // n -> strconv.Itoa(n) - -func init() { - for i := 100; i <= 999; i++ { - if v := StatusText(i); v != "" { - http2httpCodeStringCommon[i] = strconv.Itoa(i) - } - } -} - func http2httpCodeString(code int) string { - if s, ok := http2httpCodeStringCommon[code]; ok { - return s + switch code { + case 200: + return "200" + case 404: + return "404" } return strconv.Itoa(code) } @@ -3993,12 +3768,14 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { conf = new(http2Server) } conf.state = &http2serverInternalState{activeConns: make(map[*http2serverConn]struct{})} - if err := http2configureServer18(s, conf); err != nil { - return err - } - if err := http2configureServer19(s, conf); err != nil { - return err + if h1, h2 := s, conf; h2.IdleTimeout == 0 { + if h1.IdleTimeout != 0 { + h2.IdleTimeout = h1.IdleTimeout + } else { + h2.IdleTimeout = h1.ReadTimeout + } } + s.RegisterOnShutdown(conf.state.startGracefulShutdown) if s.TLSConfig == nil { s.TLSConfig = new(tls.Config) @@ -4219,6 +3996,15 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { sc.serve() } +func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) { + ctx, cancel = context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr()) + if hs := opts.baseConfig(); hs != nil { + ctx = context.WithValue(ctx, ServerContextKey, hs) + } + return +} + func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) // ignoring errors. hanging up anyway. @@ -4234,7 +4020,7 @@ type http2serverConn struct { conn net.Conn bw *http2bufferedWriter // writing to conn handler Handler - baseCtx http2contextContext + baseCtx context.Context framer *http2Framer doneServing chan struct{} // closed when serverConn.serve ends readFrameCh chan http2readFrameResult // written by serverConn.readFrames @@ -4314,7 +4100,7 @@ type http2stream struct { id uint32 body *http2pipe // non-nil if expecting DATA frames cw http2closeWaiter // closed wait stream transitions to closed state - ctx http2contextContext + ctx context.Context cancelCtx func() // owned by serverConn's serve loop: @@ -4450,6 +4236,7 @@ func (sc *http2serverConn) condlogf(err error, format string, args ...interface{ func (sc *http2serverConn) canonicalHeader(v string) string { sc.serveG.check() + http2buildCommonHeaderMapsOnce() cv, ok := http2commonCanonHeader[v] if ok { return cv @@ -4898,7 +4685,7 @@ func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { // errHandlerPanicked is the error given to any callers blocked in a read from // Request.Body when the main goroutine panics. Since most handlers read in the -// the main ServeHTTP goroutine, this will show up rarely. +// main ServeHTTP goroutine, this will show up rarely. var http2errHandlerPanicked = errors.New("http2: handler panicked") // wroteFrame is called on the serve goroutine with the result of @@ -5670,7 +5457,7 @@ func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState panic("internal error: cannot create stream with id 0") } - ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) + ctx, cancelCtx := context.WithCancel(sc.baseCtx) st := &http2stream{ sc: sc, id: id, @@ -5836,7 +5623,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re Body: body, Trailer: trailer, } - req = http2requestWithContext(req, st.ctx) + req = req.WithContext(st.ctx) rws := http2responseWriterStatePool.Get().(*http2responseWriterState) bwSave := rws.bw @@ -5864,7 +5651,7 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, han stream: rw.rws.stream, }) // Same as net/http: - if http2shouldLogPanic(e) { + if e != nil && e != ErrAbortHandler { const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] @@ -6426,14 +6213,9 @@ var ( http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") ) -// pushOptions is the internal version of http.PushOptions, which we -// cannot include here because it's only defined in Go 1.8 and later. -type http2pushOptions struct { - Method string - Header Header -} +var _ Pusher = (*http2responseWriter)(nil) -func (w *http2responseWriter) push(target string, opts http2pushOptions) error { +func (w *http2responseWriter) Push(target string, opts *PushOptions) error { st := w.rws.stream sc := st.sc sc.serveG.checkNotOn() @@ -6444,6 +6226,10 @@ func (w *http2responseWriter) push(target string, opts http2pushOptions) error { return http2ErrRecursivePush } + if opts == nil { + opts = new(PushOptions) + } + // Default options. if opts.Method == "" { opts.Method = "GET" @@ -6739,6 +6525,16 @@ type http2Transport struct { // to mean no limit. MaxHeaderListSize uint32 + // StrictMaxConcurrentStreams controls whether the server's + // SETTINGS_MAX_CONCURRENT_STREAMS should be respected + // globally. If false, new TCP connections are created to the + // server as needed to keep each under the per-connection + // SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the + // server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as + // a global limit and callers of RoundTrip block when needed, + // waiting for their turn. + StrictMaxConcurrentStreams bool + // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). @@ -6762,16 +6558,56 @@ func (t *http2Transport) disableCompression() bool { return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) } -var http2errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6") - // ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. -// It requires Go 1.6 or later and returns an error if the net/http package is too old -// or if t1 has already been HTTP/2-enabled. +// It returns an error if t1 has already been HTTP/2-enabled. func http2ConfigureTransport(t1 *Transport) error { - _, err := http2configureTransport(t1) // in configure_transport.go (go1.6) or not_go16.go + _, err := http2configureTransport(t1) return err } +func http2configureTransport(t1 *Transport) (*http2Transport, error) { + connPool := new(http2clientConnPool) + t2 := &http2Transport{ + ConnPool: http2noDialClientConnPool{connPool}, + t1: t1, + } + connPool.t = t2 + if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { + return nil, err + } + if t1.TLSClientConfig == nil { + t1.TLSClientConfig = new(tls.Config) + } + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { + t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) + } + if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { + t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") + } + upgradeFn := func(authority string, c *tls.Conn) RoundTripper { + addr := http2authorityAddr("https", authority) + if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { + go c.Close() + return http2erringRoundTripper{err} + } else if !used { + // Turns out we don't need this c. + // For example, two goroutines made requests to the same host + // at the same time, both kicking off TCP dials. (since protocol + // was unknown) + go c.Close() + } + return t2 + } + if m := t1.TLSNextProto; len(m) == 0 { + t1.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{ + "h2": upgradeFn, + } + } else { + m["h2"] = upgradeFn + } + return t2, nil +} + func (t *http2Transport) connPool() http2ClientConnPool { t.connPoolOnce.Do(t.initConnPool) return t.connPoolOrDef @@ -6836,7 +6672,7 @@ type http2ClientConn struct { type http2clientStream struct { cc *http2ClientConn req *Request - trace *http2clientTrace // or nil + trace *httptrace.ClientTrace // or nil ID uint32 resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload @@ -6870,7 +6706,7 @@ type http2clientStream struct { // channel to be signaled. A non-nil error is returned only if the request was // canceled. func http2awaitRequestCancel(req *Request, done <-chan struct{}) error { - ctx := http2reqContext(req) + ctx := req.Context() if req.Cancel == nil && ctx.Done() == nil { return nil } @@ -7046,8 +6882,8 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res select { case <-time.After(time.Second * time.Duration(backoff)): continue - case <-http2reqContext(req).Done(): - return nil, http2reqContext(req).Err() + case <-req.Context().Done(): + return nil, req.Context().Err() } } } @@ -7084,16 +6920,15 @@ func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Req } // If the Body is nil (or http.NoBody), it's safe to reuse // this request and its Body. - if req.Body == nil || http2reqBodyIsNoBody(req.Body) { + if req.Body == nil || req.Body == NoBody { return req, nil } // If the request body can be reset back to its original // state via the optional req.GetBody, do that. - getBody := http2reqGetBody(req) // Go 1.8: getBody = req.GetBody - if getBody != nil { + if req.GetBody != nil { // TODO: consider a req.Body.Close here? or audit that all caller paths do? - body, err := getBody() + body, err := req.GetBody() if err != nil { return nil, err } @@ -7139,7 +6974,7 @@ func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2Clie func (t *http2Transport) newTLSConfig(host string) *tls.Config { cfg := new(tls.Config) if t.TLSClientConfig != nil { - *cfg = *http2cloneTLSConfig(t.TLSClientConfig) + *cfg = *t.TLSClientConfig.Clone() } if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) @@ -7190,7 +7025,7 @@ func (t *http2Transport) expectContinueTimeout() time.Duration { if t.t1 == nil { return 0 } - return http2transportExpectContinueTimeout(t.t1) + return t.t1.ExpectContinueTimeout } func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { @@ -7315,8 +7150,19 @@ func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { if cc.singleUse && cc.nextStreamID > 1 { return } - st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && - int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32 + var maxConcurrentOkay bool + if cc.t.StrictMaxConcurrentStreams { + // We'll tell the caller we can take a new request to + // prevent the caller from dialing a new TCP + // connection, but then we'll block later before + // writing it. + maxConcurrentOkay = true + } else { + maxConcurrentOkay = int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) + } + + st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && + int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest return } @@ -7356,8 +7202,7 @@ func (cc *http2ClientConn) closeIfIdle() { var http2shutdownEnterWaitStateHook = func() {} // Shutdown gracefully close the client connection, waiting for running streams to complete. -// Public implementation is in go17.go and not_go17.go -func (cc *http2ClientConn) shutdown(ctx http2contextContext) error { +func (cc *http2ClientConn) Shutdown(ctx context.Context) error { if err := cc.sendGoAway(); err != nil { return err } @@ -7527,7 +7372,7 @@ func http2checkConnHeaders(req *Request) error { // req.ContentLength, where 0 actually means zero (not unknown) and -1 // means unknown. func http2actualContentLength(req *Request) int64 { - if req.Body == nil || http2reqBodyIsNoBody(req.Body) { + if req.Body == nil || req.Body == NoBody { return 0 } if req.ContentLength != 0 { @@ -7597,7 +7442,7 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe cs := cc.newStream() cs.req = req - cs.trace = http2requestTrace(req) + cs.trace = httptrace.ContextClientTrace(req.Context()) cs.requestedGzip = requestedGzip bodyWriter := cc.t.getBodyWriterState(cs, body) cs.on100 = bodyWriter.on100 @@ -7635,7 +7480,7 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe readLoopResCh := cs.resc bodyWritten := false - ctx := http2reqContext(req) + ctx := req.Context() handleReadLoopResponse := func(re http2resAndError) (*Response, bool, error) { res := re.res @@ -7705,6 +7550,7 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe default: } if err != nil { + cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), err } bodyWritten = true @@ -7826,6 +7672,7 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos sawEOF = true err = nil } else if err != nil { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) return err } @@ -8061,7 +7908,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail return nil, http2errRequestHeaderListSize } - trace := http2requestTrace(req) + trace := httptrace.ContextClientTrace(req.Context()) traceHeaders := http2traceHasWroteHeaderField(trace) // Header list size is ok. Write the headers. @@ -8484,7 +8331,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http res.Header.Del("Content-Length") res.ContentLength = -1 res.Body = &http2gzipReader{body: res.Body} - http2setResponseUncompressed(res) + res.Uncompressed = true } return res, nil } @@ -8861,8 +8708,7 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er } // Ping sends a PING frame to the server and waits for the ack. -// Public implementation is in go17.go and not_go17.go -func (cc *http2ClientConn) ping(ctx http2contextContext) error { +func (cc *http2ClientConn) Ping(ctx context.Context) error { c := make(chan struct{}) // Generate a random payload var p [8]byte @@ -9097,6 +8943,94 @@ func http2isConnectionCloseRequest(req *Request) bool { return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close") } +// registerHTTPSProtocol calls Transport.RegisterProtocol but +// converting panics into errors. +func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("%v", e) + } + }() + t.RegisterProtocol("https", rt) + return nil +} + +// noDialH2RoundTripper is a RoundTripper which only tries to complete the request +// if there's already has a cached connection to the host. +// (The field is exported so it can be accessed via reflect from net/http; tested +// by TestNoDialH2RoundTripperType) +type http2noDialH2RoundTripper struct{ *http2Transport } + +func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) { + res, err := rt.http2Transport.RoundTrip(req) + if http2isNoCachedConnError(err) { + return nil, ErrSkipAltProtocol + } + return res, err +} + +func (t *http2Transport) idleConnTimeout() time.Duration { + if t.t1 != nil { + return t.t1.IdleConnTimeout + } + return 0 +} + +func http2traceGetConn(req *Request, hostPort string) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GetConn == nil { + return + } + trace.GetConn(hostPort) +} + +func http2traceGotConn(req *Request, cc *http2ClientConn) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GotConn == nil { + return + } + ci := httptrace.GotConnInfo{Conn: cc.tconn} + cc.mu.Lock() + ci.Reused = cc.nextStreamID > 1 + ci.WasIdle = len(cc.streams) == 0 && ci.Reused + if ci.WasIdle && !cc.lastActive.IsZero() { + ci.IdleTime = time.Now().Sub(cc.lastActive) + } + cc.mu.Unlock() + + trace.GotConn(ci) +} + +func http2traceWroteHeaders(trace *httptrace.ClientTrace) { + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } +} + +func http2traceGot100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } +} + +func http2traceWait100Continue(trace *httptrace.ClientTrace) { + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } +} + +func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) { + if trace != nil && trace.WroteRequest != nil { + trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) + } +} + +func http2traceFirstResponseByte(trace *httptrace.ClientTrace) { + if trace != nil && trace.GotFirstResponseByte != nil { + trace.GotFirstResponseByte() + } +} + // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error @@ -9283,7 +9217,7 @@ func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { // TODO: this is a common one. It'd be nice to return true // here and get into the fast path if we could be clever and // calculate the size fast enough, or at least a conservative - // uppper bound that usually fires. (Maybe if w.h and + // upper bound that usually fires. (Maybe if w.h and // w.trailers are nil, so we don't need to enumerate it.) // Otherwise I'm afraid that just calculating the length to // answer this question would be slower than the ~2µs benefit. @@ -9413,7 +9347,7 @@ func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { } // encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) -// is encoded only only if k is in keys. +// is encoded only if k is in keys. func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { if keys == nil { sorter := http2sorterPool.Get().(*http2sorter)