Skip to content

Commit

Permalink
more updates for unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Zsailer committed Nov 17, 2022
1 parent 9cec9e8 commit 5f8d2f3
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
4 changes: 2 additions & 2 deletions jupyter_server/services/kernels/connection/abc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable
from typing import Any


class KernelWebsocketConnectionABC(ABC):
Expand All @@ -10,7 +10,7 @@ class KernelWebsocketConnectionABC(ABC):
interface.
"""

write_message: Callable
websocket_handler: Any

@abstractmethod
async def connect(self):
Expand Down
5 changes: 3 additions & 2 deletions jupyter_server/services/kernels/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import sys

from jupyter_client.session import Session
from traitlets import Callable, Float, Instance, default
from tornado.websocket import WebSocketHandler
from traitlets import Float, Instance, default
from traitlets.config import LoggingConfigurable

try:
Expand Down Expand Up @@ -135,7 +136,7 @@ def _default_kernel_info_timeout(self):
def _default_session(self):
return Session(config=self.config)

write_message = Callable()
websocket_handler = Instance(WebSocketHandler)

async def connect(self):
raise NotImplementedError()
Expand Down
27 changes: 23 additions & 4 deletions jupyter_server/services/kernels/connection/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,24 @@ class ZMQChannelsWebsocketConnection(BaseKernelWebsocketConnection):
),
)

kernel_ws_protocol = Unicode(
None,
allow_none=True,
config=True,
help=_i18n(
"Preferred kernel message protocol over websocket to use (default: None). "
"If an empty string is passed, select the legacy protocol. If None, "
"the selected protocol will depend on what the front-end supports "
"(usually the most recent protocol supported by the back-end and the "
"front-end)."
),
)

@property
def write_message(self):
"""Alias to the websocket handler's write_message method."""
return self.websocket_handler.write_message

# class-level registry of open sessions
# allows checking for conflict on session-id,
# which is used as a zmq identity and must be unique.
Expand Down Expand Up @@ -117,13 +135,14 @@ def _default_close_future(self):
@classmethod
async def close_all(cls):
"""Tornado does not provide a way to close open sockets, so add one."""
for socket in list(cls._open_sockets):
await socket.close()
for connection in list(cls._open_sockets):
connection.disconnect()
await _ensure_future(connection._close_future)

@property
def subprotocol(self):
try:
protocol = self.selected_subprotocol
protocol = self.websocket_handler.selected_subprotocol
except Exception:
protocol = None
return protocol
Expand Down Expand Up @@ -520,7 +539,7 @@ def _on_zmq_reply(self, stream, msg_list):
self.close()
return
channel = getattr(stream, "channel", None)
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
if self.subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(msg_list, channel)
self.write_message(bin_msg, binary=True)
else:
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/services/kernels/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def pre_get(self):
kernel = self.kernel_manager.get_kernel(self.kernel_id)
self.connection = self.kernel_websocket_connection_class(
parent=kernel,
write_message=self.write_message,
websocket_handler=self,
)

if self.get_argument("session_id", None):
Expand Down

0 comments on commit 5f8d2f3

Please sign in to comment.