diff --git a/conn.go b/conn.go index 4c0933b7..7bbc6013 100644 --- a/conn.go +++ b/conn.go @@ -265,6 +265,10 @@ type Conn struct { } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { + return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) +} + +func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { mu := make(chan bool, 1) mu <- true @@ -274,13 +278,28 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) if readBufferSize < maxControlFramePayloadSize { readBufferSize = maxControlFramePayloadSize } + + // Reuse the supplied brw.Reader if brw.Reader's buf is the requested size. + var br *bufio.Reader + if brw != nil && brw.Reader != nil { + // This code assumes that peek on a reset reader returns + // bufio.Reader.buf[:0]. + brw.Reader.Reset(conn) + if p, err := brw.Reader.Peek(0); err == nil && cap(p) == readBufferSize { + br = brw.Reader + } + } + if br == nil { + br = bufio.NewReaderSize(conn, readBufferSize) + } + if writeBufferSize == 0 { writeBufferSize = defaultWriteBufferSize } c := &Conn{ isServer: isServer, - br: bufio.NewReaderSize(conn, readBufferSize), + br: br, conn: conn, mu: mu, readFinal: true, diff --git a/conn_test.go b/conn_test.go index 7431383b..ba347f8b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -463,3 +463,17 @@ func TestFailedConnectionReadPanic(t *testing.T) { } t.Fatal("should not get here") } + +func TestBufioReaderReuse(t *testing.T) { + brw := bufio.NewReadWriter(bufio.NewReader(nil), nil) + c := newConnBRW(nil, false, 0, 0, brw) + if c.br != brw.Reader { + t.Error("connection did not reuse bufio.Reader") + } + + brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 1234), nil) // size must not equal bufio.defaultBufSize + c = newConnBRW(nil, false, 0, 0, brw) + if c.br == brw.Reader { + t.Error("connection reuse bufio.Reader with wrong size") + } +} diff --git a/server.go b/server.go index 6f6ac832..95c16566 100644 --- a/server.go +++ b/server.go @@ -152,7 +152,6 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade var ( netConn net.Conn - br *bufio.Reader err error ) @@ -160,19 +159,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if !ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") } - var rw *bufio.ReadWriter - netConn, rw, err = h.Hijack() + var brw *bufio.ReadWriter + netConn, brw, err = h.Hijack() if err != nil { return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } - br = rw.Reader - if br.Buffered() > 0 { + if brw.Reader.Buffered() > 0 { netConn.Close() return nil, errors.New("websocket: client sent data before handshake is complete") } - c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) + c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) c.subprotocol = subprotocol if compress {