Skip to content

Commit

Permalink
Support websocket subprotocols
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 18, 2022
1 parent 9d44029 commit cdcffaf
Show file tree
Hide file tree
Showing 2 changed files with 328 additions and 67 deletions.
123 changes: 110 additions & 13 deletions jupyter_server/base/zmqhandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,64 @@
from .handlers import JupyterHandler


def serialize_binary_message(msg):
"""serialize a message as a binary blob
Header:
4 bytes: number of msg parts (nbufs) as 32b int
4 * nbufs bytes: offset for each buffer as integer as 32b int
Offsets are from the start of the buffer, including the header.
Returns
-------
The message serialized to bytes.
"""
# don't modify msg or buffer list in-place
msg = msg.copy()
buffers = list(msg.pop("buffers"))
if sys.version_info < (3, 4):
buffers = [x.tobytes() for x in buffers]
bmsg = json.dumps(msg, default=json_default).encode("utf8")
buffers.insert(0, bmsg)
nbufs = len(buffers)
offsets = [4 * (nbufs + 1)]
for buf in buffers[:-1]:
offsets.append(offsets[-1] + len(buf))
offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets)
buffers.insert(0, offsets_buf)
return b"".join(buffers)


def deserialize_binary_message(bmsg):
"""deserialize a message from a binary blog
Header:
4 bytes: number of msg parts (nbufs) as 32b int
4 * nbufs bytes: offset for each buffer as integer as 32b int
Offsets are from the start of the buffer, including the header.
Returns
-------
message dictionary
"""
nbufs = struct.unpack("!i", bmsg[:4])[0]
offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)]))
offsets.append(None)
bufs = []
for start, stop in zip(offsets[:-1], offsets[1:]):
bufs.append(bmsg[start:stop])
msg = json.loads(bufs[0].decode("utf8"))
msg["header"] = extract_dates(msg["header"])
msg["parent_header"] = extract_dates(msg["parent_header"])
msg["buffers"] = bufs[1:]
return msg


# ping interval for keeping websockets alive (30 seconds)
WS_PING_INTERVAL = 30000

Expand Down Expand Up @@ -155,6 +213,37 @@ def send_error(self, *args, **kwargs):
# we can close the connection more gracefully.
self.stream.close()

def _reserialize_reply(self, msg_or_list, channel=None):
"""Reserialize a reply message using JSON.
msg_or_list can be an already-deserialized msg dict or the zmq buffer list.
If it is the zmq list, it will be deserialized with self.session.
This takes the msg list from the ZMQ socket and serializes the result for the websocket.
This method should be used by self._on_zmq_reply to build messages that can
be sent back to the browser.
"""
if isinstance(msg_or_list, dict):
# already unpacked
msg = msg_or_list
else:
idents, msg_list = self.session.feed_identities(msg_or_list)
msg = self.session.deserialize(msg_list)
if channel:
msg["channel"] = channel
if msg["buffers"]:
buf = serialize_binary_message(msg)
return buf
else:
smsg = json.dumps(msg, default=json_default)
return cast_unicode(smsg)

def select_subprotocol(self, subprotocols):
selected_subprotocol = "0.0.1" if "0.0.1" in subprotocols else None
# None is the default, "legacy" protocol
return selected_subprotocol

def _on_zmq_reply(self, stream, msg_list):
# Sometimes this gets triggered when the on_close method is scheduled in the
# eventloop but hasn't been called.
Expand All @@ -163,19 +252,27 @@ def _on_zmq_reply(self, stream, msg_list):
self.close()
return
channel = getattr(stream, "channel", None)
offsets = []
curr_sum = 0
for msg in msg_list:
length = len(msg)
offsets.append(length + curr_sum)
curr_sum += length
layout = json.dumps({
"channel": channel,
"offsets": offsets,
}).encode("utf-8")
layout_length = len(layout).to_bytes(2, byteorder="little")
bin_msg = b"".join([layout_length, layout] + msg_list)
self.write_message(bin_msg, binary=True)
if not self.selected_subprotocol:
try:
msg = self._reserialize_reply(msg_list, channel=channel)
except Exception:
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
else:
self.write_message(msg, binary=isinstance(msg, bytes))
elif self.selected_subprotocol == "0.0.1":
offsets = []
curr_sum = 0
for msg in msg_list:
length = len(msg)
offsets.append(length + curr_sum)
curr_sum += length
layout = json.dumps({
"channel": channel,
"offsets": offsets,
}).encode("utf-8")
layout_length = len(layout).to_bytes(2, byteorder="little")
bin_msg = b"".join([layout_length, layout] + msg_list)
self.write_message(bin_msg, binary=True)


class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler):
Expand Down
Loading

0 comments on commit cdcffaf

Please sign in to comment.