Skip to content

Commit

Permalink
[FIXED] Websocket compression/decompression issue with continuation f…
Browse files Browse the repository at this point in the history
…rames

For compression, continuation frames had the compress bit set, which is
wrong since only the first frame should.
For decompression, continuation frames were decompressed individually
instead of assembling the full payload and then decompressing.

Resolves #2287

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
  • Loading branch information
kozlovic committed Jun 21, 2021
1 parent 6129562 commit 1893364
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 34 deletions.
133 changes: 106 additions & 27 deletions server/websocket.go
Expand Up @@ -75,7 +75,6 @@ const (
wsFirstFrame = true
wsContFrame = false
wsFinalFrame = true
wsCompressedFrame = true
wsUncompressedFrame = false

wsSchemePrefix = "ws"
Expand All @@ -92,6 +91,7 @@ const (
)

var decompressorPool sync.Pool
var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}

// From https://tools.ietf.org/html/rfc6455#section-1.3
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
Expand Down Expand Up @@ -144,7 +144,8 @@ type wsReadInfo struct {
mask bool // Incoming leafnode connections may not have masking.
mkpos byte
mkey [4]byte
buf []byte
cbufs [][]byte
coff int
}

func (r *wsReadInfo) init() {
Expand Down Expand Up @@ -292,42 +293,118 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err
b = buf[pos : pos+n]
pos += n
r.rem -= n
if r.fc {
r.buf = append(r.buf, b...)
b = r.buf
// If needed, unmask the buffer
if r.mask {
r.unmask(b)
}
if !r.fc || r.rem == 0 {
if r.mask {
r.unmask(b)
}
if r.fc {
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
b = append(b, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
br := bytes.NewBuffer(b)
d, _ := decompressorPool.Get().(io.ReadCloser)
if d == nil {
d = flate.NewReader(br)
} else {
d.(flate.Resetter).Reset(br, nil)
}
b, err = ioutil.ReadAll(d)
decompressorPool.Put(d)
addToBufs := true
// Handle compressed message
if r.fc {
// Assume that we may have continuation frames or not the full payload.
addToBufs = false
// Make a copy of the buffer before adding it to the list
// of compressed fragments.
r.cbufs = append(r.cbufs, append([]byte(nil), b...))
// When we have the final frame and we have read the full payload,
// we can decompress it.
if r.ff && r.rem == 0 {
b, err = r.decompress()
if err != nil {
return bufs, err
}
r.fc = false
// Now we can add to `bufs`
addToBufs = true
}
}
// For non compressed frames, or when we have decompressed the
// whole message.
if addToBufs {
bufs = append(bufs, b)
if r.rem == 0 {
r.fs, r.fc, r.buf = true, false, nil
}
}
// If payload has been fully read, then indicate that next
// is the start of a frame.
if r.rem == 0 {
r.fs = true
}
}
}
return bufs, nil
}

func (r *wsReadInfo) Read(dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if len(r.cbufs) == 0 {
return 0, io.EOF
}
copied := 0
rem := len(dst)
for buf := r.cbufs[0]; buf != nil && rem > 0; {
n := len(buf[r.coff:])
if n > rem {
n = rem
}
copy(dst[copied:], buf[r.coff:r.coff+n])
copied += n
rem -= n
r.coff += n
buf = r.nextCBuf()
}
return copied, nil
}

func (r *wsReadInfo) nextCBuf() []byte {
// We still have remaining data in the first buffer
if r.coff != len(r.cbufs[0]) {
return r.cbufs[0]
}
// We read the full first buffer. Reset offset.
r.coff = 0
// We were at the last buffer, so we are done.
if len(r.cbufs) == 1 {
r.cbufs = nil
return nil
}
// Here we move to the next buffer.
r.cbufs = r.cbufs[1:]
return r.cbufs[0]
}

func (r *wsReadInfo) ReadByte() (byte, error) {
if len(r.cbufs) == 0 {
return 0, io.EOF
}
b := r.cbufs[0][r.coff]
r.coff++
r.nextCBuf()
return b, nil
}

func (r *wsReadInfo) decompress() ([]byte, error) {
r.coff = 0
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
r.cbufs = append(r.cbufs, compressLastBlock)
// Get a decompressor from the pool and bind it to this object (wsReadInfo)
// that provides Read() and ReadByte() APIs that will consume the compressed
// buffers (r.cbufs).
d, _ := decompressorPool.Get().(io.ReadCloser)
if d == nil {
d = flate.NewReader(r)
} else {
d.(flate.Resetter).Reset(r, nil)
}
// This will do the decompression.
b, err := ioutil.ReadAll(d)
decompressorPool.Put(d)
// Now reset the compressed buffers list.
r.cbufs = nil
return b, err
}

// Handles the PING, PONG and CLOSE websocket control frames.
//
// Client lock MUST NOT be held on entry.
Expand Down Expand Up @@ -1211,7 +1288,9 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
final = true
}
fh := make([]byte, wsMaxFrameHeaderSize)
n, key := wsFillFrameHeader(fh, mask, first, final, wsCompressedFrame, wsBinaryMessage, lp)
// Only the first frame should be marked as compressed, so pass
// `first` for the compressed boolean.
n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp)
if mask {
wsMaskBuf(key, p[:lp])
}
Expand Down
57 changes: 50 additions & 7 deletions server/websocket_test.go
Expand Up @@ -434,7 +434,10 @@ func TestWSReadCompressedFrames(t *testing.T) {
// Stress the fact that we use a pool and want to make sure
// that if we get a decompressor from the pool, it is properly reset
// with the buffer to decompress.
for i := 0; i < 9; i++ {
// Since we unmask the read buffer, reset it now and fill it
// with 10 compressed frames.
rb = nil
for i := 0; i < 10; i++ {
rb = append(rb, wsmsg1...)
}
bufs, err = c.wsRead(ri, tr, rb)
Expand All @@ -444,6 +447,31 @@ func TestWSReadCompressedFrames(t *testing.T) {
if n := len(bufs); n != 10 {
t.Fatalf("Unexpected buffer returned: %v", n)
}

// Compress a message and send it in several frames.
buf := &bytes.Buffer{}
compressor, _ := flate.NewWriter(buf, 1)
compressor.Write(uncompressed)
compressor.Flush()
compressed := buf.Bytes()
// The last 4 bytes are dropped
compressed = compressed[:len(compressed)-4]
ncomp := 10
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp])
frag1[0] |= wsRsv1Bit
frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:])
rb = append([]byte(nil), frag1...)
rb = append(rb, frag2...)
bufs, err = c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if !bytes.Equal(bufs[0], uncompressed) {
t.Fatalf("Unexpected content: %s", bufs[0])
}
}

func TestWSReadCompressedFrameCorrupted(t *testing.T) {
Expand Down Expand Up @@ -827,15 +855,20 @@ func TestWSReadErrors(t *testing.T) {
},
{
func() []byte {
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, true, []byte("frag1"))
frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frag2"))
frag2[0] |= wsRsv1Bit
all := append([]byte(nil), frag1...)
all = append(all, frag2...)
frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame"))
frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation"))
all := append([]byte(nil), frame...)
all = append(all, frag...)
return all
},
"invalid continuation frame", 2,
},
{
func() []byte {
return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame"))
},
"invalid continuation frame", 1,
},
{
func() []byte {
return testWSCreateClientMsg(99, 1, false, false, []byte("hello"))
Expand Down Expand Up @@ -2914,7 +2947,17 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) {
}
}
// Check frame headers for the proper formatting.
if i%2 == 1 {
if i%2 == 0 {
// Only the first frame should have the compress bit set.
if b[0]&wsRsv1Bit != 0 {
if i > 0 {
t.Fatalf("Compressed bit should not be in continuation frame")
}
} else if i == 0 {
t.Fatalf("Compressed bit missing")
}
} else {
// Collect the payload
bb.Write(b)
}
}
Expand Down

0 comments on commit 1893364

Please sign in to comment.