Skip to content

Commit

Permalink
websocket: support permessage-deflate extension; Thanks to Costas Chr…
Browse files Browse the repository at this point in the history
…istofi and Peter Kovary

Support for compression extension as described in RFC7692 https://tools.ietf.org/html/rfc7692

#417
  • Loading branch information
costasgambit authored and temoto committed Sep 11, 2017
1 parent 82f1877 commit b7d2a25
Show file tree
Hide file tree
Showing 3 changed files with 471 additions and 20 deletions.
2 changes: 2 additions & 0 deletions AUTHORS
Expand Up @@ -152,3 +152,5 @@ Thanks To
* Aayush Kasurde
* Linbing
* Geoffrey Thomas
* Costas Christofi, adding permessage-deflate weboscket extension support
* Peter Kovary, adding permessage-deflate weboscket extension support
183 changes: 163 additions & 20 deletions eventlet/websocket.py
Expand Up @@ -9,6 +9,8 @@
import sys
import time

import zlib

try:
from hashlib import md5, sha1
except ImportError: # pragma NO COVER
Expand Down Expand Up @@ -196,6 +198,76 @@ def _handle_legacy_request(self, environ):
sock.sendall(handshake_reply)
return WebSocket(sock, environ, self.protocol_version)

def _parse_extension_header(self, header):
if header is None:
return None
res = {}
for ext in header.split(","):
parts = ext.split(";")
config = {}
for part in parts[1:]:
key_val = part.split("=")
if len(key_val) == 1:
config[key_val[0].strip().lower()] = True
else:
config[key_val[0].strip().lower()] = key_val[1].strip().strip('"').lower()
res.setdefault(parts[0].strip().lower(), []).append(config)
return res

def _negotiate_permessage_deflate(self, extensions):
if not extensions:
return None
deflate = extensions.get("permessage-deflate")
if deflate is None:
return None
for config in deflate:
# We'll evaluate each config in the client's preferred order and pick
# the first that we can support.
want_config = {
# These are bool options, we can support both
"server_no_context_takeover": config.get("server_no_context_takeover", False),
"client_no_context_takeover": config.get("client_no_context_takeover", False)
}
# These are either bool OR int options. True means the client can accept a value
# for the option, a number means the client wants that specific value.
max_wbits = min(zlib.MAX_WBITS, 15)
mwb = config.get("server_max_window_bits")
if mwb is not None:
if mwb is True:
want_config["server_max_window_bits"] = max_wbits
else:
want_config["server_max_window_bits"] = \
int(config.get("server_max_window_bits", max_wbits))
if not (8 <= want_config["server_max_window_bits"] <= 15):
continue
mwb = config.get("client_max_window_bits")
if mwb is not None:
if mwb is True:
want_config["client_max_window_bits"] = max_wbits
else:
want_config["client_max_window_bits"] = \
int(config.get("client_max_window_bits", max_wbits))
if not (8 <= want_config["client_max_window_bits"] <= 15):
continue
return want_config
return None

def _format_extension_header(self, parsed_extensions):
if not parsed_extensions:
return None
parts = []
for name, config in parsed_extensions.items():
ext_parts = [six.b(name)]
for key, value in config.items():
if value is False:
pass
elif value is True:
ext_parts.append(six.b(key))
else:
ext_parts.append(six.b("%s=%s" % (key, str(value))))
parts.append(b"; ".join(ext_parts))
return b", ".join(parts)

def _handle_hybi_request(self, environ):
if 'eventlet.input' in environ:
sock = environ['eventlet.input'].get_socket()
Expand Down Expand Up @@ -226,9 +298,6 @@ def _handle_hybi_request(self, environ):
if p in self.supported_protocols:
negotiated_protocol = p
break
# extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS', None)
# if extensions:
# extensions = [i.strip() for i in extensions.split(',')]

key = environ['HTTP_SEC_WEBSOCKET_KEY']
response = base64.b64encode(sha1(six.b(key) + PROTOCOL_GUID).digest())
Expand All @@ -238,9 +307,22 @@ def _handle_hybi_request(self, environ):
b"Sec-WebSocket-Accept: " + response]
if negotiated_protocol:
handshake_reply.append(b"Sec-WebSocket-Protocol: " + six.b(negotiated_protocol))

parsed_extensions = {}
extensions = self._parse_extension_header(environ.get("HTTP_SEC_WEBSOCKET_EXTENSIONS"))

deflate = self._negotiate_permessage_deflate(extensions)
if deflate is not None:
parsed_extensions["permessage-deflate"] = deflate

formatted_ext = self._format_extension_header(parsed_extensions)
if formatted_ext is not None:
handshake_reply.append(b"Sec-WebSocket-Extensions: " + formatted_ext)

sock.sendall(b'\r\n'.join(handshake_reply) + b'\r\n\r\n')
return RFC6455WebSocket(sock, environ, self.protocol_version,
protocol=negotiated_protocol)
protocol=negotiated_protocol,
extensions=parsed_extensions)

def _extract_number(self, value):
"""
Expand Down Expand Up @@ -296,8 +378,7 @@ def __init__(self, sock, environ, version=76):
self._msgs = collections.deque()
self._sendlock = semaphore.Semaphore()

@staticmethod
def _pack_message(message):
def _pack_message(self, message):
"""Pack the message inside ``00`` and ``FF``
As per the dataframing section (5.3) for the websocket spec
Expand Down Expand Up @@ -409,11 +490,15 @@ class ProtocolError(ValueError):


class RFC6455WebSocket(WebSocket):
def __init__(self, sock, environ, version=13, protocol=None, client=False):
def __init__(self, sock, environ, version=13, protocol=None, client=False, extensions=None):
super(RFC6455WebSocket, self).__init__(sock, environ, version)
self.iterator = self._iter_frames()
self.client = client
self.protocol = protocol
self.extensions = extensions or {}

self._deflate_enc = None
self._deflate_dec = None

class UTF8Decoder(object):
def __init__(self):
Expand All @@ -436,6 +521,45 @@ def decode(self, data, final=False):
raise ValueError('Data is not valid unicode')
return self.decoder.decode(data, final)

def _get_permessage_deflate_enc(self):
options = self.extensions.get("permessage-deflate")
if options is None:
return None

def _make():
return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
-options.get("client_max_window_bits" if self.client
else "server_max_window_bits",
zlib.MAX_WBITS))

if options.get("client_no_context_takeover" if self.client
else "server_no_context_takeover"):
# This option means we have to make a new one every time
return _make()
else:
if self._deflate_enc is None:
self._deflate_enc = _make()
return self._deflate_enc

def _get_permessage_deflate_dec(self, rsv1):
options = self.extensions.get("permessage-deflate")
if options is None or not rsv1:
return None

def _make():
return zlib.decompressobj(-options.get("server_max_window_bits" if self.client
else "client_max_window_bits",
zlib.MAX_WBITS))

if options.get("server_no_context_takeover" if self.client
else "client_no_context_takeover"):
# This option means we have to make a new one every time
return _make()
else:
if self._deflate_dec is None:
self._deflate_dec = _make()
return self._deflate_dec

def _get_bytes(self, numbytes):
data = b''
while len(data) < numbytes:
Expand All @@ -446,20 +570,24 @@ def _get_bytes(self, numbytes):
return data

class Message(object):
def __init__(self, opcode, decoder=None):
def __init__(self, opcode, decoder=None, decompressor=None):
self.decoder = decoder
self.data = []
self.finished = False
self.opcode = opcode
self.decompressor = decompressor

def push(self, data, final=False):
if self.decoder:
data = self.decoder.decode(data, final=final)
self.finished = final
self.data.append(data)

def getvalue(self):
return ('' if self.decoder else b'').join(self.data)
data = b"".join(self.data)
if not self.opcode & 8 and self.decompressor:
data = self.decompressor.decompress(data + b'\x00\x00\xff\xff')
if self.decoder:
data = self.decoder.decode(data, self.finished)
return data

@staticmethod
def _apply_mask(data, mask, length=None, offset=0):
Expand Down Expand Up @@ -523,16 +651,21 @@ def _iter_frames(self):

def _recv_frame(self, message=None):
recv = self._get_bytes

# Unpacking the frame described in Section 5.2 of RFC6455
# (https://tools.ietf.org/html/rfc6455#section-5.2)
header = recv(2)
a, b = struct.unpack('!BB', header)
finished = a >> 7 == 1
rsv123 = a >> 4 & 7
rsv1 = rsv123 & 4
if rsv123:
# must be zero
raise FailedConnectionError(
1002,
"RSV1, RSV2, RSV3: MUST be 0 unless an extension is"
" negotiated that defines meanings for non-zero values.")
if rsv1 and "permessage-deflate" not in self.extensions:
# must be zero - unless it's compressed then rsv1 is true
raise FailedConnectionError(
1002,
"RSV1, RSV2, RSV3: MUST be 0 unless an extension is"
" negotiated that defines meanings for non-zero values.")
opcode = a & 15
if opcode not in (0, 1, 2, 8, 9, 0xA):
raise FailedConnectionError(1002, "Unknown opcode received.")
Expand Down Expand Up @@ -569,7 +702,8 @@ def _recv_frame(self, message=None):
received = 0
if not message or opcode & 8:
decoder = self.UTF8Decoder() if opcode == 1 else None
message = self.Message(opcode, decoder=decoder)
decompressor = self._get_permessage_deflate_dec(rsv1)
message = self.Message(opcode, decoder=decoder, decompressor=decompressor)
if not length:
message.push(b'', final=finished)
else:
Expand All @@ -588,13 +722,22 @@ def _recv_frame(self, message=None):
1007, "Text data must be valid utf-8")
return message

@staticmethod
def _pack_message(message, masked=False,
def _pack_message(self, message, masked=False,
continuation=False, final=True, control_code=None):
is_text = False
if isinstance(message, six.text_type):
message = message.encode('utf-8')
is_text = True

compress_bit = 0
compressor = self._get_permessage_deflate_enc()
if message and compressor:
message = compressor.compress(message)
message += compressor.flush(zlib.Z_SYNC_FLUSH)
assert message[-4:] == b"\x00\x00\xff\xff"
message = message[:-4]
compress_bit = 1 << 6

length = len(message)
if not length:
# no point masking empty data
Expand All @@ -608,7 +751,7 @@ def _pack_message(message, masked=False,
raise ProtocolError('Control frame data too large (>125).')
header = struct.pack('!B', control_code | 1 << 7)
else:
opcode = 0 if continuation else (1 if is_text else 2)
opcode = 0 if continuation else ((1 if is_text else 2) | compress_bit)
header = struct.pack('!B', opcode | (1 << 7 if final else 0))
lengthdata = 1 << 7 if masked else 0
if length > 65535:
Expand Down

0 comments on commit b7d2a25

Please sign in to comment.