Skip to content

Commit

Permalink
Add hooks to support RFC 7692 (per-message compression extension)
Browse files Browse the repository at this point in the history
Add newCompressionWriter and newDecompressionReader fields to Conn. When
not nil, these functions are used to create a compression/decompression
wrapper around an underlying message writer/reader.

Add code to set and check for RSV1 frame header bit.

Add functions compressNoContextTakeover and decompressNoContextTakeover
for creating no context takeover wrappers around an underlying message
writer/reader.

Work remaining:

- Add fields to Dialer and Upgrader for specifying compression options.
- Add compression negotiation to Dialer and Upgrader.
- Add function to enable/disable write compression:

    // EnableWriteCompression enables and disables write compression of
    // subsequent text and binary messages. This function is a noop if
    // compression was not negotiated with the peer.
    func (c *Conn) EnableWriteCompression(enable bool) {
            c.enableWriteCompression = enable
    }
  • Loading branch information
garyburd committed Jun 30, 2016
1 parent b5389d0 commit a87eae1
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 32 deletions.
85 changes: 85 additions & 0 deletions compression.go
@@ -0,0 +1,85 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package websocket

import (
"compress/flate"
"errors"
"io"
"strings"
)

func decompressNoContextTakeover(r io.Reader) io.Reader {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"

return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
}

func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
tw := &truncWriter{w: w}
fw, err := flate.NewWriter(tw, 3)
return &flateWrapper{fw: fw, tw: tw}, err
}

// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}

func (w *truncWriter) Write(p []byte) (int, error) {
n := 0

// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}

m := len(p)
if m > len(w.p) {
m = len(w.p)
}

if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}

copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}

type flateWrapper struct {
fw *flate.Writer
tw *truncWriter
}

func (w *flateWrapper) Write(p []byte) (int, error) {
return w.fw.Write(p)
}

func (w *flateWrapper) Close() error {
err1 := w.fw.Flush()
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}
31 changes: 31 additions & 0 deletions compression_test.go
@@ -0,0 +1,31 @@
package websocket

import (
"bytes"
"io"
"testing"
)

type nopCloser struct{ io.Writer }

func (nopCloser) Close() error { return nil }

func TestTruncWriter(t *testing.T) {
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
for n := 1; n <= 10; n++ {
var b bytes.Buffer
w := &truncWriter{w: nopCloser{&b}}
p := []byte(data)
for len(p) > 0 {
m := len(p)
if m > n {
m = n
}
w.Write(p[:m])
p = p[m:]
}
if b.String() != data[:len(data)-len(w.p)] {
t.Errorf("%d: %q", n, b.String())
}
}
}
107 changes: 75 additions & 32 deletions conn.go
Expand Up @@ -18,11 +18,19 @@ import (
)

const (
// Frame header byte 0 bits from Section 5.2 of RFC 6455
finalBit = 1 << 7
rsv1Bit = 1 << 6
rsv2Bit = 1 << 5
rsv3Bit = 1 << 4

// Frame header byte 1 bits from Section 5.2 of RFC 6455
maskBit = 1 << 7

maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
maxControlFramePayloadSize = 125
finalBit = 1 << 7
maskBit = 1 << 7
writeWait = time.Second

writeWait = time.Second

defaultReadBufferSize = 4096
defaultWriteBufferSize = 4096
Expand Down Expand Up @@ -230,17 +238,20 @@ type Conn struct {
subprotocol string

// Write fields
mu chan bool // used as mutex to protect write to conn and closeSent
closeSent bool // true if close message was sent

// Message writer fields.
mu chan bool // used as mutex to protect write to conn and closeSent
closeSent bool // whether close message was sent
writeErr error
writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame.
writeDeadline time.Time
messageWriter *messageWriter // the current low-level message writer
writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection
messageWriter *messageWriter // the current writer

enableWriteCompression bool
writeCompress bool // whether next call to flushFrame should set RSV1
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)

// Read fields
readErr error
Expand All @@ -254,7 +265,10 @@ type Conn struct {
handlePong func(string) error
handlePing func(string) error
readErrCount int
messageReader *messageReader // the current reader
messageReader *messageReader // the current low-level reader

readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.Reader
}

func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
Expand All @@ -272,14 +286,15 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
}

c := &Conn{
isServer: isServer,
br: bufio.NewReaderSize(conn, readBufferSize),
conn: conn,
mu: mu,
readFinal: true,
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
writeFrameType: noFrame,
writePos: maxFrameHeaderSize,
isServer: isServer,
br: bufio.NewReaderSize(conn, readBufferSize),
conn: conn,
mu: mu,
readFinal: true,
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
writeFrameType: noFrame,
writePos: maxFrameHeaderSize,
enableWriteCompression: true,
}
c.SetPingHandler(nil)
c.SetPongHandler(nil)
Expand Down Expand Up @@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
return nil, c.writeErr
}

if c.writeFrameType != noFrame {
if err := c.flushFrame(true, nil); err != nil {
// Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
if c.writer != nil {
err := c.writer.Close()
if err != nil {
return nil, err
}
}
Expand All @@ -414,18 +433,32 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
}

c.writeFrameType = messageType
w := &messageWriter{c}
c.messageWriter = w
c.messageWriter = &messageWriter{c}

var w io.WriteCloser = c.messageWriter
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
c.writeCompress = true
var err error
w, err = c.newCompressionWriter(w)
if err != nil {
c.writer.Close()
return nil, err
}
}

return w, nil
}

// flushFrame writes buffered data and extra as a frame to the network. The
// final argument indicates that this is the last frame in the message.
func (c *Conn) flushFrame(final bool, extra []byte) error {
length := c.writePos - maxFrameHeaderSize + len(extra)

// Check for invalid control frames.
if isControl(c.writeFrameType) &&
(!final || length > maxControlFramePayloadSize) {
c.messageWriter = nil
c.writer = nil
c.writeFrameType = noFrame
c.writePos = maxFrameHeaderSize
return errInvalidControlFrame
Expand All @@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
if final {
b0 |= finalBit
}
if c.writeCompress {
b0 |= rsv1Bit
}
c.writeCompress = false

b1 := byte(0)
if !c.isServer {
b1 |= maskBit
Expand Down Expand Up @@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
c.writeFrameType = continuationFrame
if final {
c.messageWriter = nil
c.writer = nil
c.writeFrameType = noFrame
}
return c.writeErr
Expand Down Expand Up @@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
return n, nil
}

func (w *messageWriter) write(final bool, p []byte) (int, error) {
func (w *messageWriter) Write(p []byte) (int, error) {
if err := w.err(); err != nil {
return 0, err
}

if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
// Don't buffer large messages.
err := w.c.flushFrame(final, p)
err := w.c.flushFrame(false, p)
if err != nil {
return 0, err
}
Expand All @@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
return nn, nil
}

func (w *messageWriter) Write(p []byte) (int, error) {
return w.write(false, p)
}

func (w *messageWriter) WriteString(p string) (int, error) {
if err := w.err(); err != nil {
return 0, err
Expand Down Expand Up @@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) {

final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
reserved := int((p[0] >> 4) & 0x7)
mask := p[1]&maskBit != 0
c.readRemaining = int64(p[1] & 0x7f)

if reserved != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true
p[0] &^= rsv1Bit
}

if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
}

switch frameType {
Expand Down Expand Up @@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
break
}
if frameType == TextMessage || frameType == BinaryMessage {
r := &messageReader{c}
c.messageReader = r
c.messageReader = &messageReader{c}
var r io.Reader = c.messageReader
if c.readDecompress {
r = c.newDecompressionReader(r)
}
return frameType, r, nil
}
}
Expand Down

0 comments on commit a87eae1

Please sign in to comment.