diff --git a/docs/source/api/jupyter_server.services.kernelspecs.rst b/docs/source/api/jupyter_server.services.kernelspecs.rst index 3f210d0f55..3487be9f34 100644 --- a/docs/source/api/jupyter_server.services.kernelspecs.rst +++ b/docs/source/api/jupyter_server.services.kernelspecs.rst @@ -10,6 +10,12 @@ Submodules :undoc-members: :show-inheritance: + +.. automodule:: jupyter_server.services.kernelspecs.renaming + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index e8a622401a..f6c36f829c 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -16,10 +16,10 @@ from jupyter_client.clientabc import KernelClientABC from jupyter_client.kernelspec import KernelSpecManager from jupyter_client.managerabc import KernelManagerABC -from jupyter_core.utils import ensure_async +from jupyter_core.utils import ensure_async, run_sync from tornado import web from tornado.escape import json_decode, json_encode, url_escape, utf8 -from traitlets import DottedObjectName, Instance, Type, default +from traitlets import DottedObjectName, Instance, Type, Unicode, default, observe from .._tz import UTC, utcnow from ..services.kernels.kernelmanager import ( @@ -27,6 +27,7 @@ ServerKernelManager, emit_kernel_action_event, ) +from ..services.kernelspecs.renaming import RenamingKernelSpecManagerMixin, normalize_kernel_name from ..services.sessions.sessionmanager import SessionManager from ..utils import url_path_join from .gateway_client import GatewayClient, gateway_request @@ -60,7 +61,8 @@ def remove_kernel(self, kernel_id): except KeyError: pass - async def start_kernel(self, *, kernel_id=None, path=None, **kwargs): + @normalize_kernel_name + async def start_kernel(self, *, kernel_id=None, path=None, renamed_kernel=None, **kwargs): """Start a kernel for a session and return its kernel_id. Parameters @@ -80,6 +82,10 @@ async def start_kernel(self, *, kernel_id=None, path=None, **kwargs): km = self.kernel_manager_factory(parent=self, log=self.log) await km.start_kernel(kernel_id=kernel_id, **kwargs) + if renamed_kernel is not None: + km.kernel_name = renamed_kernel + if km.kernel: + km.kernel["name"] = km.kernel_name kernel_id = km.kernel_id self._kernels[kernel_id] = km # Initialize culling if not already @@ -210,6 +216,27 @@ async def cull_kernels(self): class GatewayKernelSpecManager(KernelSpecManager): """A gateway kernel spec manager.""" + default_kernel_name = Unicode(allow_none=True) + + # Use a hidden trait for the default kernel name we get from the remote. + # + # This is automatically copied to the corresponding public trait. + # + # We use two layers of trait so that sub classes can modify the public + # trait without confusing the logic that tracks changes to the remote + # default kernel name. + _remote_default_kernel_name = Unicode(allow_none=True) + + @default("default_kernel_name") + def _default_default_kernel_name(self): + # The default kernel name is taken from the remote gateway + run_sync(self.get_all_specs)() + return self._remote_default_kernel_name + + @observe("_remote_default_kernel_name") + def _observe_remote_default_kernel_name(self, change): + self.default_kernel_name = change.new + def __init__(self, **kwargs): """Initialize a gateway kernel spec manager.""" super().__init__(**kwargs) @@ -273,14 +300,13 @@ async def get_all_specs(self): # If different log a warning and reset the default. However, the # caller of this method will still return this server's value until # the next fetch of kernelspecs - at which time they'll match. - km = self.parent.kernel_manager remote_default_kernel_name = fetched_kspecs.get("default") - if remote_default_kernel_name != km.default_kernel_name: + if remote_default_kernel_name != self._remote_default_kernel_name: self.log.info( f"Default kernel name on Gateway server ({remote_default_kernel_name}) differs from " - f"Notebook server ({km.default_kernel_name}). Updating to Gateway server's value." + f"Notebook server ({self._remote_default_kernel_name}). Updating to Gateway server's value." ) - km.default_kernel_name = remote_default_kernel_name + self._remote_default_kernel_name = remote_default_kernel_name remote_kspecs = fetched_kspecs.get("kernelspecs") return remote_kspecs @@ -345,6 +371,18 @@ async def get_kernel_spec_resource(self, kernel_name, path): return kernel_spec_resource +class GatewayRenamingKernelSpecManager(RenamingKernelSpecManagerMixin, GatewayKernelSpecManager): + spec_name_prefix = Unicode( + "remote-", help="Prefix to be added onto the front of kernel spec names." + ) + + display_name_suffix = Unicode( + " (Remote)", + config=True, + help="Suffix to be added onto the end of kernel spec display names.", + ) + + class GatewaySessionManager(SessionManager): """A gateway session manager.""" @@ -453,6 +491,8 @@ async def refresh_model(self, model=None): # a parent instance if, say, a server extension is using another application # (e.g., papermill) that uses a KernelManager instance directly. self.parent._kernel_connections[self.kernel_id] = int(model["connections"]) + if self.kernel_name: + model["name"] = self.kernel_name self.kernel = model return model @@ -477,7 +517,8 @@ async def start_kernel(self, **kwargs): if kernel_id is None: kernel_name = kwargs.get("kernel_name", "python3") - self.log.debug("Request new kernel at: %s" % self.kernels_url) + self.kernel_name = kernel_name + self.log.debug(f"Request new kernel at: {self.kernels_url} using {kernel_name}") # Let KERNEL_USERNAME take precedent over http_user config option. if os.environ.get("KERNEL_USERNAME") is None and GatewayClient.instance().http_user: diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index aec1be9a8c..bcdc8d11f7 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -17,6 +17,7 @@ from typing import Optional from jupyter_client.ioloop.manager import AsyncIOLoopKernelManager +from jupyter_client.kernelspec import NATIVE_KERNEL_NAME from jupyter_client.multikernelmanager import AsyncMultiKernelManager, MultiKernelManager from jupyter_client.session import Session from jupyter_core.paths import exists @@ -38,6 +39,7 @@ TraitError, Unicode, default, + observe, validate, ) @@ -46,6 +48,8 @@ from jupyter_server.prometheus.metrics import KERNEL_CURRENTLY_RUNNING_TOTAL from jupyter_server.utils import ApiPath, import_item, to_os_path +from ..kernelspecs.renaming import normalize_kernel_name + class MappingKernelManager(MultiKernelManager): """A KernelManager that handles @@ -206,8 +210,14 @@ async def _remove_kernel_when_ready(self, kernel_id, kernel_awaitable): # TODO DEC 2022: Revise the type-ignore once the signatures have been changed upstream # https://github.com/jupyter/jupyter_client/pull/905 - async def _async_start_kernel( # type:ignore[override] - self, *, kernel_id: Optional[str] = None, path: Optional[ApiPath] = None, **kwargs: str + @normalize_kernel_name + async def _async_start_kernel( + self, + *, + kernel_id: Optional[str] = None, + path: Optional[ApiPath] = None, + renamed_kernel: Optional[str] = None, + **kwargs: str, ) -> str: """Start a kernel for a session and return its kernel_id. @@ -231,6 +241,8 @@ async def _async_start_kernel( # type:ignore[override] assert kernel_id is not None, "Never Fail, but necessary for mypy " kwargs["kernel_id"] = kernel_id kernel_id = await self.pinned_superclass._async_start_kernel(self, **kwargs) + if renamed_kernel: + self._kernels[kernel_id].kernel_name = renamed_kernel self._kernel_connections[kernel_id] = 0 task = asyncio.create_task(self._finish_kernel_start(kernel_id)) if not getattr(self, "use_pending_kernels", None): @@ -261,7 +273,7 @@ async def _async_start_kernel( # type:ignore[override] # see https://github.com/jupyter-server/jupyter_server/issues/1165 # this assignment is technically incorrect, but might need a change of API # in jupyter_client. - start_kernel = _async_start_kernel # type:ignore[assignment] + start_kernel = _async_start_kernel async def _finish_kernel_start(self, kernel_id): """Handle a kernel that finishes starting.""" @@ -678,7 +690,7 @@ async def cull_kernel_if_idle(self, kernel_id): # AsyncMappingKernelManager inherits as much as possible from MappingKernelManager, # overriding only what is different. -class AsyncMappingKernelManager(MappingKernelManager, AsyncMultiKernelManager): # type:ignore[misc] +class AsyncMappingKernelManager(MappingKernelManager, AsyncMultiKernelManager): """An asynchronous mapping kernel manager.""" @default("kernel_manager_class") @@ -700,6 +712,42 @@ def _validate_kernel_manager_class(self, proposal): ) return km_class_value + @default("default_kernel_name") + def _default_default_kernel_name(self): + if ( + hasattr(self.kernel_spec_manager, "default_kernel_name") + and self.kernel_spec_manager.default_kernel_name + ): + return self.kernel_spec_manager.default_kernel_name + return NATIVE_KERNEL_NAME + + @observe("default_kernel_name") + def _observe_default_kernel_name(self, change): + if ( + hasattr(self.kernel_spec_manager, "default_kernel_name") + and self.kernel_spec_manager.default_kernel_name + ): + # If the kernel spec manager defines a default kernel name, treat that + # one as authoritative. + kernel_name = change.new + if kernel_name == self.kernel_spec_manager.default_kernel_name: + return + self.log.debug( + f"The MultiKernelManager default kernel name '{kernel_name}'" + " differs from the KernelSpecManager default kernel name" + f" '{self.kernel_spec_manager.default_kernel_name}'..." + " Using the kernel spec manager's default name." + ) + self.default_kernel_name = self.kernel_spec_manager.default_kernel_name + + def _on_kernel_spec_manager_default_kernel_name_changed(self, change): + # Sync the kernel-spec-manager's trait to the multi-kernel-manager's trait. + kernel_name = change.new + if kernel_name is None: + return + self.log.debug(f"KernelSpecManager default kernel name changed: {kernel_name}") + self.default_kernel_name = kernel_name + def __init__(self, **kwargs): """Initialize an async mapping kernel manager.""" self.pinned_superclass = MultiKernelManager @@ -707,6 +755,13 @@ def __init__(self, **kwargs): self.pinned_superclass.__init__(self, **kwargs) self.last_kernel_activity = utcnow() + if hasattr(self.kernel_spec_manager, "default_kernel_name"): + self.kernel_spec_manager.observe( + self._on_kernel_spec_manager_default_kernel_name_changed, "default_kernel_name" + ) + if not self.kernel_spec_manager.default_kernel_name: + self.kernel_spec_manager.default_kernel_name = self.default_kernel_name + def emit_kernel_action_event(success_msg: str = ""): # type: ignore """Decorate kernel action methods to diff --git a/jupyter_server/services/kernelspecs/handlers.py b/jupyter_server/services/kernelspecs/handlers.py index e1ed186fa7..be3480418e 100644 --- a/jupyter_server/services/kernelspecs/handlers.py +++ b/jupyter_server/services/kernelspecs/handlers.py @@ -64,10 +64,10 @@ async def get(self): """Get the list of kernel specs.""" ksm = self.kernel_spec_manager km = self.kernel_manager + kspecs = await ensure_async(ksm.get_all_specs()) model = {} model["default"] = km.default_kernel_name model["kernelspecs"] = specs = {} - kspecs = await ensure_async(ksm.get_all_specs()) for kernel_name, kernel_info in kspecs.items(): try: if is_kernelspec_model(kernel_info): diff --git a/jupyter_server/services/kernelspecs/renaming.py b/jupyter_server/services/kernelspecs/renaming.py new file mode 100644 index 0000000000..e09c1de246 --- /dev/null +++ b/jupyter_server/services/kernelspecs/renaming.py @@ -0,0 +1,164 @@ +"""Support for renaming kernel specs at runtime.""" +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. +from functools import wraps +from typing import Any, Dict, Tuple + +from jupyter_client.kernelspec import KernelSpecManager +from jupyter_core.utils import ensure_async, run_sync +from traitlets import Unicode, default, observe +from traitlets.config import LoggingConfigurable + + +def normalize_kernel_name(method): + @wraps(method) + async def wrapped_method(self, *args, **kwargs): + kernel_name = kwargs.get("kernel_name", None) + if ( + kernel_name + and hasattr(self, "kernel_spec_manager") + and hasattr(self.kernel_spec_manager, "original_kernel_name") + ): + original_kernel_name = self.kernel_spec_manager.original_kernel_name(kernel_name) + if kernel_name != original_kernel_name: + self.log.debug( + f"Renamed kernel '{kernel_name}' replaced with original kernel name '{original_kernel_name}'" + ) + kwargs["renamed_kernel"] = kernel_name + kwargs["kernel_name"] = original_kernel_name + + return await method(self, *args, **kwargs) + + return wrapped_method + + +class RenamingKernelSpecManagerMixin(LoggingConfigurable): + """KernelSpecManager mixin that renames kernel specs. + + The base KernelSpecManager class only has synchronous methods, but some child + classes (in particular, GatewayKernelManager) change those methods to be async. + + In order to support both versions, we provide both synchronous and async versions + of all the relevant kernel spec manager methods. We first do the renaming in the + async version, but override the KernelSpecManager base methods using the + synchronous versions. + """ + + spec_name_prefix = Unicode(help="Prefix to be added onto the front of kernel spec names.") + + display_name_suffix = Unicode( + config=True, help="Suffix to be added onto the end of kernel spec display names." + ) + + display_name_format = Unicode( + config=True, help="Format for rewritten kernel spec display names." + ) + + @default("display_name_format") + def _default_display_name_format(self): + if self.display_name_suffix: + return "{}" + self.display_name_suffix + return "{}" + + default_kernel_name = Unicode(allow_none=True) + + @observe("default_kernel_name") + def _observe_default_kernel_name(self, change): + kernel_name = change.new + if self.original_kernel_name(kernel_name) is not kernel_name: + # The default kernel name has already been renamed + return + updated_kernel_name = self.rename_kernel(kernel_name) + self.log.debug(f"Renaming default kernel name {kernel_name} to {updated_kernel_name}") + self.default_kernel_name = updated_kernel_name + + def rename_kernel(self, kernel_name: str) -> str: + """Rename the supplied kernel spec based on the configured format string.""" + if kernel_name.startswith(self.spec_name_prefix): + return kernel_name + return self.spec_name_prefix + kernel_name + + def original_kernel_name(self, kernel_name: str) -> str: + if not kernel_name.startswith(self.spec_name_prefix): + return kernel_name + return kernel_name[len(self.spec_name_prefix) :] + + def _update_display_name(self, display_name: str) -> str: + if not display_name: + return display_name + return self.display_name_format.format(display_name) + + def _update_spec(self, original_name: str, kernel_spec: Dict) -> Tuple[str, Dict]: + original_prefix = f"/kernelspecs/{original_name}" + spec_name = self.rename_kernel(original_name) + new_prefix = f"/kernelspecs/{spec_name}" + + kernel_spec["name"] = spec_name + kernel_spec["spec"] = kernel_spec.get("spec", {}) + kernel_spec["resources"] = kernel_spec.get("resources", {}) + + spec = kernel_spec["spec"] + spec["display_name"] = self._update_display_name(spec.get("display_name", "")) + + resources = kernel_spec["resources"] + for name, value in resources.items(): + resources[name] = value.replace(original_prefix, new_prefix) + return spec_name, kernel_spec + + async def async_get_all_specs(self): + ks: Dict = {} + original_ks = await ensure_async(super().get_all_specs()) # type:ignore[misc] + for s, k in original_ks.items(): + spec_name, kernel_spec = self._update_spec(s, k) + ks[spec_name] = kernel_spec + return ks + + def get_all_specs(self): + return run_sync(self.async_get_all_specs)() + + async def async_get_kernel_spec(self, kernel_name: str, *args: Any, **kwargs: Any) -> Any: + original_kernel_name = self.original_kernel_name(kernel_name) + self.log.debug(f"Found original kernel name '{original_kernel_name}' for '{kernel_name}'") + kspec = await ensure_async( + super().get_kernel_spec(original_kernel_name, *args, **kwargs) # type:ignore[misc] + ) + if original_kernel_name == kernel_name: + # The kernel wasn't renamed, so don't modify its contents + return kspec + + # KernelSpecManager and GatewayKernelSpec manager return different types for the + # wrapped `get_kernel_spec` call (KernelSpec vs. Dict). To accommodate both, + # we check the type of the returned value and operate on the two different + # types as appropriate. + if isinstance(kspec, dict): + kspec["name"] = kernel_name + kspec["display_name"] = self._update_display_name(kspec.get("display_name", "")) + else: + kspec.name = kernel_name + kspec.display_name = self._update_display_name(kspec.display_name) + return kspec + + def get_kernel_spec(self, kernel_name: str, *args: Any, **kwargs: Any) -> Any: + return run_sync(self.async_get_kernel_spec)(kernel_name, *args, **kwargs) + + async def get_kernel_spec_resource(self, kernel_name: str, *args: Any, **kwargs: Any) -> Any: + if not hasattr(super(), "get_kernel_spec_resource"): + return None + kernel_name = self.original_kernel_name(kernel_name) + return await ensure_async( + super().get_kernel_spec_resource(kernel_name, *args, **kwargs) # type:ignore[misc] + ) + + +class RenamingKernelSpecManager(RenamingKernelSpecManagerMixin, KernelSpecManager): + """KernelSpecManager that renames kernels""" + + spec_name_prefix = Unicode( + "local-", help="Prefix to be added onto the front of kernel spec names." + ) + + display_name_suffix = Unicode( + " (Local)", + config=True, + help="Suffix to be added onto the end of kernel spec display names.", + ) diff --git a/tests/services/kernels/test_api.py b/tests/services/kernels/test_api.py index dfbf1a47bc..2094c37067 100644 --- a/tests/services/kernels/test_api.py +++ b/tests/services/kernels/test_api.py @@ -58,6 +58,12 @@ async def _(kernel_id, ready=None): "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.AsyncMappingKernelManager" } }, + { + "ServerApp": { + "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.AsyncMappingKernelManager", + "kernel_spec_manager_class": "jupyter_server.services.kernelspecs.renaming.RenamingKernelSpecManager", + }, + }, ] @@ -66,13 +72,22 @@ async def _(kernel_id, ready=None): # See https://github.com/jupyter-server/jupyter_server/issues/672 if os.name != "nt" and jupyter_client._version.version_info >= (7, 1): # Add a pending kernels condition - c = { - "ServerApp": { - "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.AsyncMappingKernelManager" + cs = [ + { + "ServerApp": { + "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.AsyncMappingKernelManager" + }, + "AsyncMappingKernelManager": {"use_pending_kernels": True}, + }, + { + "ServerApp": { + "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.AsyncMappingKernelManager", + "kernel_spec_manager_class": "jupyter_server.services.kernelspecs.renaming.RenamingKernelSpecManager", + }, + "AsyncMappingKernelManager": {"use_pending_kernels": True}, }, - "AsyncMappingKernelManager": {"use_pending_kernels": True}, - } - configs.append(c) + ] + configs.extend(cs) @pytest.fixture(params=configs) @@ -102,15 +117,44 @@ async def test_default_kernels(jp_fetch, jp_base_url): @pytest.mark.timeout(TEST_TIMEOUT) -async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_serverapp, pending_kernel_is_ready): +async def test_kernels_with_default_kernelspec(jp_fetch, jp_base_url, jp_kernelspecs): + r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True) + kernel = json.loads(r.body.decode()) + assert r.headers["location"] == url_path_join(jp_base_url, "/api/kernels/", kernel["id"]) + assert r.code == 201 + assert isinstance(kernel, dict) + + report_uri = url_path_join(jp_base_url, "/api/security/csp-report") + expected_csp = "; ".join( + ["frame-ancestors 'self'", "report-uri " + report_uri, "default-src 'none'"] + ) + assert r.headers["Content-Security-Policy"] == expected_csp + + # Verify that the default kernel was created using the default kernelspec + r2 = await jp_fetch("api", "kernelspecs", method="GET") + model = json.loads(r2.body.decode()) + assert isinstance(model, dict) + assert model["default"] == kernel["name"] + specs = model["kernelspecs"] + assert isinstance(specs, dict) + assert kernel["name"] in specs + + +@pytest.mark.timeout(TEST_TIMEOUT) +async def test_main_kernel_handler( + jp_server_config, jp_fetch, jp_base_url, jp_serverapp, pending_kernel_is_ready +): # Start the first kernel - r = await jp_fetch( - "api", "kernels", method="POST", body=json.dumps({"name": NATIVE_KERNEL_NAME}) + spec_name_prefix = jp_server_config.get("RenamingKernelSpecManager", {}).get( + "spec_name_prefix", "" ) + kernel_name = spec_name_prefix + NATIVE_KERNEL_NAME + r = await jp_fetch("api", "kernels", method="POST", body=json.dumps({"name": kernel_name})) kernel1 = json.loads(r.body.decode()) assert r.headers["location"] == url_path_join(jp_base_url, "/api/kernels/", kernel1["id"]) assert r.code == 201 assert isinstance(kernel1, dict) + assert kernel1["name"] == kernel_name report_uri = url_path_join(jp_base_url, "/api/security/csp-report") expected_csp = "; ".join( @@ -128,11 +172,10 @@ async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_serverapp, pending_ await pending_kernel_is_ready(kernel1["id"]) # Start a second kernel - r = await jp_fetch( - "api", "kernels", method="POST", body=json.dumps({"name": NATIVE_KERNEL_NAME}) - ) + r = await jp_fetch("api", "kernels", method="POST", body=json.dumps({"name": kernel_name})) kernel2 = json.loads(r.body.decode()) assert isinstance(kernel2, dict) + assert kernel2["name"] == kernel_name await pending_kernel_is_ready(kernel1["id"]) # Get kernel list again @@ -176,7 +219,7 @@ async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_serverapp, pending_ "api", "kernels", method="POST", - body=json.dumps({"name": NATIVE_KERNEL_NAME, "path": "/foo"}), + body=json.dumps({"name": kernel_name, "path": "/foo"}), ) kernel3 = json.loads(r.body.decode()) assert isinstance(kernel3, dict) @@ -184,11 +227,13 @@ async def test_main_kernel_handler(jp_fetch, jp_base_url, jp_serverapp, pending_ @pytest.mark.timeout(TEST_TIMEOUT) -async def test_kernel_handler(jp_fetch, jp_serverapp, pending_kernel_is_ready): +async def test_kernel_handler(jp_server_config, jp_fetch, jp_serverapp, pending_kernel_is_ready): # Create a kernel - r = await jp_fetch( - "api", "kernels", method="POST", body=json.dumps({"name": NATIVE_KERNEL_NAME}) + spec_name_prefix = jp_server_config.get("RenamingKernelSpecManager", {}).get( + "spec_name_prefix", "" ) + kernel_name = spec_name_prefix + NATIVE_KERNEL_NAME + r = await jp_fetch("api", "kernels", method="POST", body=json.dumps({"name": kernel_name})) kernel_id = json.loads(r.body.decode())["id"] r = await jp_fetch("api", "kernels", kernel_id, method="GET") kernel = json.loads(r.body.decode()) @@ -257,11 +302,13 @@ async def test_kernel_handler_startup_error_pending( @pytest.mark.timeout(TEST_TIMEOUT) -async def test_connection(jp_fetch, jp_ws_fetch, jp_http_port, jp_auth_header): +async def test_connection(jp_server_config, jp_fetch, jp_ws_fetch, jp_http_port, jp_auth_header): # Create kernel - r = await jp_fetch( - "api", "kernels", method="POST", body=json.dumps({"name": NATIVE_KERNEL_NAME}) + spec_name_prefix = jp_server_config.get("RenamingKernelSpecManager", {}).get( + "spec_name_prefix", "" ) + kernel_name = spec_name_prefix + NATIVE_KERNEL_NAME + r = await jp_fetch("api", "kernels", method="POST", body=json.dumps({"name": kernel_name})) kid = json.loads(r.body.decode())["id"] # Get kernel info diff --git a/tests/services/kernelspecs/test_api.py b/tests/services/kernelspecs/test_api.py index 105dd983fa..e23c990453 100644 --- a/tests/services/kernelspecs/test_api.py +++ b/tests/services/kernelspecs/test_api.py @@ -8,7 +8,26 @@ from ...utils import expected_http_error, some_resource -async def test_list_kernelspecs_bad(jp_fetch, jp_kernelspecs, jp_data_dir, jp_serverapp): +@pytest.fixture(params=[False, True]) +def jp_rename_kernels(request): + return request.param + + +@pytest.fixture +def jp_argv(jp_rename_kernels): + argv = [] + if jp_rename_kernels: + argv.extend( + [ + "--ServerApp.kernel_spec_manager_class=jupyter_server.services.kernelspecs.renaming.RenamingKernelSpecManager", + ] + ) + return argv + + +async def test_list_kernelspecs_bad( + jp_rename_kernels, jp_fetch, jp_kernelspecs, jp_data_dir, jp_serverapp +): app: ServerApp = jp_serverapp default = app.kernel_manager.default_kernel_name bad_kernel_dir = jp_data_dir.joinpath(jp_data_dir, "kernels", "bad2") @@ -25,7 +44,7 @@ async def test_list_kernelspecs_bad(jp_fetch, jp_kernelspecs, jp_data_dir, jp_se assert len(specs) > 2 -async def test_list_kernelspecs(jp_fetch, jp_kernelspecs, jp_serverapp): +async def test_list_kernelspecs(jp_rename_kernels, jp_fetch, jp_kernelspecs, jp_serverapp): app: ServerApp = jp_serverapp default = app.kernel_manager.default_kernel_name r = await jp_fetch("api", "kernelspecs", method="GET") @@ -37,21 +56,38 @@ async def test_list_kernelspecs(jp_fetch, jp_kernelspecs, jp_serverapp): assert len(specs) > 2 def is_sample_kernelspec(s): - return s["name"] == "sample" and s["spec"]["display_name"] == "Test kernel" + if jp_rename_kernels: + return ( + s["name"] == "local-sample" and s["spec"]["display_name"] == "Test kernel (Local)" + ) + else: + return s["name"] == "sample" and s["spec"]["display_name"] == "Test kernel" def is_default_kernelspec(s): return s["name"] == default assert any(is_sample_kernelspec(s) for s in specs.values()), specs - assert any(is_default_kernelspec(s) for s in specs.values()), specs + assert any( + is_default_kernelspec(s) for s in specs.values() + ), f"Default kernel name {default} not found in {specs}" -async def test_get_kernelspecs(jp_fetch, jp_kernelspecs): - r = await jp_fetch("api", "kernelspecs", "Sample", method="GET") +async def test_get_kernelspecs(jp_rename_kernels, jp_fetch, jp_kernelspecs): + kernel_name = "Sample" + if jp_rename_kernels: + kernel_name = "local-sample" + r = await jp_fetch("api", "kernelspecs", kernel_name, method="GET") model = json.loads(r.body.decode()) - assert model["name"].lower() == "sample" + if jp_rename_kernels: + assert model["name"].lower() == "local-sample" + else: + assert model["name"].lower() == "sample" + assert isinstance(model["spec"], dict) - assert model["spec"]["display_name"] == "Test kernel" + if jp_rename_kernels: + assert model["spec"]["display_name"] == "Test kernel (Local)" + else: + assert model["spec"]["display_name"] == "Test kernel" assert isinstance(model["resources"], dict) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 86fcf508ca..b7dd653664 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -222,6 +222,29 @@ def get_token( return f"{self.config_var_2}{self.config_var_1}" +@pytest.fixture(params=[False, True]) +def jp_rename_kernels(request): + return request.param + + +@pytest.fixture +def jp_argv(jp_rename_kernels): + argv = [ + "--gateway-url=" + mock_gateway_url, + ] + if jp_rename_kernels: + argv.append( + "--ServerApp.kernel_spec_manager_class=jupyter_server.gateway.managers.GatewayRenamingKernelSpecManager" + ) + return argv + + +@pytest.fixture +def jp_mocked_gateway(jp_server_config, jp_argv, jp_configurable_serverapp): + with mocked_gateway: + yield jp_configurable_serverapp(config=jp_server_config, argv=jp_argv) + + @pytest.fixture() def jp_server_config(): return Config( @@ -248,7 +271,7 @@ def init_gateway(monkeypatch): GatewayClient.clear_instance() -async def test_gateway_env_options(init_gateway, jp_serverapp): +async def test_gateway_env_options(init_gateway, jp_mocked_gateway, jp_serverapp): assert jp_serverapp.gateway_config.gateway_enabled is True assert jp_serverapp.gateway_config.url == mock_gateway_url assert jp_serverapp.gateway_config.http_user == mock_http_user @@ -264,18 +287,19 @@ async def test_gateway_env_options(init_gateway, jp_serverapp): assert GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT == 43 -def test_gateway_cli_options(jp_configurable_serverapp, capsys): - argv = [ - "--gateway-url=" + mock_gateway_url, - "--GatewayClient.http_user=" + mock_http_user, - "--GatewayClient.connect_timeout=44.4", - "--GatewayClient.request_timeout=96.0", - "--GatewayClient.launch_timeout_pad=5.1", - "--GatewayClient.env_whitelist=FOO,BAR", - ] +def test_gateway_cli_options(jp_mocked_gateway, jp_argv, jp_configurable_serverapp, capsys): + jp_argv.extend( + [ + "--GatewayClient.http_user=" + mock_http_user, + "--GatewayClient.connect_timeout=44.4", + "--GatewayClient.request_timeout=96.0", + "--GatewayClient.launch_timeout_pad=5.1", + "--GatewayClient.env_whitelist=FOO,BAR", + ] + ) GatewayClient.clear_instance() - app = jp_configurable_serverapp(argv=argv) + app = jp_configurable_serverapp(argv=jp_argv) assert app.gateway_config.gateway_enabled is True assert app.gateway_config.url == mock_gateway_url @@ -299,15 +323,16 @@ def test_gateway_cli_options(jp_configurable_serverapp, capsys): @pytest.mark.parametrize("renewer_type", ["default", "custom"]) -def test_token_renewer_config(jp_server_config, jp_configurable_serverapp, renewer_type): - argv = ["--gateway-url=" + mock_gateway_url] +def test_token_renewer_config( + jp_server_config, jp_argv, jp_configurable_serverapp, renewer_type, jp_mocked_gateway +): if renewer_type == "custom": - argv.append( + jp_argv.append( "--GatewayClient.gateway_token_renewer_class=tests.test_gateway.CustomTestTokenRenewer" ) GatewayClient.clear_instance() - app = jp_configurable_serverapp(argv=argv) + app = jp_configurable_serverapp(argv=jp_argv) assert app.gateway_config.gateway_enabled is True assert app.gateway_config.url == mock_gateway_url @@ -404,51 +429,78 @@ def test_gateway_request_with_expiring_cookies( GatewayClient.clear_instance() -async def test_gateway_class_mappings(init_gateway, jp_serverapp): +async def test_gateway_class_mappings( + init_gateway, jp_mocked_gateway, jp_rename_kernels, jp_serverapp +): # Ensure appropriate class mappings are in place. assert jp_serverapp.kernel_manager_class.__name__ == "GatewayMappingKernelManager" assert jp_serverapp.session_manager_class.__name__ == "GatewaySessionManager" - assert jp_serverapp.kernel_spec_manager_class.__name__ == "GatewayKernelSpecManager" + if jp_rename_kernels: + assert jp_serverapp.kernel_spec_manager_class.__name__ == "GatewayRenamingKernelSpecManager" + else: + assert jp_serverapp.kernel_spec_manager_class.__name__ == "GatewayKernelSpecManager" -async def test_gateway_get_kernelspecs(init_gateway, jp_fetch, jp_serverapp): +async def test_gateway_get_kernelspecs( + init_gateway, jp_mocked_gateway, jp_rename_kernels, jp_fetch, jp_serverapp +): # Validate that kernelspecs come from gateway. - with mocked_gateway: - r = await jp_fetch("api", "kernelspecs", method="GET") - assert r.code == 200 - content = json.loads(r.body.decode("utf-8")) - kspecs = content.get("kernelspecs") - assert len(kspecs) == 2 - assert kspecs.get("kspec_bar").get("name") == "kspec_bar" - assert ( - kspecs.get("kspec_bar").get("resources")["logo-64x64"].startswith(jp_serverapp.base_url) - ) + r = await jp_fetch("api", "kernelspecs", method="GET") + assert r.code == 200 + content = json.loads(r.body.decode("utf-8")) + default_kernel_name = content.get("default") + kspecs = content.get("kernelspecs") + assert len(kspecs) == 2 + expected_kernel_name = "kspec_bar" + expected_default_kernel_name = "kspec_foo" + if jp_rename_kernels: + expected_kernel_name = "remote-kspec_bar" + expected_default_kernel_name = "remote-kspec_foo" + assert default_kernel_name == expected_default_kernel_name + assert kspecs.get(expected_kernel_name).get("name") == expected_kernel_name + assert ( + kspecs.get(expected_kernel_name) + .get("resources")["logo-64x64"] + .startswith(jp_serverapp.base_url) + ) -async def test_gateway_get_named_kernelspec(init_gateway, jp_fetch): +async def test_gateway_get_named_kernelspec( + init_gateway, jp_mocked_gateway, jp_rename_kernels, jp_fetch +): # Validate that a specific kernelspec can be retrieved from gateway (and an invalid spec can't) - with mocked_gateway: - r = await jp_fetch("api", "kernelspecs", "kspec_foo", method="GET") - assert r.code == 200 - kspec_foo = json.loads(r.body.decode("utf-8")) - assert kspec_foo.get("name") == "kspec_foo" + kernel_name = "kspec_foo" + if jp_rename_kernels: + kernel_name = "remote-kspec_foo" + r = await jp_fetch("api", "kernelspecs", kernel_name, method="GET") + assert r.code == 200 + kspec_foo = json.loads(r.body.decode("utf-8")) + assert kspec_foo.get("name") == kernel_name - r = await jp_fetch("kernelspecs", "kspec_foo", "logo-64x64.png", method="GET") - assert r.code == 200 - assert r.body == b"foo" - assert r.headers["content-type"] == "image/png" + r = await jp_fetch("kernelspecs", kernel_name, "logo-64x64.png", method="GET") + assert r.code == 200 + assert r.body == b"foo" + assert r.headers["content-type"] == "image/png" - with pytest.raises(tornado.httpclient.HTTPClientError) as e: - await jp_fetch("api", "kernelspecs", "no_such_spec", method="GET") - assert expected_http_error(e, 404) + with pytest.raises(tornado.httpclient.HTTPClientError) as e: + await jp_fetch("api", "kernelspecs", "no_such_spec", method="GET") + assert expected_http_error(e, 404) @pytest.mark.parametrize("cull_kernel", [False, True]) -async def test_gateway_session_lifecycle(init_gateway, jp_root_dir, jp_fetch, cull_kernel): +async def test_gateway_session_lifecycle( + init_gateway, jp_mocked_gateway, jp_rename_kernels, jp_root_dir, jp_fetch, cull_kernel +): # Validate session lifecycle functions; create and delete. # create - session_id, kernel_id = await create_session(jp_fetch, "kspec_foo") + kernel_name = "kspec_foo" + remote_kernel_name = kernel_name + if jp_rename_kernels: + kernel_name = "remote-kspec_foo" + session_id, kernel_id = await create_session( + jp_fetch, kernel_name, remote_kernel_name=remote_kernel_name + ) # ensure kernel still considered running assert await is_session_active(jp_fetch, session_id) is True @@ -460,7 +512,7 @@ async def test_gateway_session_lifecycle(init_gateway, jp_root_dir, jp_fetch, cu assert await is_session_active(jp_fetch, session_id) is True # restart - await restart_kernel(jp_fetch, kernel_id) + await restart_kernel(jp_fetch, kernel_id, kernel_name, remote_kernel_name) assert await is_session_active(jp_fetch, session_id) is True @@ -486,6 +538,8 @@ async def test_gateway_session_lifecycle(init_gateway, jp_root_dir, jp_fetch, cu @pytest.mark.parametrize("cull_kernel", [False, True]) async def test_gateway_kernel_lifecycle( init_gateway, + jp_mocked_gateway, + jp_rename_kernels, jp_configurable_serverapp, jp_read_emitted_events, jp_event_handler, @@ -499,7 +553,11 @@ async def test_gateway_kernel_lifecycle( app.event_logger.register_handler(jp_event_handler) # create - kernel_id = await create_kernel(jp_fetch, "kspec_bar") + kernel_name = "kspec_bar" + remote_kernel_name = kernel_name + if jp_rename_kernels: + kernel_name = "remote-kspec_bar" + kernel_id = await create_kernel(jp_fetch, kernel_name, remote_kernel_name=remote_kernel_name) output = jp_read_emitted_events()[0] assert "action" in output and output["action"] == "start" @@ -529,7 +587,7 @@ async def test_gateway_kernel_lifecycle( assert await is_kernel_running(jp_fetch, kernel_id) is True # restart - await restart_kernel(jp_fetch, kernel_id) + await restart_kernel(jp_fetch, kernel_id, kernel_name, remote_kernel_name) output = jp_read_emitted_events()[0] assert "action" in output and output["action"] == "restart" @@ -565,7 +623,9 @@ async def test_gateway_kernel_lifecycle( @pytest.mark.parametrize("missing_kernel", [True, False]) -async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_kernel): +async def test_gateway_shutdown( + init_gateway, jp_mocked_gateway, jp_serverapp, jp_fetch, missing_kernel +): # Validate server shutdown when multiple gateway kernels are present or # we've lost track of at least one (missing) kernel @@ -580,8 +640,7 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke if missing_kernel: running_kernels.pop(k1) # "terminate" kernel w/o our knowledge - with mocked_gateway: - await jp_serverapp.kernel_manager.shutdown_all() + await jp_serverapp.kernel_manager.shutdown_all() assert await is_kernel_running(jp_fetch, k1) is False assert await is_kernel_running(jp_fetch, k2) is False @@ -589,7 +648,7 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke @patch("websocket.create_connection", mock_websocket_create_connection(recv_side_effect=Exception)) async def test_kernel_client_response_router_notifies_channel_queue_when_finished( - init_gateway, jp_serverapp, jp_fetch + init_gateway, jp_mocked_gateway, jp_serverapp, jp_fetch ): # create kernel_id = await create_kernel(jp_fetch, "kspec_bar") @@ -666,129 +725,124 @@ async def test_channel_queue_get_msg_when_response_router_had_finished(): async def is_session_active(jp_fetch, session_id): """Issues request to get the set of running kernels""" - with mocked_gateway: - # Get list of running kernels - r = await jp_fetch("api", "sessions", method="GET") - assert r.code == 200 - sessions = json.loads(r.body.decode("utf-8")) - assert len(sessions) == len(running_kernels) # Use running_kernels as truth - return any(model.get("id") == session_id for model in sessions) + # Get list of running kernels + r = await jp_fetch("api", "sessions", method="GET") + assert r.code == 200 + sessions = json.loads(r.body.decode("utf-8")) + assert len(sessions) == len(running_kernels) # Use running_kernels as truth + return any(model.get("id") == session_id for model in sessions) -async def create_session(jp_fetch, kernel_name): +async def create_session(jp_fetch, kernel_name, remote_kernel_name=None): """Creates a session for a kernel. The session is created against the server which then uses the gateway for kernel management. """ - with mocked_gateway: - nb_path = "/testgw.ipynb" - body = json.dumps( - {"path": str(nb_path), "type": "notebook", "kernel": {"name": kernel_name}} - ) - - # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method - os.environ["KERNEL_KSPEC_NAME"] = kernel_name - - # Create the kernel... (also tests get_kernel) - r = await jp_fetch("api", "sessions", method="POST", body=body) - assert r.code == 201 - model = json.loads(r.body.decode("utf-8")) - assert model.get("path") == str(nb_path) - kernel_id = model.get("kernel").get("id") - # ensure its in the running_kernels and name matches. - running_kernel = running_kernels.get(kernel_id) - assert running_kernel is not None - assert kernel_id == running_kernel.get("id") - assert model.get("kernel").get("name") == running_kernel.get("name") - session_id = model.get("id") - - # restore env - os.environ.pop("KERNEL_KSPEC_NAME") - return session_id, kernel_id + nb_path = "/testgw.ipynb" + body = json.dumps({"path": str(nb_path), "type": "notebook", "kernel": {"name": kernel_name}}) + + # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method + remote_kernel_name = remote_kernel_name or kernel_name + os.environ["KERNEL_KSPEC_NAME"] = remote_kernel_name + + # Create the kernel... (also tests get_kernel) + r = await jp_fetch("api", "sessions", method="POST", body=body) + assert r.code == 201 + model = json.loads(r.body.decode("utf-8")) + assert model.get("path") == str(nb_path) + kernel_id = model.get("kernel").get("id") + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(kernel_id) + assert running_kernel is not None + assert kernel_id == running_kernel.get("id") + assert running_kernel.get("name") == remote_kernel_name + assert model.get("kernel").get("name") == kernel_name + session_id = model.get("id") + + # restore env + os.environ.pop("KERNEL_KSPEC_NAME") + return session_id, kernel_id async def delete_session(jp_fetch, session_id): """Deletes a session corresponding to the given session id.""" - with mocked_gateway: - # Delete the session (and kernel) - r = await jp_fetch("api", "sessions", session_id, method="DELETE") - assert r.code == 204 - assert r.reason == "No Content" + # Delete the session (and kernel) + r = await jp_fetch("api", "sessions", session_id, method="DELETE") + assert r.code == 204 + assert r.reason == "No Content" async def is_kernel_running(jp_fetch, kernel_id): """Issues request to get the set of running kernels""" - with mocked_gateway: - # Get list of running kernels - r = await jp_fetch("api", "kernels", method="GET") - assert r.code == 200 - kernels = json.loads(r.body.decode("utf-8")) - assert len(kernels) == len(running_kernels) - return any(model.get("id") == kernel_id for model in kernels) + # Get list of running kernels + r = await jp_fetch("api", "kernels", method="GET") + assert r.code == 200 + kernels = json.loads(r.body.decode("utf-8")) + assert len(kernels) == len(running_kernels) + return any(model.get("id") == kernel_id for model in kernels) -async def create_kernel(jp_fetch, kernel_name): +async def create_kernel(jp_fetch, kernel_name, remote_kernel_name=None): """Issues request to retart the given kernel""" - with mocked_gateway: - body = json.dumps({"name": kernel_name}) + body = json.dumps({"name": kernel_name}) - # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method - os.environ["KERNEL_KSPEC_NAME"] = kernel_name + # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method + remote_kernel_name = remote_kernel_name or kernel_name + os.environ["KERNEL_KSPEC_NAME"] = remote_kernel_name - r = await jp_fetch("api", "kernels", method="POST", body=body) - assert r.code == 201 - model = json.loads(r.body.decode("utf-8")) - kernel_id = model.get("id") - # ensure its in the running_kernels and name matches. - running_kernel = running_kernels.get(kernel_id) - assert running_kernel is not None - assert kernel_id == running_kernel.get("id") - assert model.get("name") == kernel_name + r = await jp_fetch("api", "kernels", method="POST", body=body) + assert r.code == 201 + model = json.loads(r.body.decode("utf-8")) + kernel_id = model.get("id") + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(kernel_id) + assert running_kernel is not None + assert kernel_id == running_kernel.get("id") + assert running_kernel.get("name") == remote_kernel_name + assert model.get("name") == kernel_name - # restore env - os.environ.pop("KERNEL_KSPEC_NAME") - return kernel_id + # restore env + os.environ.pop("KERNEL_KSPEC_NAME") + return kernel_id async def interrupt_kernel(jp_fetch, kernel_id): """Issues request to interrupt the given kernel""" - with mocked_gateway: - r = await jp_fetch( - "api", - "kernels", - kernel_id, - "interrupt", - method="POST", - allow_nonstandard_methods=True, - ) - assert r.code == 204 - assert r.reason == "No Content" + r = await jp_fetch( + "api", + "kernels", + kernel_id, + "interrupt", + method="POST", + allow_nonstandard_methods=True, + ) + assert r.code == 204 + assert r.reason == "No Content" -async def restart_kernel(jp_fetch, kernel_id): +async def restart_kernel(jp_fetch, kernel_id, kernel_name, remote_kernel_name): """Issues request to retart the given kernel""" - with mocked_gateway: - r = await jp_fetch( - "api", - "kernels", - kernel_id, - "restart", - method="POST", - allow_nonstandard_methods=True, - ) - assert r.code == 200 - model = json.loads(r.body.decode("utf-8")) - restarted_kernel_id = model.get("id") - # ensure its in the running_kernels and name matches. - running_kernel = running_kernels.get(restarted_kernel_id) - assert running_kernel is not None - assert restarted_kernel_id == running_kernel.get("id") - assert model.get("name") == running_kernel.get("name") + r = await jp_fetch( + "api", + "kernels", + kernel_id, + "restart", + method="POST", + allow_nonstandard_methods=True, + ) + assert r.code == 200 + model = json.loads(r.body.decode("utf-8")) + restarted_kernel_id = model.get("id") + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(restarted_kernel_id) + assert running_kernel is not None + assert restarted_kernel_id == running_kernel.get("id") + assert running_kernel.get("name") == remote_kernel_name + assert model.get("name") == kernel_name async def delete_kernel(jp_fetch, kernel_id): """Deletes kernel corresponding to the given kernel id.""" - with mocked_gateway: - # Delete the session (and kernel) - r = await jp_fetch("api", "kernels", kernel_id, method="DELETE") - assert r.code == 204 - assert r.reason == "No Content" + # Delete the session (and kernel) + r = await jp_fetch("api", "kernels", kernel_id, method="DELETE") + assert r.code == 204 + assert r.reason == "No Content"