From a87eae1d6ff562bf944755f4a6f527b6ef6fb888 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Wed, 29 Jun 2016 17:03:55 -0700 Subject: [PATCH] Add hooks to support RFC 7692 (per-message compression extension) Add newCompressionWriter and newDecompressionReader fields to Conn. When not nil, these functions are used to create a compression/decompression wrapper around an underlying message writer/reader. Add code to set and check for RSV1 frame header bit. Add functions compressNoContextTakeover and decompressNoContextTakeover for creating no context takeover wrappers around an underlying message writer/reader. Work remaining: - Add fields to Dialer and Upgrader for specifying compression options. - Add compression negotiation to Dialer and Upgrader. - Add function to enable/disable write compression: // EnableWriteCompression enables and disables write compression of // subsequent text and binary messages. This function is a noop if // compression was not negotiated with the peer. func (c *Conn) EnableWriteCompression(enable bool) { c.enableWriteCompression = enable } --- compression.go | 85 +++++++++++++++++++++++++++++++++++ compression_test.go | 31 +++++++++++++ conn.go | 107 +++++++++++++++++++++++++++++++------------- 3 files changed, 191 insertions(+), 32 deletions(-) create mode 100644 compression.go create mode 100644 compression_test.go diff --git a/compression.go b/compression.go new file mode 100644 index 00000000..e2ac7617 --- /dev/null +++ b/compression.go @@ -0,0 +1,85 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "compress/flate" + "errors" + "io" + "strings" +) + +func decompressNoContextTakeover(r io.Reader) io.Reader { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + return flate.NewReader(io.MultiReader(r, strings.NewReader(tail))) +} + +func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { + tw := &truncWriter{w: w} + fw, err := flate.NewWriter(tw, 3) + return &flateWrapper{fw: fw, tw: tw}, err +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWrapper struct { + fw *flate.Writer + tw *truncWriter +} + +func (w *flateWrapper) Write(p []byte) (int, error) { + return w.fw.Write(p) +} + +func (w *flateWrapper) Close() error { + err1 := w.fw.Flush() + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} diff --git a/compression_test.go b/compression_test.go new file mode 100644 index 00000000..cad70fb5 --- /dev/null +++ b/compression_test.go @@ -0,0 +1,31 @@ +package websocket + +import ( + "bytes" + "io" + "testing" +) + +type nopCloser struct{ io.Writer } + +func (nopCloser) Close() error { return nil } + +func TestTruncWriter(t *testing.T) { + const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" + for n := 1; n <= 10; n++ { + var b bytes.Buffer + w := &truncWriter{w: nopCloser{&b}} + p := []byte(data) + for len(p) > 0 { + m := len(p) + if m > n { + m = n + } + w.Write(p[:m]) + p = p[m:] + } + if b.String() != data[:len(data)-len(w.p)] { + t.Errorf("%d: %q", n, b.String()) + } + } +} diff --git a/conn.go b/conn.go index 794c2eff..eb4334e7 100644 --- a/conn.go +++ b/conn.go @@ -18,11 +18,19 @@ import ( ) const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask maxControlFramePayloadSize = 125 - finalBit = 1 << 7 - maskBit = 1 << 7 - writeWait = time.Second + + writeWait = time.Second defaultReadBufferSize = 4096 defaultWriteBufferSize = 4096 @@ -230,17 +238,20 @@ type Conn struct { subprotocol string // Write fields - mu chan bool // used as mutex to protect write to conn and closeSent - closeSent bool // true if close message was sent - - // Message writer fields. + mu chan bool // used as mutex to protect write to conn and closeSent + closeSent bool // whether close message was sent writeErr error writeBuf []byte // frame is constructed in this buffer. writePos int // end of data in writeBuf. writeFrameType int // type of the current frame. writeDeadline time.Time + messageWriter *messageWriter // the current low-level message writer + writer io.WriteCloser // the current writer returned to the application isWriting bool // for best-effort concurrent write detection - messageWriter *messageWriter // the current writer + + enableWriteCompression bool + writeCompress bool // whether next call to flushFrame should set RSV1 + newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) // Read fields readErr error @@ -254,7 +265,10 @@ type Conn struct { handlePong func(string) error handlePing func(string) error readErrCount int - messageReader *messageReader // the current reader + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.Reader } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -272,14 +286,15 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) } c := &Conn{ - isServer: isServer, - br: bufio.NewReaderSize(conn, readBufferSize), - conn: conn, - mu: mu, - readFinal: true, - writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), - writeFrameType: noFrame, - writePos: maxFrameHeaderSize, + isServer: isServer, + br: bufio.NewReaderSize(conn, readBufferSize), + conn: conn, + mu: mu, + readFinal: true, + writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), + writeFrameType: noFrame, + writePos: maxFrameHeaderSize, + enableWriteCompression: true, } c.SetPingHandler(nil) c.SetPongHandler(nil) @@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { return nil, c.writeErr } - if c.writeFrameType != noFrame { - if err := c.flushFrame(true, nil); err != nil { + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + err := c.writer.Close() + if err != nil { return nil, err } } @@ -414,11 +433,24 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } c.writeFrameType = messageType - w := &messageWriter{c} - c.messageWriter = w + c.messageWriter = &messageWriter{c} + + var w io.WriteCloser = c.messageWriter + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + c.writeCompress = true + var err error + w, err = c.newCompressionWriter(w) + if err != nil { + c.writer.Close() + return nil, err + } + } + return w, nil } +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. func (c *Conn) flushFrame(final bool, extra []byte) error { length := c.writePos - maxFrameHeaderSize + len(extra) @@ -426,6 +458,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { if isControl(c.writeFrameType) && (!final || length > maxControlFramePayloadSize) { c.messageWriter = nil + c.writer = nil c.writeFrameType = noFrame c.writePos = maxFrameHeaderSize return errInvalidControlFrame @@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { if final { b0 |= finalBit } + if c.writeCompress { + b0 |= rsv1Bit + } + c.writeCompress = false + b1 := byte(0) if !c.isServer { b1 |= maskBit @@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { c.writeFrameType = continuationFrame if final { c.messageWriter = nil + c.writer = nil c.writeFrameType = noFrame } return c.writeErr @@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) { return n, nil } -func (w *messageWriter) write(final bool, p []byte) (int, error) { +func (w *messageWriter) Write(p []byte) (int, error) { if err := w.err(); err != nil { return 0, err } if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { // Don't buffer large messages. - err := w.c.flushFrame(final, p) + err := w.c.flushFrame(false, p) if err != nil { return 0, err } @@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) { return nn, nil } -func (w *messageWriter) Write(p []byte) (int, error) { - return w.write(false, p) -} - func (w *messageWriter) WriteString(p string) (int, error) { if err := w.err(); err != nil { return 0, err @@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) { final := p[0]&finalBit != 0 frameType := int(p[0] & 0xf) - reserved := int((p[0] >> 4) & 0x7) mask := p[1]&maskBit != 0 c.readRemaining = int64(p[1] & 0x7f) - if reserved != 0 { - return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) + c.readDecompress = false + if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { + c.readDecompress = true + p[0] &^= rsv1Bit + } + + if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { + return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) } switch frameType { @@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { break } if frameType == TextMessage || frameType == BinaryMessage { - r := &messageReader{c} - c.messageReader = r + c.messageReader = &messageReader{c} + var r io.Reader = c.messageReader + if c.readDecompress { + r = c.newDecompressionReader(r) + } return frameType, r, nil } }