Skip to content

Commit

Permalink
permessage-deflate decompression support in websocket (#3263)
Browse files Browse the repository at this point in the history
* handshake

* fixup

* +85 autobahn test passes

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* fixup

* add basic send queue

* fixup

* fixup

* fixup

* fix EVERY FAILURE!!!!
  • Loading branch information
KhafraDev committed May 20, 2024
1 parent 0ca9c1e commit 5a564bd
Show file tree
Hide file tree
Showing 10 changed files with 355 additions and 74 deletions.
1 change: 1 addition & 0 deletions lib/web/fetch/data-url.js
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ module.exports = {
collectAnHTTPQuotedString,
serializeAMimeType,
removeChars,
removeHTTPWhitespace,
minimizeSupportedMimeType,
HTTP_TOKEN_CODEPOINTS,
isomorphicDecode
Expand Down
27 changes: 18 additions & 9 deletions lib/web/websocket/connection.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const {
kReceivedClose,
kResponse
} = require('./symbols')
const { fireEvent, failWebsocketConnection, isClosing, isClosed, isEstablished } = require('./util')
const { fireEvent, failWebsocketConnection, isClosing, isClosed, isEstablished, parseExtensions } = require('./util')
const { channels } = require('../../core/diagnostics')
const { CloseEvent } = require('./events')
const { makeRequest } = require('../fetch/request')
Expand All @@ -31,7 +31,7 @@ try {
* @param {URL} url
* @param {string|string[]} protocols
* @param {import('./websocket').WebSocket} ws
* @param {(response: any) => void} onEstablish
* @param {(response: any, extensions: string[] | undefined) => void} onEstablish
* @param {Partial<import('../../types/websocket').WebSocketInit>} options
*/
function establishWebSocketConnection (url, protocols, client, ws, onEstablish, options) {
Expand Down Expand Up @@ -91,12 +91,11 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
// 9. Let permessageDeflate be a user-agent defined
// "permessage-deflate" extension header value.
// https://github.com/mozilla/gecko-dev/blob/ce78234f5e653a5d3916813ff990f053510227bc/netwerk/protocol/websocket/WebSocketChannel.cpp#L2673
// TODO: enable once permessage-deflate is supported
const permessageDeflate = '' // 'permessage-deflate; 15'
const permessageDeflate = 'permessage-deflate; client_max_window_bits'

// 10. Append (`Sec-WebSocket-Extensions`, permessageDeflate) to
// request’s header list.
// request.headersList.append('sec-websocket-extensions', permessageDeflate)
request.headersList.append('sec-websocket-extensions', permessageDeflate)

// 11. Fetch request with useParallelQueue set to true, and
// processResponse given response being these steps:
Expand Down Expand Up @@ -167,10 +166,15 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
// header field to determine which extensions are requested is
// discussed in Section 9.1.)
const secExtension = response.headersList.get('Sec-WebSocket-Extensions')
let extensions

if (secExtension !== null && secExtension !== permessageDeflate) {
failWebsocketConnection(ws, 'Received different permessage-deflate than the one set.')
return
if (secExtension !== null) {
extensions = parseExtensions(secExtension)

if (!extensions.has('permessage-deflate')) {
failWebsocketConnection(ws, 'Sec-WebSocket-Extensions header does not match.')
return
}
}

// 6. If the response includes a |Sec-WebSocket-Protocol| header field
Expand Down Expand Up @@ -206,7 +210,7 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
})
}

onEstablish(response)
onEstablish(response, extensions)
}
})

Expand Down Expand Up @@ -290,6 +294,11 @@ function onSocketData (chunk) {
*/
function onSocketClose () {
const { ws } = this
const { [kResponse]: response } = ws

response.socket.off('data', onSocketData)
response.socket.off('close', onSocketClose)
response.socket.off('error', onSocketError)

// If the TCP connection was closed after the
// WebSocket closing handshake was completed, the WebSocket connection
Expand Down
10 changes: 9 additions & 1 deletion lib/web/websocket/constants.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ const parserStates = {

const emptyBuffer = Buffer.allocUnsafe(0)

const sendHints = {
string: 1,
typedArray: 2,
arrayBuffer: 3,
blob: 4
}

module.exports = {
uid,
sentCloseFrameState,
Expand All @@ -54,5 +61,6 @@ module.exports = {
opcodes,
maxUnsigned16Bit,
parserStates,
emptyBuffer
emptyBuffer,
sendHints
}
70 changes: 70 additions & 0 deletions lib/web/websocket/permessage-deflate.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
'use strict'

const { createInflateRaw, Z_DEFAULT_WINDOWBITS } = require('node:zlib')
const { isValidClientWindowBits } = require('./util')

const tail = Buffer.from([0x00, 0x00, 0xff, 0xff])
const kBuffer = Symbol('kBuffer')
const kLength = Symbol('kLength')

class PerMessageDeflate {
/** @type {import('node:zlib').InflateRaw} */
#inflate

#options = {}

constructor (extensions) {
this.#options.serverNoContextTakeover = extensions.has('server_no_context_takeover')
this.#options.serverMaxWindowBits = extensions.get('server_max_window_bits')
}

decompress (chunk, fin, callback) {
// An endpoint uses the following algorithm to decompress a message.
// 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
// payload of the message.
// 2. Decompress the resulting data using DEFLATE.

if (!this.#inflate) {
let windowBits = Z_DEFAULT_WINDOWBITS

if (this.#options.serverMaxWindowBits) { // empty values default to Z_DEFAULT_WINDOWBITS
if (!isValidClientWindowBits(this.#options.serverMaxWindowBits)) {
callback(new Error('Invalid server_max_window_bits'))
return
}

windowBits = Number.parseInt(this.#options.serverMaxWindowBits)
}

this.#inflate = createInflateRaw({ windowBits })
this.#inflate[kBuffer] = []
this.#inflate[kLength] = 0

this.#inflate.on('data', (data) => {
this.#inflate[kBuffer].push(data)
this.#inflate[kLength] += data.length
})

this.#inflate.on('error', (err) => {
this.#inflate = null
callback(err)
})
}

this.#inflate.write(chunk)
if (fin) {
this.#inflate.write(tail)
}

this.#inflate.flush(() => {
const full = Buffer.concat(this.#inflate[kBuffer], this.#inflate[kLength])

this.#inflate[kBuffer].length = 0
this.#inflate[kLength] = 0

callback(null, full)
})
}
}

module.exports = { PerMessageDeflate }
79 changes: 63 additions & 16 deletions lib/web/websocket/receiver.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const {
} = require('./util')
const { WebsocketFrameSend } = require('./frame')
const { closeWebSocketConnection } = require('./connection')
const { PerMessageDeflate } = require('./permessage-deflate')

// This code was influenced by ws released under the MIT license.
// Copyright (c) 2011 Einar Otto Stangvik <einaros@gmail.com>
Expand All @@ -33,10 +34,18 @@ class ByteParser extends Writable {
#info = {}
#fragments = []

constructor (ws) {
/** @type {Map<string, PerMessageDeflate>} */
#extensions

constructor (ws, extensions) {
super()

this.ws = ws
this.#extensions = extensions == null ? new Map() : extensions

if (this.#extensions.has('permessage-deflate')) {
this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions))
}
}

/**
Expand Down Expand Up @@ -91,7 +100,16 @@ class ByteParser extends Writable {
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
if (rsv1 !== 0 || rsv2 !== 0 || rsv3 !== 0) {
// This document allocates the RSV1 bit of the WebSocket header for
// PMCEs and calls the bit the "Per-Message Compressed" bit. On a
// WebSocket connection where a PMCE is in use, this bit indicates
// whether a message is compressed or not.
if (rsv1 !== 0 && !this.#extensions.has('permessage-deflate')) {
failWebsocketConnection(this.ws, 'Expected RSV1 to be clear.')
return
}

if (rsv2 !== 0 || rsv3 !== 0) {
failWebsocketConnection(this.ws, 'RSV1, RSV2, RSV3 must be clear')
return
}
Expand Down Expand Up @@ -122,7 +140,7 @@ class ByteParser extends Writable {
return
}

if (isContinuationFrame(opcode) && this.#fragments.length === 0) {
if (isContinuationFrame(opcode) && this.#fragments.length === 0 && !this.#info.compressed) {
failWebsocketConnection(this.ws, 'Unexpected continuation frame')
return
}
Expand All @@ -138,6 +156,7 @@ class ByteParser extends Writable {

if (isTextBinaryFrame(opcode)) {
this.#info.binaryType = opcode
this.#info.compressed = rsv1 !== 0
}

this.#info.opcode = opcode
Expand Down Expand Up @@ -185,21 +204,50 @@ class ByteParser extends Writable {

if (isControlFrame(this.#info.opcode)) {
this.#loop = this.parseControlFrame(body)
this.#state = parserStates.INFO
} else {
this.#fragments.push(body)

// If the frame is not fragmented, a message has been received.
// If the frame is fragmented, it will terminate with a fin bit set
// and an opcode of 0 (continuation), therefore we handle that when
// parsing continuation frames, not here.
if (!this.#info.fragmented && this.#info.fin) {
const fullMessage = Buffer.concat(this.#fragments)
websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage)
this.#fragments.length = 0
if (!this.#info.compressed) {
this.#fragments.push(body)

// If the frame is not fragmented, a message has been received.
// If the frame is fragmented, it will terminate with a fin bit set
// and an opcode of 0 (continuation), therefore we handle that when
// parsing continuation frames, not here.
if (!this.#info.fragmented && this.#info.fin) {
const fullMessage = Buffer.concat(this.#fragments)
websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage)
this.#fragments.length = 0
}

this.#state = parserStates.INFO
} else {
this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => {
if (error) {
closeWebSocketConnection(this.ws, 1007, error.message, error.message.length)
return
}

this.#fragments.push(data)

if (!this.#info.fin) {
this.#state = parserStates.INFO
this.#loop = true
this.run(callback)
return
}

websocketMessageReceived(this.ws, this.#info.binaryType, Buffer.concat(this.#fragments))

this.#loop = true
this.#state = parserStates.INFO
this.run(callback)
this.#fragments.length = 0
})

this.#loop = false
break
}
}

this.#state = parserStates.INFO
}
}
}
Expand Down Expand Up @@ -333,7 +381,6 @@ class ByteParser extends Writable {
this.ws[kReadyState] = states.CLOSING
this.ws[kReceivedClose] = true

this.end()
return false
} else if (opcode === opcodes.PING) {
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
Expand Down
85 changes: 85 additions & 0 deletions lib/web/websocket/sender.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
'use strict'

const { WebsocketFrameSend } = require('./frame')
const { opcodes, sendHints } = require('./constants')

/** @type {Uint8Array} */
const FastBuffer = Buffer[Symbol.species]

class SendQueue {
#queued = new Set()
#size = 0

/** @type {import('net').Socket} */
#socket

constructor (socket) {
this.#socket = socket
}

add (item, cb, hint) {
if (hint !== sendHints.blob) {
const data = clone(item, hint)

if (this.#size === 0) {
this.#dispatch(data, cb, hint)
} else {
this.#queued.add([data, cb, true, hint])
this.#size++

this.#run()
}

return
}

const promise = item.arrayBuffer()
const queue = [null, cb, false, hint]
promise.then((ab) => {
queue[0] = clone(ab, hint)
queue[2] = true

this.#run()
})

this.#queued.add(queue)
this.#size++
}

#run () {
for (const queued of this.#queued) {
const [data, cb, done, hint] = queued

if (!done) return

this.#queued.delete(queued)
this.#size--

this.#dispatch(data, cb, hint)
}
}

#dispatch (data, cb, hint) {
const frame = new WebsocketFrameSend()
const opcode = hint === sendHints.string ? opcodes.TEXT : opcodes.BINARY

frame.frameData = data
const buffer = frame.createFrame(opcode)

this.#socket.write(buffer, cb)
}
}

function clone (data, hint) {
switch (hint) {
case sendHints.string:
return Buffer.from(data)
case sendHints.arrayBuffer:
case sendHints.blob:
return new FastBuffer(data)
case sendHints.typedArray:
return Buffer.copyBytesFrom(data)
}
}

module.exports = { SendQueue }

0 comments on commit 5a564bd

Please sign in to comment.