From 189336417ffdc7e2cb89bf309d22c25bf6321301 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Mon, 21 Jun 2021 11:09:19 -0600 Subject: [PATCH] [FIXED] Websocket compression/decompression issue with continuation frames For compression, continuation frames had the compress bit set, which is wrong since only the first frame should. For decompression, continuation frames were decompressed individually instead of assembling the full payload and then decompressing. Resolves #2287 Signed-off-by: Ivan Kozlovic --- server/websocket.go | 133 +++++++++++++++++++++++++++++++-------- server/websocket_test.go | 57 ++++++++++++++--- 2 files changed, 156 insertions(+), 34 deletions(-) diff --git a/server/websocket.go b/server/websocket.go index b699631da1..ab4ce82228 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -75,7 +75,6 @@ const ( wsFirstFrame = true wsContFrame = false wsFinalFrame = true - wsCompressedFrame = true wsUncompressedFrame = false wsSchemePrefix = "ws" @@ -92,6 +91,7 @@ const ( ) var decompressorPool sync.Pool +var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} // From https://tools.ietf.org/html/rfc6455#section-1.3 var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") @@ -144,7 +144,8 @@ type wsReadInfo struct { mask bool // Incoming leafnode connections may not have masking. mkpos byte mkey [4]byte - buf []byte + cbufs [][]byte + coff int } func (r *wsReadInfo) init() { @@ -292,42 +293,118 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err b = buf[pos : pos+n] pos += n r.rem -= n - if r.fc { - r.buf = append(r.buf, b...) - b = r.buf + // If needed, unmask the buffer + if r.mask { + r.unmask(b) } - if !r.fc || r.rem == 0 { - if r.mask { - r.unmask(b) - } - if r.fc { - // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 - // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader - // does not report unexpected EOF. - b = append(b, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) - br := bytes.NewBuffer(b) - d, _ := decompressorPool.Get().(io.ReadCloser) - if d == nil { - d = flate.NewReader(br) - } else { - d.(flate.Resetter).Reset(br, nil) - } - b, err = ioutil.ReadAll(d) - decompressorPool.Put(d) + addToBufs := true + // Handle compressed message + if r.fc { + // Assume that we may have continuation frames or not the full payload. + addToBufs = false + // Make a copy of the buffer before adding it to the list + // of compressed fragments. + r.cbufs = append(r.cbufs, append([]byte(nil), b...)) + // When we have the final frame and we have read the full payload, + // we can decompress it. + if r.ff && r.rem == 0 { + b, err = r.decompress() if err != nil { return bufs, err } + r.fc = false + // Now we can add to `bufs` + addToBufs = true } + } + // For non compressed frames, or when we have decompressed the + // whole message. + if addToBufs { bufs = append(bufs, b) - if r.rem == 0 { - r.fs, r.fc, r.buf = true, false, nil - } + } + // If payload has been fully read, then indicate that next + // is the start of a frame. + if r.rem == 0 { + r.fs = true } } } return bufs, nil } +func (r *wsReadInfo) Read(dst []byte) (int, error) { + if len(dst) == 0 { + return 0, nil + } + if len(r.cbufs) == 0 { + return 0, io.EOF + } + copied := 0 + rem := len(dst) + for buf := r.cbufs[0]; buf != nil && rem > 0; { + n := len(buf[r.coff:]) + if n > rem { + n = rem + } + copy(dst[copied:], buf[r.coff:r.coff+n]) + copied += n + rem -= n + r.coff += n + buf = r.nextCBuf() + } + return copied, nil +} + +func (r *wsReadInfo) nextCBuf() []byte { + // We still have remaining data in the first buffer + if r.coff != len(r.cbufs[0]) { + return r.cbufs[0] + } + // We read the full first buffer. Reset offset. + r.coff = 0 + // We were at the last buffer, so we are done. + if len(r.cbufs) == 1 { + r.cbufs = nil + return nil + } + // Here we move to the next buffer. + r.cbufs = r.cbufs[1:] + return r.cbufs[0] +} + +func (r *wsReadInfo) ReadByte() (byte, error) { + if len(r.cbufs) == 0 { + return 0, io.EOF + } + b := r.cbufs[0][r.coff] + r.coff++ + r.nextCBuf() + return b, nil +} + +func (r *wsReadInfo) decompress() ([]byte, error) { + r.coff = 0 + // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 + // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader + // does not report unexpected EOF. + r.cbufs = append(r.cbufs, compressLastBlock) + // Get a decompressor from the pool and bind it to this object (wsReadInfo) + // that provides Read() and ReadByte() APIs that will consume the compressed + // buffers (r.cbufs). + d, _ := decompressorPool.Get().(io.ReadCloser) + if d == nil { + d = flate.NewReader(r) + } else { + d.(flate.Resetter).Reset(r, nil) + } + // This will do the decompression. + b, err := ioutil.ReadAll(d) + decompressorPool.Put(d) + // Now reset the compressed buffers list. + r.cbufs = nil + return b, err +} + // Handles the PING, PONG and CLOSE websocket control frames. // // Client lock MUST NOT be held on entry. @@ -1211,7 +1288,9 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { final = true } fh := make([]byte, wsMaxFrameHeaderSize) - n, key := wsFillFrameHeader(fh, mask, first, final, wsCompressedFrame, wsBinaryMessage, lp) + // Only the first frame should be marked as compressed, so pass + // `first` for the compressed boolean. + n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp) if mask { wsMaskBuf(key, p[:lp]) } diff --git a/server/websocket_test.go b/server/websocket_test.go index d3a5de5710..8c52b560be 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -434,7 +434,10 @@ func TestWSReadCompressedFrames(t *testing.T) { // Stress the fact that we use a pool and want to make sure // that if we get a decompressor from the pool, it is properly reset // with the buffer to decompress. - for i := 0; i < 9; i++ { + // Since we unmask the read buffer, reset it now and fill it + // with 10 compressed frames. + rb = nil + for i := 0; i < 10; i++ { rb = append(rb, wsmsg1...) } bufs, err = c.wsRead(ri, tr, rb) @@ -444,6 +447,31 @@ func TestWSReadCompressedFrames(t *testing.T) { if n := len(bufs); n != 10 { t.Fatalf("Unexpected buffer returned: %v", n) } + + // Compress a message and send it in several frames. + buf := &bytes.Buffer{} + compressor, _ := flate.NewWriter(buf, 1) + compressor.Write(uncompressed) + compressor.Flush() + compressed := buf.Bytes() + // The last 4 bytes are dropped + compressed = compressed[:len(compressed)-4] + ncomp := 10 + frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp]) + frag1[0] |= wsRsv1Bit + frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:]) + rb = append([]byte(nil), frag1...) + rb = append(rb, frag2...) + bufs, err = c.wsRead(ri, tr, rb) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if n := len(bufs); n != 1 { + t.Fatalf("Unexpected buffer returned: %v", n) + } + if !bytes.Equal(bufs[0], uncompressed) { + t.Fatalf("Unexpected content: %s", bufs[0]) + } } func TestWSReadCompressedFrameCorrupted(t *testing.T) { @@ -827,15 +855,20 @@ func TestWSReadErrors(t *testing.T) { }, { func() []byte { - frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, true, []byte("frag1")) - frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frag2")) - frag2[0] |= wsRsv1Bit - all := append([]byte(nil), frag1...) - all = append(all, frag2...) + frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame")) + frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation")) + all := append([]byte(nil), frame...) + all = append(all, frag...) return all }, "invalid continuation frame", 2, }, + { + func() []byte { + return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame")) + }, + "invalid continuation frame", 1, + }, { func() []byte { return testWSCreateClientMsg(99, 1, false, false, []byte("hello")) @@ -2914,7 +2947,17 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) { } } // Check frame headers for the proper formatting. - if i%2 == 1 { + if i%2 == 0 { + // Only the first frame should have the compress bit set. + if b[0]&wsRsv1Bit != 0 { + if i > 0 { + t.Fatalf("Compressed bit should not be in continuation frame") + } + } else if i == 0 { + t.Fatalf("Compressed bit missing") + } + } else { + // Collect the payload bb.Write(b) } }