Skip to content

Commit

Permalink
net/http: disable 100 continue status after handler finished
Browse files Browse the repository at this point in the history
When client supplies "Expect: 100-continue" header,
server wraps request body into expectContinueReader
that writes 100 Continue status on the first body read.

When handler acts as a reverse proxy and passes incoming
request (or body) to the client (or transport) it may happen
that request body is read after handler exists which may
cause nil pointer panic on connection write if server
already closed the connection.

This change disables write of 100 Continue status by expectContinueReader
after handler finished and before connection is closed.

It also fixes racy access to w.wroteContinue.

Fixes #53808
Updates #46866
  • Loading branch information
AlexanderYastrebov committed Mar 30, 2024
1 parent bb523c9 commit 4a48705
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 29 deletions.
25 changes: 25 additions & 0 deletions src/net/http/clientserver_test.go
Expand Up @@ -1754,3 +1754,28 @@ func testEarlyHintsRequest(t *testing.T, mode testMode) {
t.Errorf("Read body %q; want Hello", body)
}
}

// Issue 53808
func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
run(t, testServerReadAfterHandlerDone100Continue)
}
func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
readyc := make(chan struct{})
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
go func() {
<-readyc
io.ReadAll(r.Body)
<-readyc
}()
}))

req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
req.Header.Set("Expect", "100-continue")
res, err := cst.c.Do(req)
if err != nil {
t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
}
res.Body.Close()
readyc <- struct{}{} // server starts reading from the request body
readyc <- struct{}{} // server finishes reading from the request body
}
57 changes: 31 additions & 26 deletions src/net/http/server.go
Expand Up @@ -425,7 +425,6 @@ type response struct {
reqBody io.ReadCloser
cancelCtx context.CancelFunc // when ServeHTTP exits
wroteHeader bool // a non-1xx header has been (logically) written
wroteContinue bool // 100 Continue response was written
wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive"
wantsClose bool // HTTP request has Connection "close"

Expand All @@ -438,6 +437,7 @@ type response struct {
// against the main writer.
canWriteContinue atomic.Bool
writeContinueMu sync.Mutex
wroteContinue atomic.Bool // 100 Continue response was written

w *bufio.Writer // buffers output in chunks to chunkWriter
cw chunkWriter
Expand Down Expand Up @@ -916,17 +916,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
if ecr.closed.Load() {
return 0, ErrBodyReadAfterClose
}
w := ecr.resp
if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() {
w.wroteContinue = true
w.writeContinueMu.Lock()
if w.canWriteContinue.Load() {
w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
w.conn.bufw.Flush()
w.canWriteContinue.Store(false)
}
w.writeContinueMu.Unlock()
}
ecr.resp.writeContinueOnce()
n, err = ecr.readCloser.Read(p)
if err == io.EOF {
ecr.sawEOF.Store(true)
Expand Down Expand Up @@ -1165,10 +1155,8 @@ func (w *response) WriteHeader(code int) {
// so it takes the non-informational path.
if code >= 100 && code <= 199 && code != StatusSwitchingProtocols {
// Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read()
if code == 100 && w.canWriteContinue.Load() {
w.writeContinueMu.Lock()
w.canWriteContinue.Store(false)
w.writeContinueMu.Unlock()
if code == 100 {
w.disableContinue()
}

writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:])
Expand Down Expand Up @@ -1383,7 +1371,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {

switch bdy := w.req.Body.(type) {
case *expectContinueReader:
if bdy.resp.wroteContinue {
if bdy.resp.wroteContinue.Load() {
discard = true
}
case *body:
Expand Down Expand Up @@ -1625,15 +1613,7 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er
return 0, ErrHijacked
}

if w.canWriteContinue.Load() {
// Body reader wants to write 100 Continue but hasn't yet.
// Tell it not to. The store must be done while holding the lock
// because the lock makes sure that there is not an active write
// this very moment.
w.writeContinueMu.Lock()
w.canWriteContinue.Store(false)
w.writeContinueMu.Unlock()
}
w.disableContinue()

if !w.wroteHeader {
w.WriteHeader(StatusOK)
Expand Down Expand Up @@ -1679,6 +1659,29 @@ func (w *response) finishRequest() {
}
}

// disableContinue disables write of 100 Continue status.
func (w *response) disableContinue() {
if w.canWriteContinue.Load() {
w.writeContinueMu.Lock()
w.canWriteContinue.Store(false)
w.writeContinueMu.Unlock()
}
}

// writeContinueOnce writes 100 Continue status if allowed.
func (w *response) writeContinueOnce() {
if w.canWriteContinue.Load() && !w.conn.hijacked() {
w.wroteContinue.Store(true)
w.writeContinueMu.Lock()
if w.canWriteContinue.Load() {
w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
w.conn.bufw.Flush()
w.canWriteContinue.Store(false)
}
w.writeContinueMu.Unlock()
}
}

// shouldReuseConnection reports whether the underlying TCP connection can be reused.
// It must only be called after the handler is done executing.
func (w *response) shouldReuseConnection() bool {
Expand Down Expand Up @@ -1905,6 +1908,7 @@ func (c *conn) serve(ctx context.Context) {
if inFlightResponse != nil {
inFlightResponse.conn.r.abortPendingRead()
inFlightResponse.reqBody.Close()
inFlightResponse.disableContinue()
}
c.close()
c.setState(c.rwc, StateClosed, runHooks)
Expand Down Expand Up @@ -2046,6 +2050,7 @@ func (c *conn) serve(ctx context.Context) {
return
}
w.finishRequest()
w.disableContinue()
c.rwc.SetWriteDeadline(time.Time{})
if !w.shouldReuseConnection() {
if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
Expand Down
19 changes: 16 additions & 3 deletions src/net/http/transport_test.go
Expand Up @@ -26,6 +26,7 @@ import (
"log"
mrand "math/rand"
"net"
"net/http"
. "net/http"
"net/http/httptest"
"net/http/httptrace"
Expand Down Expand Up @@ -6901,19 +6902,31 @@ func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
panic(ErrAbortHandler)
})).ts

newRequest := func() *http.Request {
const reqLen = 6 * 1024 * 1024
req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
req.ContentLength = reqLen
return req
}

var wg sync.WaitGroup
for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
const reqLen = 6 * 1024 * 1024
req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
req.ContentLength = reqLen
req := newRequest()
resp, _ := ts.Client().Transport.RoundTrip(req)
if resp != nil {
resp.Body.Close()
}

req = newRequest()
req.Header.Set("Expect", "100-continue")
resp, _ = ts.Client().Transport.RoundTrip(req)
if resp != nil {
resp.Body.Close()
}
}
}()
}
Expand Down

0 comments on commit 4a48705

Please sign in to comment.