Skip to content

Commit

Permalink
http2: don't leak streams on broken body
Browse files Browse the repository at this point in the history
Updates golang/go#27208

Change-Id: I5d9a643f33d27d33b24f670c98f5a51aa6000967
GitHub-Last-Rev: 3ac4a57
GitHub-Pull-Request: #18
Reviewed-on: https://go-review.googlesource.com/c/132715
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
  • Loading branch information
Ruslan Nigmatullin authored and bradfitz committed Nov 7, 2018
1 parent a544f70 commit 1c5f79c
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
2 changes: 2 additions & 0 deletions http2/transport.go
Expand Up @@ -1100,6 +1100,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf
default:
}
if err != nil {
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), err
}
bodyWritten = true
Expand Down Expand Up @@ -1221,6 +1222,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
sawEOF = true
err = nil
} else if err != nil {
cc.writeStreamReset(cs.ID, ErrCodeCancel, err)
return err
}

Expand Down
96 changes: 96 additions & 0 deletions http2/transport_test.go
Expand Up @@ -4180,3 +4180,99 @@ func TestNoDialH2RoundTripperType(t *testing.T) {
t.Fatalf("wrong kind %T; want *Transport", v.Interface())
}
}

type errReader struct {
body []byte
err error
}

func (r *errReader) Read(p []byte) (int, error) {
if len(r.body) > 0 {
n := copy(p, r.body)
r.body = r.body[n:]
return n, nil
}
return 0, r.err
}

func testTransportBodyReadError(t *testing.T, body []byte) {
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)

checkNoStreams := func() error {
cp, ok := ct.tr.connPool().(*clientConnPool)
if !ok {
return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
}
cp.mu.Lock()
defer cp.mu.Unlock()
conns, ok := cp.conns["dummy.tld:443"]
if !ok {
return fmt.Errorf("missing connection")
}
if len(conns) != 1 {
return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
}
if activeStreams(conns[0]) != 0 {
return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
}
return nil
}
bodyReadError := errors.New("body read error")
body := &errReader{body, bodyReadError}
req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
if err != nil {
return err
}
_, err = ct.tr.RoundTrip(req)
if err != bodyReadError {
return fmt.Errorf("err = %v; want %v", err, bodyReadError)
}
if err = checkNoStreams(); err != nil {
return err
}
return nil
}
ct.server = func() error {
ct.greet()
var receivedBody []byte
var resetCount int
for {
f, err := ct.fr.ReadFrame()
if err != nil {
select {
case <-clientDone:
// If the client's done, it
// will have reported any
// errors on its side.
if bytes.Compare(receivedBody, body) != 0 {
return fmt.Errorf("body: %v; expected %v", receivedBody, body)
}
if resetCount != 1 {
return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
}
return nil
default:
return err
}
}
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
case *HeadersFrame:
case *DataFrame:
receivedBody = append(receivedBody, f.Data()...)
case *RSTStreamFrame:
resetCount++
default:
return fmt.Errorf("Unexpected client frame %v", f)
}
}
}
ct.run()
}

func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }

0 comments on commit 1c5f79c

Please sign in to comment.