Skip to content

Commit

Permalink
Merge the gateway handlers into the standard handlers.
Browse files Browse the repository at this point in the history
The two `WebSocketChannelsHandler` and `GatewayResourceHandler` classes are removed and
their corresponding functionality is merged into the respective `KernelWebsocketHandler`
and `KernelSpecResourceHandler` classes.

For the `KernelSpecResourceHandler` class, this change is rather straightforward as
we can simply make the existing handler check if the kernel spec manager has a
`get_kernel_spec_resource` method, and if so delegate to that method instead of
trying to read resources from disk.

The `KernelWebsocketHandler` conversion is more complicated, though. The handling of
websocket connections was generalized/extended in jupyter-server#1047 to allow the definition of
a `kernel_websocket_connection_class` as an alternative to replacing the entire
websocket handler.

This change builds on that by converting the `GatewayWebSocketClient` class to be
an instance of the kernel websocket connection class, and accordingly renames it to
`GatewayWebSocketConnection`.

When the gateway client is enabled, the default `kernel_websocket_connection_class`
is changed to this `GatewayWebSocketConnection` class similarly to how the kernel
and kernel spec manager default classes are updated.
  • Loading branch information
ojarjur committed Apr 20, 2023
1 parent 5c49253 commit 6bb3da4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 194 deletions.
206 changes: 24 additions & 182 deletions jupyter_server/gateway/handlers.py
Original file line number Diff line number Diff line change
@@ -1,171 +1,42 @@
"""Gateway API handlers."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import asyncio
import logging
import mimetypes
import os
import random
from typing import Optional, cast
from typing import cast

from jupyter_client.session import Session
from tornado import web
from tornado.concurrent import Future
from tornado.escape import json_decode, url_escape, utf8
from tornado.escape import url_escape
from tornado.httpclient import HTTPRequest
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.websocket import WebSocketHandler, websocket_connect
from traitlets.config.configurable import LoggingConfigurable
from tornado.ioloop import IOLoop
from tornado.websocket import websocket_connect

from ..base.handlers import APIHandler, JupyterHandler
from ..services.kernels.connection.base import BaseKernelWebsocketConnection
from ..utils import url_path_join
from .managers import GatewayClient

# Keepalive ping interval (default: 30 seconds)
GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30"))


class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler):
"""Gateway web socket channels handler."""

session = None
gateway = None
kernel_id = None
ping_callback = None

def check_origin(self, origin=None):
"""Check origin for the socket."""
return JupyterHandler.check_origin(self, origin)

def set_default_headers(self):
"""Undo the set_default_headers in JupyterHandler which doesn't make sense for websockets"""
pass

def get_compression_options(self):
"""Get the compression options for the socket."""
# use deflate compress websocket
return {}

def authenticate(self):
"""Run before finishing the GET request
Extend this method to add logic that should fire before
the websocket finishes completing.
"""
# authenticate the request before opening the websocket
if self.current_user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

if self.get_argument("session_id", None):
assert self.session is not None
self.session.session = self.get_argument("session_id")
else:
self.log.warning("No session ID specified")

def initialize(self):
"""Intialize the socket."""
self.log.debug("Initializing websocket connection %s", self.request.path)
self.session = Session(config=self.config)
self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url)

async def get(self, kernel_id, *args, **kwargs):
"""Get the socket."""
self.authenticate()
self.kernel_id = kernel_id
kwargs["kernel_id"] = kernel_id
await super().get(*args, **kwargs)

def send_ping(self):
"""Send a ping to the socket."""
if self.ws_connection is None and self.ping_callback is not None:
self.ping_callback.stop()
return

self.ping(b"")

def open(self, kernel_id, *args, **kwargs):
"""Handle web socket connection open to notebook server and delegate to gateway web socket handler"""
self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000)
self.ping_callback.start()

assert self.gateway is not None
self.gateway.on_open(
kernel_id=kernel_id,
message_callback=self.write_message,
compression_options=self.get_compression_options(),
)

def on_message(self, message):
"""Forward message to gateway web socket handler."""
assert self.gateway is not None
self.gateway.on_message(message)

def write_message(self, message, binary=False):
"""Send message back to notebook client. This is called via callback from self.gateway._read_messages."""
if self.ws_connection: # prevent WebSocketClosedError
if isinstance(message, bytes):
binary = True
super().write_message(message, binary=binary)
elif self.log.isEnabledFor(logging.DEBUG):
msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message)))
self.log.debug(
"Notebook client closed websocket connection - message dropped: {}".format(
msg_summary
)
)

def on_close(self):
"""Handle a closing socket."""
self.log.debug("Closing websocket connection %s", self.request.path)
assert self.gateway is not None
self.gateway.on_close()
super().on_close()

@staticmethod
def _get_message_summary(message):
"""Get a summary of a message."""
summary = []
message_type = message["msg_type"]
summary.append(f"type: {message_type}")

if message_type == "status":
summary.append(", state: {}".format(message["content"]["execution_state"]))
elif message_type == "error":
summary.append(
", {}:{}:{}".format(
message["content"]["ename"],
message["content"]["evalue"],
message["content"]["traceback"],
)
)
else:
summary.append(", ...") # don't display potentially sensitive data

return "".join(summary)


class GatewayWebSocketClient(LoggingConfigurable):
class GatewayWebSocketConnection(BaseKernelWebsocketConnection):
"""Proxy web socket connection to a kernel/enterprise gateway."""

kernel_ws_protocol = None

def __init__(self, **kwargs):
"""Initialize the gateway web socket client."""
super().__init__()
self.kernel_id = None
super().__init__(**kwargs)
self.ws = None
self.ws_future: Future = Future()
self.disconnected = False
self.retry = 0

async def _connect(self, kernel_id, message_callback):
async def connect(self):
"""Connect to the socket."""
# websocket is initialized before connection
self.ws = None
self.kernel_id = kernel_id
ws_url = url_path_join(
GatewayClient.instance().ws_url,
GatewayClient.instance().kernels_endpoint,
url_escape(kernel_id),
url_escape(self.kernel_id),
"channels",
)
self.log.info(f"Connecting to {ws_url}")
Expand All @@ -177,7 +48,7 @@ async def _connect(self, kernel_id, message_callback):
self.ws_future.add_done_callback(self._connection_done)

loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self._read_messages(message_callback))
loop.add_future(self.ws_future, lambda future: self._read_messages())

def _connection_done(self, fut):
"""Handle a finished connection."""
Expand All @@ -195,7 +66,7 @@ def _connection_done(self, fut):
)
)

def _disconnect(self):
def disconnect(self):
"""Handle a disconnect."""
self.disconnected = True
if self.ws is not None:
Expand All @@ -206,7 +77,7 @@ def _disconnect(self):
self.ws_future.cancel()
self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")

async def _read_messages(self, callback):
async def _read_messages(self):
"""Read messages from gateway server."""
while self.ws is not None:
message = None
Expand All @@ -221,7 +92,7 @@ async def _read_messages(self, callback):
if not self.disconnected:
self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
break
callback(
self.handle_outgoing_message(
message
) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
else: # ws cancelled - stop reading
Expand All @@ -247,14 +118,13 @@ async def _read_messages(self, callback):
)
await asyncio.sleep(retry_interval)
loop = IOLoop.current()
loop.spawn_callback(self._connect, self.kernel_id, callback)
loop.spawn_callback(self.connect)

def on_open(self, kernel_id, message_callback, **kwargs):
"""Web socket connection open against gateway server."""
loop = IOLoop.current()
loop.spawn_callback(self._connect, kernel_id, message_callback)
def handle_outgoing_message(self, *args, **kwargs):
"""Send message to the notebook client."""
self.websocket_handler.write_message(*args, **kwargs)

def on_message(self, message):
def handle_incoming_message(self, message: str) -> None:
"""Send message to gateway server."""
if self.ws is None:
loop = IOLoop.current()
Expand All @@ -270,34 +140,6 @@ def _write_message(self, message):
except Exception as e:
self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)

def on_close(self):
"""Web socket closed event."""
self._disconnect()


class GatewayResourceHandler(APIHandler):
"""Retrieves resources for specific kernelspec definitions from kernel/enterprise gateway."""

@web.authenticated
async def get(self, kernel_name, path, include_body=True):
"""Get a gateway resource by name and path."""
mimetype: Optional[str] = None
ksm = self.kernel_spec_manager
kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path)
if kernel_spec_res is None:
self.log.warning(
"Kernelspec resource '{}' for '{}' not found. Gateway may not support"
" resource serving.".format(path, kernel_name)
)
else:
mimetype = mimetypes.guess_type(path)[0] or "text/plain"
self.finish(kernel_spec_res, set_content_type=mimetype)


from ..services.kernels.handlers import _kernel_id_regex
from ..services.kernelspecs.handlers import kernel_name_regex

default_handlers = [
(r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler),
(r"/kernelspecs/%s/(?P<path>.*)" % kernel_name_regex, GatewayResourceHandler),
]
@classmethod
async def close_all(cls):
pass
23 changes: 23 additions & 0 deletions jupyter_server/kernelspecs/handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Kernelspecs API Handlers."""
import mimetypes
from typing import Optional

from jupyter_core.utils import ensure_async
from tornado import web

Expand Down Expand Up @@ -27,6 +30,26 @@ async def get(self, kernel_name, path, include_body=True):
ksm = self.kernel_spec_manager
if path.lower().endswith(".png"):
self.set_header("Cache-Control", f"max-age={60*60*24*30}")
ksm = self.kernel_spec_manager
if hasattr(ksm, "get_kernel_spec_resource"):
# If the kernel spec manager defines a method to get kernelspec resources,
# then use that instead of trying to read from disk.
kernel_spec_res = await ksm.get_kernel_spec_resource(kernel_name, path)
if kernel_spec_res is not None:
# We have to explicitly specify the `absolute_path` attribute so that
# the underlying StaticFileHandler methods can calculate an etag.
self.absolute_path = path
mimetype: Optional[str] = mimetypes.guess_type(path)[0] or "text/plain"
self.set_header("Content-Type", mimetype)
self.finish(kernel_spec_res)
return
else:
self.log.warning(
"Kernelspec resource '{}' for '{}' not found. Kernel spec manager may"
" not support resource serving. Falling back to reading from disk".format(
path, kernel_name
)
)
try:
kspec = await ensure_async(ksm.get_kernel_spec(kernel_name))
self.root = kspec.resource_dir
Expand Down
20 changes: 8 additions & 12 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from jupyter_server.extension.config import ExtensionConfigManager
from jupyter_server.extension.manager import ExtensionManager
from jupyter_server.extension.serverextension import ServerExtensionApp
from jupyter_server.gateway.handlers import GatewayWebSocketConnection
from jupyter_server.gateway.managers import (
GatewayClient,
GatewayKernelSpecManager,
Expand Down Expand Up @@ -433,17 +434,6 @@ def init_handlers(self, default_services, settings):
# And from identity provider
handlers.extend(settings["identity_provider"].get_handlers())

# If gateway mode is enabled, replace appropriate handlers to perform redirection
if GatewayClient.instance().gateway_enabled:
# for each handler required for gateway, locate its pattern
# in the current list and replace that entry...
gateway_handlers = load_handlers("jupyter_server.gateway.handlers")
for _, gwh in enumerate(gateway_handlers):
for j, h in enumerate(handlers):
if gwh[0] == h[0]:
handlers[j] = (gwh[0], gwh[1])
break

# register base handlers last
handlers.extend(load_handlers("jupyter_server.base.handlers"))

Expand Down Expand Up @@ -796,6 +786,7 @@ class ServerApp(JupyterApp):
GatewayMappingKernelManager,
GatewayKernelSpecManager,
GatewaySessionManager,
GatewayWebSocketConnection,
GatewayClient,
Authorizer,
EventLogger,
Expand Down Expand Up @@ -1505,12 +1496,17 @@ def _default_session_manager_class(self):
return SessionManager

kernel_websocket_connection_class = Type(
default_value=ZMQChannelsWebsocketConnection,
klass=BaseKernelWebsocketConnection,
config=True,
help=_i18n("The kernel websocket connection class to use."),
)

@default("kernel_websocket_connection_class")
def _default_kernel_websocket_connection_class(self):
if self.gateway_config.gateway_enabled:
return "jupyter_server.gateway.handlers.GatewayWebSocketConnection"
return ZMQChannelsWebsocketConnection

config_manager_class = Type(
default_value=ConfigManager,
config=True,
Expand Down

0 comments on commit 6bb3da4

Please sign in to comment.