From 438097d76259d4670fb324c834c2426a7f1946f8 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 1 Dec 2015 22:16:42 +0000 Subject: [PATCH] http2: make the Transport write request body data as it's available Unlike HTTP/1, we now permit streaming the write of a request body as we read the response body, since HTTP/2's framing makes it possible. Our behavior however is based on a heuristic: we always begin writing the request body right away (like previously, and like HTTP/1), but if we're still writing the request body and the server replies with a status code over 299 (not 1xx and not 2xx), then we stop writing the request body, assuming the server doesn't care about it. There is currently no switch (and hopefully won't be) to force enable this behavior. In the case where the server replied with a 1xx/2xx and we're still writing the request body but the server doesn't want it, the server can do a RST_STREAM, which we respect as before and stop sending. Also in this CL: * adds an h2demo handler at https://http2.golang.org/ECHO to demo it * fixes a potential flow control integer truncation bug * start of clientTester type used for the tests in this CL, similar to the serverTester. It's still a bit cumbersome to write client tests, though. * fix potential deadlock where awaitFlowControl could block while waiting a stream reset arrived. fix it by moving all checks into the sync.Cond loop, rather than having a sync.Cond check followed by a select. simplifies code, too. * fix two data races in test-only code. Updates golang/go#13444 Change-Id: Idfda6833a212a89fcd65293cdeb4169d1723724f Reviewed-on: https://go-review.googlesource.com/17310 Reviewed-by: Blake Mizerany --- http2/h2demo/h2demo.go | 36 +++++ http2/server_test.go | 6 +- http2/transport.go | 98 ++++++++----- http2/transport_test.go | 303 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 406 insertions(+), 37 deletions(-) diff --git a/http2/h2demo/h2demo.go b/http2/h2demo/h2demo.go index 8d5e4fd1c..15ef52f9b 100644 --- a/http2/h2demo/h2demo.go +++ b/http2/h2demo/h2demo.go @@ -91,6 +91,7 @@ href="https://golang.org/issues">file a bug.

  • GET /redirect to redirect back to / (this page)
  • GET /goroutines to see all active goroutines in this server
  • PUT something to /crc32 to get a count of number of bytes and its CRC-32
  • +
  • PUT something to /ECHO and it will be streamed back to you capitalized
  • `) @@ -124,6 +125,40 @@ func crcHandler(w http.ResponseWriter, r *http.Request) { } } +type capitalizeReader struct { + r io.Reader +} + +func (cr capitalizeReader) Read(p []byte) (n int, err error) { + n, err = cr.r.Read(p) + for i, b := range p[:n] { + if b >= 'a' && b <= 'z' { + p[i] = b - ('a' - 'A') + } + } + return +} + +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(p []byte) (n int, err error) { + n, err = fw.w.Write(p) + if f, ok := fw.w.(http.Flusher); ok { + f.Flush() + } + return +} + +func echoCapitalHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != "PUT" { + http.Error(w, "PUT required.", 400) + return + } + io.Copy(flushWriter{w}, capitalizeReader{r.Body}) +} + var ( fsGrp singleflight.Group fsMu sync.Mutex // guards fsCache @@ -217,6 +252,7 @@ func registerHandlers() { mux2.Handle("/file/go.src.tar.gz", fileServer("https://storage.googleapis.com/golang/go1.4.1.src.tar.gz")) mux2.HandleFunc("/reqinfo", reqInfoHandler) mux2.HandleFunc("/crc32", crcHandler) + mux2.HandleFunc("/ECHO", echoCapitalHandler) mux2.HandleFunc("/clockstream", clockStreamHandler) mux2.Handle("/gophertiles", tiles) mux2.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { diff --git a/http2/server_test.go b/http2/server_test.go index 7a4205159..7e8eb7eb0 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -2213,6 +2213,9 @@ func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) { t.Skip("skipping curl test in short mode") } requireCurl(t) + var gotConn int32 + testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) } + const msg = "Hello from curl!\n" ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Foo", "Bar") @@ -2226,9 +2229,6 @@ func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) { ts.StartTLS() defer ts.Close() - var gotConn int32 - testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) } - t.Logf("Running test server for curl to hit at: %s", ts.URL) container := curl(t, "--silent", "--http2", "--insecure", "-v", ts.URL) defer kill(container) diff --git a/http2/transport.go b/http2/transport.go index 320bf6718..9d9e09c81 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -155,6 +155,7 @@ type clientStream struct { inflow flow // guarded by cc.mu bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read + stopReqBody bool // stop writing req body; guarded by cc.mu peerReset chan struct{} // closed on peer reset resetErr error // populated before peerReset is closed @@ -171,6 +172,14 @@ func (cs *clientStream) checkReset() error { } } +func (cs *clientStream) abortRequestBodyWrite() { + cc := cs.cc + cc.mu.Lock() + cs.stopReqBody = true + cc.cond.Broadcast() + cc.mu.Unlock() +} + type stickyErrWriter struct { w io.Writer err *error @@ -516,26 +525,33 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return nil, werr } - var bodyCopyErrc chan error - var gotResHeaders chan struct{} // closed on resheaders + var bodyCopyErrc chan error // result of body copy if hasBody { bodyCopyErrc = make(chan error, 1) - gotResHeaders = make(chan struct{}) go func() { - bodyCopyErrc <- cs.writeRequestBody(req.Body, gotResHeaders) + bodyCopyErrc <- cs.writeRequestBody(req.Body) }() } for { select { case re := <-cs.resc: - if gotResHeaders != nil { - close(gotResHeaders) + res := re.res + if re.err != nil || res.StatusCode > 299 { + // On error or status code 3xx, 4xx, 5xx, etc abort any + // ongoing write, assuming that the server doesn't care + // about our request body. If the server replied with 1xx or + // 2xx, however, then assume the server DOES potentially + // want our body (e.g. full-duplex streaming: + // golang.org/issue/13444). If it turns out the server + // doesn't, they'll RST_STREAM us soon enough. This is a + // heuristic to avoid adding knobs to Transport. Hopefully + // we can keep it. + cs.abortRequestBodyWrite() } if re.err != nil { return nil, re.err } - res := re.res res.Request = req res.TLS = cc.tlsState return res, nil @@ -547,45 +563,56 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { } } -var errServerResponseBeforeRequestBody = errors.New("http2: server sent response while still writing request body") +// errAbortReqBodyWrite is an internal error value. +// It doesn't escape to callers. +var errAbortReqBodyWrite = errors.New("http2: aborting request body write") -func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan struct{}) error { +func (cs *clientStream) writeRequestBody(body io.ReadCloser) (err error) { cc := cs.cc sentEnd := false // whether we sent the final DATA frame w/ END_STREAM buf := cc.frameScratchBuffer() defer cc.putFrameScratchBuffer(buf) - for !sentEnd { - var sawEOF bool - n, err := io.ReadFull(body, buf) - if err == io.ErrUnexpectedEOF { + defer func() { + // TODO: write h12Compare test showing whether + // Request.Body is closed by the Transport, + // and in multiple cases: server replies <=299 and >299 + // while still writing request body + cerr := body.Close() + if err == nil { + err = cerr + } + }() + + var sawEOF bool + for !sawEOF { + n, err := body.Read(buf) + if err == io.EOF { sawEOF = true err = nil - } else if err == io.EOF { - break } else if err != nil { return err } - toWrite := buf[:n] - for len(toWrite) > 0 && err == nil { + remain := buf[:n] + for len(remain) > 0 && err == nil { var allowed int32 - allowed, err = cs.awaitFlowControl(int32(len(toWrite))) + allowed, err = cs.awaitFlowControl(len(remain)) if err != nil { return err } - cc.wmu.Lock() - select { - case <-gotResHeaders: - err = errServerResponseBeforeRequestBody - case <-cs.peerReset: - err = cs.resetErr - default: - data := toWrite[:allowed] - toWrite = toWrite[allowed:] - sentEnd = sawEOF && len(toWrite) == 0 - err = cc.fr.WriteData(cs.ID, sentEnd, data) + data := remain[:allowed] + remain = remain[allowed:] + sentEnd = sawEOF && len(remain) == 0 + err = cc.fr.WriteData(cs.ID, sentEnd, data) + if err == nil { + // TODO(bradfitz): this flush is for latency, not bandwidth. + // Most requests won't need this. Make this opt-in or opt-out? + // Use some heuristic on the body type? Nagel-like timers? + // Based on 'n'? Only last chunk of this for loop, unless flow control + // tokens are low? For now, always: + err = cc.bw.Flush() } cc.wmu.Unlock() } @@ -594,8 +621,6 @@ func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan st } } - var err error - cc.wmu.Lock() if !sentEnd { err = cc.fr.WriteData(cs.ID, true, nil) @@ -612,7 +637,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, gotResHeaders <-chan st // control tokens from the server. // It returns either the non-zero number of tokens taken or an error // if the stream is dead. -func (cs *clientStream) awaitFlowControl(maxBytes int32) (taken int32, err error) { +func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { cc := cs.cc cc.mu.Lock() defer cc.mu.Unlock() @@ -620,13 +645,17 @@ func (cs *clientStream) awaitFlowControl(maxBytes int32) (taken int32, err error if cc.closed { return 0, errClientConnClosed } + if cs.stopReqBody { + return 0, errAbortReqBodyWrite + } if err := cs.checkReset(); err != nil { return 0, err } if a := cs.flow.available(); a > 0 { take := a - if take > maxBytes { - take = maxBytes + if int(take) > maxBytes { + + take = int32(maxBytes) // can't truncate int; take is int32 } if take > int32(cc.maxFrameSize) { take = int32(cc.maxFrameSize) @@ -1092,6 +1121,7 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error { cs.resetErr = err close(cs.peerReset) cs.bufPipe.CloseWithError(err) + cs.cc.cond.Broadcast() // wake up checkReset via clientStream.awaitFlowControl } delete(rl.activeRes, cs.ID) return nil diff --git a/http2/transport_test.go b/http2/transport_test.go index 83791575e..0c875acfd 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -5,21 +5,29 @@ package http2 import ( + "bufio" + "bytes" "crypto/tls" + "errors" "flag" "fmt" "io" "io/ioutil" + "log" "math/rand" "net" "net/http" "net/url" "os" "reflect" + "strconv" "strings" "sync" + "sync/atomic" "testing" "time" + + "golang.org/x/net/http2/hpack" ) var ( @@ -182,6 +190,8 @@ func TestTransportGroupsPendingDials(t *testing.T) { if !ok { return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool()) } + cp.mu.Lock() + defer cp.mu.Unlock() if len(cp.dialing) != 0 { return fmt.Errorf("dialing map = %v; want empty", cp.dialing) } @@ -456,3 +466,296 @@ func TestConfigureTransport(t *testing.T) { t.Errorf("body = %q; want %q", got, want) } } + +type capitalizeReader struct { + r io.Reader +} + +func (cr capitalizeReader) Read(p []byte) (n int, err error) { + n, err = cr.r.Read(p) + for i, b := range p[:n] { + if b >= 'a' && b <= 'z' { + p[i] = b - ('a' - 'A') + } + } + return +} + +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(p []byte) (n int, err error) { + n, err = fw.w.Write(p) + if f, ok := fw.w.(http.Flusher); ok { + f.Flush() + } + return +} + +type clientTester struct { + t *testing.T + tr *Transport + sc, cc net.Conn // server and client conn + fr *Framer // server's framer + client func() error + server func() error +} + +func newClientTester(t *testing.T) *clientTester { + var dialOnce struct { + sync.Mutex + dialed bool + } + ct := &clientTester{ + t: t, + } + ct.tr = &Transport{ + TLSClientConfig: tlsConfigInsecure, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + dialOnce.Lock() + defer dialOnce.Unlock() + if dialOnce.dialed { + return nil, errors.New("only one dial allowed in test mode") + } + dialOnce.dialed = true + return ct.cc, nil + }, + } + + ln := newLocalListener(t) + cc, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + + } + sc, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + ln.Close() + ct.cc = cc + ct.sc = sc + ct.fr = NewFramer(sc, sc) + return ct +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err == nil { + return ln + } + ln, err = net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Fatal(err) + } + return ln +} + +func (ct *clientTester) greet() { + buf := make([]byte, len(ClientPreface)) + _, err := io.ReadFull(ct.sc, buf) + if err != nil { + ct.t.Fatalf("reading client preface: %v", err) + } + f, err := ct.fr.ReadFrame() + if err != nil { + ct.t.Fatalf("Reading client settings frame: %v", err) + } + if sf, ok := f.(*SettingsFrame); !ok { + ct.t.Fatalf("Wanted client settings frame; got %v", f) + _ = sf // stash it away? + } + if err := ct.fr.WriteSettings(); err != nil { + ct.t.Fatal(err) + } + if err := ct.fr.WriteSettingsAck(); err != nil { + ct.t.Fatal(err) + } +} + +func (ct *clientTester) run() { + errc := make(chan error, 2) + ct.start("client", errc, ct.client) + ct.start("server", errc, ct.server) + for i := 0; i < 2; i++ { + if err := <-errc; err != nil { + ct.t.Error(err) + return + } + } +} + +func (ct *clientTester) start(which string, errc chan<- error, fn func() error) { + go func() { + finished := false + var err error + defer func() { + if !finished { + err = fmt.Errorf("%s goroutine didn't finish.", which) + } else if err != nil { + err = fmt.Errorf("%s: %v", which, err) + } + errc <- err + }() + err = fn() + finished = true + }() +} + +type countingReader struct { + n *int64 +} + +func (r countingReader) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(i) + } + atomic.AddInt64(r.n, int64(len(p))) + return len(p), err +} + +func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } +func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } + +func testTransportReqBodyAfterResponse(t *testing.T, status int) { + const bodySize = 10 << 20 + ct := newClientTester(t) + ct.client = func() error { + var n int64 // atomic + req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize)) + if err != nil { + return err + } + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != status { + return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("Slurp: %v", err) + } + if len(slurp) > 0 { + return fmt.Errorf("unexpected body: %q", slurp) + } + if status == 200 { + if got := atomic.LoadInt64(&n); got != bodySize { + return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) + } + } else { + if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize { + return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) + } + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + var dataRecv int64 + var closed bool + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + //println(fmt.Sprintf("server got frame: %v", f)) + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + if f.StreamEnded() { + return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) + } + time.Sleep(50 * time.Millisecond) // let client send body + enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + case *DataFrame: + dataLen := len(f.Data()) + dataRecv += int64(dataLen) + if dataLen > 0 { + if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { + return err + } + if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { + return err + } + } + if !closed && ((status != 200 && dataRecv > 0) || + (status == 200 && dataRecv == bodySize)) { + closed = true + if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { + return err + } + return nil + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + return nil + } + ct.run() +} + +// See golang.org/issue/13444 +func TestTransportFullDuplex(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) // redundant but for clarity + w.(http.Flusher).Flush() + io.Copy(flushWriter{w}, capitalizeReader{r.Body}) + fmt.Fprintf(w, "bye.\n") + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + pr, pw := io.Pipe() + req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr)) + if err != nil { + log.Fatal(err) + } + res, err := c.Do(req) + if err != nil { + log.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200) + } + bs := bufio.NewScanner(res.Body) + want := func(v string) { + if !bs.Scan() { + t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err()) + } + } + write := func(v string) { + _, err := io.WriteString(pw, v) + if err != nil { + t.Fatalf("pipe write: %v", err) + } + } + write("foo\n") + want("FOO") + write("bar\n") + want("BAR") + pw.Close() + want("bye.") + if err := bs.Err(); err != nil { + t.Fatal(err) + } +}