Permalink
Browse files

Implement support for the Websocket deflate-frame (draft 05) extension.

It's currently implemented in Chrome as x-webkit-deflate-frame.

This (or any other extensions that may be added) can be disabled by
overriding allow_extension in the Websocket handler.
  • Loading branch information...
flodiebold committed Aug 5, 2012
1 parent 303e963 commit 5052e5cfeccaaee2a8ed682fda6c3de793ebcd0e
Showing with 95 additions and 3 deletions.
  1. +95 −3 tornado/websocket.py
View
@@ -29,7 +29,9 @@
import base64
import tornado.escape
import tornado.web
+import zlib
+from tornado.httputil import _parse_header
from tornado.util import bytes_type, b
@@ -183,6 +185,18 @@ def close(self):
"""
self.ws_connection.close()
+ def allow_extension(self, name):
+ """Decides whether a given extension is used.
+
+ Override to disallow specific extensions.
+ """
+ return True
+
+ def get_extensions(self):
+ """Returns a list of the used extensions."""
+ return self.ws_connection.get_extensions()
+
+
def allow_draft76(self):
"""Override to enable support for the older "draft76" protocol.
@@ -265,6 +279,9 @@ def wrapper(*args, **kwargs):
def on_connection_close(self):
self._abort()
+ def get_extensions(self):
+ return []
+
def _abort(self):
"""Instantly aborts the WebSocket connection by closing the socket"""
self.client_terminated = True
@@ -435,6 +452,40 @@ def close(self):
self._waiting = self.stream.io_loop.add_timeout(
time.time() + 5, self._abort)
+class Deflater(object):
+ def __init__(self, window_bits, no_context_takeover):
+ self._window_bits = window_bits
+ self._no_context_takeover = no_context_takeover
+ self._compressobj = None
+
+ def compress(self, data):
+ # data is expected to be a byte string
+ if self._compressobj == None or self._no_context_takeover:
+ self._compressobj = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+ zlib.DEFLATED, -self._window_bits)
+
+ compressed = self._compressobj.compress(data)
+ compressed += self._compressobj.flush(zlib.Z_SYNC_FLUSH)
+ return compressed[:-4] # Remove 4 bytes as in deflate-frame spec
+
+class Inflater(object):
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self._decompressobj = zlib.decompressobj(-zlib.MAX_WBITS)
+
+ def decompress(self, data):
+ # data is expected to be a byte array
+ # Restore 4 bytes as in deflate-frame spec
+ data.fromstring(b("\x00\x00\xff\xff"))
+ decompressed = array.array("B", self._decompressobj.decompress(data))
+ while len(self._decompressobj.unused_data):
+ unused = self._decompressobj.unused_data
+ self.reset()
+ decompressed.fromstring(self._decompressobj.decompress(unused))
+ return decompressed
+
class WebSocketProtocol13(WebSocketProtocol):
"""Implementation of the WebSocket protocol from RFC 6455.
@@ -445,12 +496,17 @@ class WebSocketProtocol13(WebSocketProtocol):
def __init__(self, handler):
WebSocketProtocol.__init__(self, handler)
self._final_frame = False
+ self._frame_reserved_bits = None
self._frame_opcode = None
self._frame_mask = None
self._frame_length = None
self._fragmented_message_buffer = None
self._fragmented_message_opcode = None
self._waiting = None
+ self._forbidden_reserved_bits = 0x70
+ self._deflater = None
+ self._inflater = None
+ self._extensions = []
def accept_connection(self):
try:
@@ -461,6 +517,9 @@ def accept_connection(self):
self._abort()
return
+ def get_extensions(self):
+ return self._extensions
+
def _handle_websocket_headers(self):
"""Verifies all invariant- and required headers
@@ -478,6 +537,30 @@ def _challenge_response(self):
sha1.update(b("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) # Magic value
return tornado.escape.native_str(base64.b64encode(sha1.digest()))
+ def _select_extensions(self):
+ ext_header = self.request.headers.get("Sec-Websocket-Extensions", "")
+ exts = [ext.strip() for ext in ext_header.split(",")]
+ accepted_exts = []
+ for ext in exts:
+ ext_id, params = _parse_header(ext)
+ if ext_id in ["deflate-frame", "x-webkit-deflate-frame"]:
+ if not self.handler.allow_extension("deflate-frame"): continue
+ if "max_window_bits" in params:
+ window_bits = int(params.get("max_window_bits"))
+ else:
+ window_bits = zlib.MAX_WBITS
+ no_context_takeover = ("no_context_takeover" in params)
+ self._deflater = Deflater(window_bits, no_context_takeover)
+ self._inflater = Inflater()
+ self._forbidden_reserved_bits &= 0x30 # Allow RSV1
+ accepted_exts.append(ext_id)
+ self._extensions.append("deflate-frame")
+
+ if accepted_exts:
+ return "Sec-Websocket-Extensions: %s\r\n" % ",".join(accepted_exts)
+ else:
+ return ""
+
def _accept_connection(self):
subprotocol_header = ''
subprotocols = self.request.headers.get("Sec-WebSocket-Protocol", '')
@@ -488,13 +571,16 @@ def _accept_connection(self):
assert selected in subprotocols
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
+ extension_header = self._select_extensions()
+
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n"
"%s"
- "\r\n" % (self._challenge_response(), subprotocol_header)))
+ "%s"
+ "\r\n" % (self._challenge_response(), extension_header, subprotocol_header)))
self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
self._receive_frame()
@@ -504,6 +590,9 @@ def _write_frame(self, fin, opcode, data):
finbit = 0x80
else:
finbit = 0
+ if self._deflater:
+ data = self._deflater.compress(data)
+ finbit |= 0x40 # RSV1 aka COMP bit
frame = struct.pack("B", finbit | opcode)
l = len(data)
if l < 126:
@@ -531,10 +620,10 @@ def _receive_frame(self):
def _on_frame_start(self, data):
header, payloadlen = struct.unpack("BB", data)
self._final_frame = header & 0x80
- reserved_bits = header & 0x70
+ self._frame_reserved_bits = header & 0x70
self._frame_opcode = header & 0xf
self._frame_opcode_is_control = self._frame_opcode & 0x8
- if reserved_bits:
+ if self._frame_reserved_bits & self._forbidden_reserved_bits:
# client is using as-yet-undefined extensions; abort
self._abort()
return
@@ -572,6 +661,9 @@ def _on_frame_data(self, data):
for i in xrange(len(data)):
unmasked[i] = unmasked[i] ^ self._frame_mask[i % 4]
+ if self._inflater and (self._frame_reserved_bits & 0x40) > 0:
+ unmasked = self._inflater.decompress(unmasked)
+
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with

0 comments on commit 5052e5c

Please sign in to comment.