diff --git a/jupyter_server/services/kernels/connection/abc.py b/jupyter_server/services/kernels/connection/abc.py index f918466c9..4bdf6e3ed 100644 --- a/jupyter_server/services/kernels/connection/abc.py +++ b/jupyter_server/services/kernels/connection/abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable +from typing import Any class KernelWebsocketConnectionABC(ABC): @@ -10,7 +10,7 @@ class KernelWebsocketConnectionABC(ABC): interface. """ - write_message: Callable + websocket_handler: Any @abstractmethod async def connect(self): diff --git a/jupyter_server/services/kernels/connection/base.py b/jupyter_server/services/kernels/connection/base.py index 12b062e78..fc8e1f8f0 100644 --- a/jupyter_server/services/kernels/connection/base.py +++ b/jupyter_server/services/kernels/connection/base.py @@ -3,7 +3,8 @@ import sys from jupyter_client.session import Session -from traitlets import Callable, Float, Instance, default +from tornado.websocket import WebSocketHandler +from traitlets import Float, Instance, default from traitlets.config import LoggingConfigurable try: @@ -135,7 +136,7 @@ def _default_kernel_info_timeout(self): def _default_session(self): return Session(config=self.config) - write_message = Callable() + websocket_handler = Instance(WebSocketHandler) async def connect(self): raise NotImplementedError() diff --git a/jupyter_server/services/kernels/connection/channels.py b/jupyter_server/services/kernels/connection/channels.py index a1233477d..f206a40f7 100644 --- a/jupyter_server/services/kernels/connection/channels.py +++ b/jupyter_server/services/kernels/connection/channels.py @@ -79,6 +79,24 @@ class ZMQChannelsWebsocketConnection(BaseKernelWebsocketConnection): ), ) + kernel_ws_protocol = Unicode( + None, + allow_none=True, + config=True, + help=_i18n( + "Preferred kernel message protocol over websocket to use (default: None). " + "If an empty string is passed, select the legacy protocol. If None, " + "the selected protocol will depend on what the front-end supports " + "(usually the most recent protocol supported by the back-end and the " + "front-end)." + ), + ) + + @property + def write_message(self): + """Alias to the websocket handler's write_message method.""" + return self.websocket_handler.write_message + # class-level registry of open sessions # allows checking for conflict on session-id, # which is used as a zmq identity and must be unique. @@ -117,13 +135,14 @@ def _default_close_future(self): @classmethod async def close_all(cls): """Tornado does not provide a way to close open sockets, so add one.""" - for socket in list(cls._open_sockets): - await socket.close() + for connection in list(cls._open_sockets): + connection.disconnect() + await _ensure_future(connection._close_future) @property def subprotocol(self): try: - protocol = self.selected_subprotocol + protocol = self.websocket_handler.selected_subprotocol except Exception: protocol = None return protocol @@ -520,7 +539,7 @@ def _on_zmq_reply(self, stream, msg_list): self.close() return channel = getattr(stream, "channel", None) - if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + if self.subprotocol == "v1.kernel.websocket.jupyter.org": bin_msg = serialize_msg_to_ws_v1(msg_list, channel) self.write_message(bin_msg, binary=True) else: diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index 70ac886ea..e7a0cfad8 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -179,7 +179,7 @@ async def pre_get(self): kernel = self.kernel_manager.get_kernel(self.kernel_id) self.connection = self.kernel_websocket_connection_class( parent=kernel, - write_message=self.write_message, + websocket_handler=self, ) if self.get_argument("session_id", None):