From aaf8d5b9e212c96949371339457ecf0c792fdf98 Mon Sep 17 00:00:00 2001 From: Zach Sailer Date: Thu, 13 Nov 2025 12:21:46 -0800 Subject: [PATCH 1/3] Add v3 kernel API with shared kernel clients and enhanced monitoring Introduces a next-generation kernel API (v3) that can be enabled via the --kernels-v3 flag or by setting kernels_api_version=3 in config. Key improvements: - Shared kernel client per kernel: Single client instance shared across all websocket connections, reducing resource usage and improving consistency - Pre-created clients with automatic lifecycle management: Clients connect on kernel start and disconnect on shutdown/restart - Enhanced message routing: Channel and cell ID encoding in message IDs enables precise message delivery to originating cells - Improved kernel monitoring: Better execution state tracking and heartbeat monitoring for both local and gateway kernels - Backward compatible: Defaults to v2 API; v3 is opt-in The v3 classes are swapped in automatically when enabled, requiring no manual configuration of individual components. --- jupyter_server/gateway/v3/__init__.py | 0 jupyter_server/gateway/v3/managers.py | 214 ++++++ jupyter_server/serverapp.py | 45 +- .../services/kernels/v3/__init__.py | 0 jupyter_server/services/kernels/v3/client.py | 702 ++++++++++++++++++ .../kernels/v3/connection/__init__.py | 0 .../v3/connection/client_connection.py | 195 +++++ .../services/kernels/v3/kernelmanager.py | 120 +++ .../services/kernels/v3/message_utils.py | 330 ++++++++ jupyter_server/services/kernels/v3/states.py | 22 + 10 files changed, 1626 insertions(+), 2 deletions(-) create mode 100644 jupyter_server/gateway/v3/__init__.py create mode 100644 jupyter_server/gateway/v3/managers.py create mode 100644 jupyter_server/services/kernels/v3/__init__.py create mode 100644 jupyter_server/services/kernels/v3/client.py create mode 100644 jupyter_server/services/kernels/v3/connection/__init__.py create mode 100644 jupyter_server/services/kernels/v3/connection/client_connection.py create mode 100644 jupyter_server/services/kernels/v3/kernelmanager.py create mode 100644 jupyter_server/services/kernels/v3/message_utils.py create mode 100644 jupyter_server/services/kernels/v3/states.py diff --git a/jupyter_server/gateway/v3/__init__.py b/jupyter_server/gateway/v3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jupyter_server/gateway/v3/managers.py b/jupyter_server/gateway/v3/managers.py new file mode 100644 index 0000000000..35b0c74a77 --- /dev/null +++ b/jupyter_server/gateway/v3/managers.py @@ -0,0 +1,214 @@ +"""Gateway kernel manager that integrates with our kernel monitoring system.""" + +import asyncio +from jupyter_server.gateway.managers import GatewayMappingKernelManager +from jupyter_server.gateway.managers import GatewayKernelManager as _GatewayKernelManager +from jupyter_server.gateway.managers import GatewayKernelClient as _GatewayKernelClient +from traitlets import default, Instance, Type + +from jupyter_server.services.kernels.v3.client import JupyterServerKernelClientMixin + + +class GatewayKernelClient(JupyterServerKernelClientMixin, _GatewayKernelClient): + """ + Gateway kernel client that combines our monitoring capabilities with gateway support. + + This client inherits from: + - JupyterServerKernelClientMixin: Provides kernel monitoring capabilities, message caching, + and execution state tracking that integrates with our kernel monitor system + - GatewayKernelClient: Provides gateway communication capabilities for remote kernels + + The combination allows remote gateway kernels to be monitored with the same level of + detail as local kernels, including heartbeat monitoring, execution state tracking, + and kernel lifecycle management. + """ + + async def _test_kernel_communication(self, timeout: float = 10.0) -> bool: + """Skip kernel_info test for gateway kernels. + + Gateway kernels handle communication differently and the kernel_info + test can hang due to message routing differences. + + Returns: + bool: Always returns True for gateway kernels + """ + return True + + def _send_message(self, channel_name: str, msg: list[bytes]): + # Send to gateway channel + try: + channel = getattr(self, f"{channel_name}_channel", None) + if channel and hasattr(channel, 'send'): + # Convert raw message to gateway format + header = self.session.unpack(msg[0]) + parent_header = self.session.unpack(msg[1]) + metadata = self.session.unpack(msg[2]) + content = self.session.unpack(msg[3]) + + full_msg = { + 'header': header, + 'parent_header': parent_header, + 'metadata': metadata, + 'content': content, + 'buffers': msg[4:] if len(msg) > 4 else [], + 'channel': channel_name, + 'msg_id': header.get('msg_id'), + 'msg_type': header.get('msg_type') + } + + channel.send(full_msg) + except Exception as e: + self.log.warn(f"Error handling incoming message on gateway: {e}") + + async def _monitor_channel_messages(self, channel_name: str, channel): + """Monitor a gateway channel for incoming messages.""" + try: + while channel.is_alive(): + try: + # Get message from gateway channel queue + message = await channel.get_msg() + + # Update execution state from status messages + # Gateway messages are already deserialized dicts + self._update_execution_state_from_status( + channel_name, + message, + parent_msg_id=message.get("parent_header", {}).get("msg_id"), + execution_state=message.get("content", {}).get("execution_state") + ) + + # Serialize message to standard format for listeners + # Gateway messages are dicts, convert to list[bytes] format + # session.serialize() returns: [b'', signature, header, parent_header, metadata, content, buffers...] + serialized = self.session.serialize(message) + + # Skip delimiter (index 0) and signature (index 1) to get [header, parent_header, metadata, content, ...] + if serialized and len(serialized) >= 6: # Need delimiter + signature + 4 message parts + msg_list = serialized[2:] + else: + self.log.warning(f"Gateway message too short: {len(serialized) if serialized else 0} parts") + continue + + # Route to listeners + await self._route_to_listeners(channel_name, msg_list) + + except asyncio.TimeoutError: + # No message available, continue loop + continue + except Exception as e: + self.log.debug(f"Error processing gateway message in {channel_name}: {e}") + continue + + await asyncio.sleep(0.01) + + except asyncio.CancelledError: + pass + except Exception as e: + self.log.error(f"Gateway channel monitoring failed for {channel_name}: {e}") + + +class GatewayKernelManager(_GatewayKernelManager): + """ + Gateway kernel manager that uses our enhanced gateway kernel client. + + This manager inherits from jupyter_server's GatewayKernelManager and configures it + to use our GatewayKernelClient, which provides: + + - Gateway communication capabilities for remote kernels + - Kernel monitoring integration (heartbeat, execution state tracking) + - Message ID encoding with channel and src_id using simple string operations + - Full compatibility with our kernel monitor extension + - Pre-created kernel client instance stored as a property + - Automatic client connection/disconnection on kernel start/shutdown + + When jupyter_server is configured to use a gateway, this manager ensures that + remote kernels receive the same level of monitoring as local kernels. + """ + # Configure the manager to use our enhanced gateway client + client_class = GatewayKernelClient + client_factory = GatewayKernelClient + + kernel_client = Instance( + 'jupyter_client.client.KernelClient', + allow_none=True, + help="""Pre-created kernel client instance. Created on initialization.""" + ) + + def __init__(self, **kwargs): + """Initialize the kernel manager and create a kernel client instance.""" + super().__init__(**kwargs) + + # Create a kernel client instance immediately + self.kernel_client = self.client(session=self.session) + + async def post_start_kernel(self, **kwargs): + """After kernel starts, connect the kernel client. + + This method is called after the kernel has been successfully started. + It loads the latest connection info (with ports set by provisioner) + and connects the kernel client to the kernel. + + Note: If you override this method, make sure to call super().post_start_kernel(**kwargs) + to ensure the kernel client connects properly. + """ + await super().post_start_kernel(**kwargs) + + try: + # Load latest connection info from kernel manager + # The provisioner has now set the real ports + self.kernel_client.load_connection_info(self.get_connection_info(session=True)) + + # Connect the kernel client + success = await self.kernel_client.connect() + + if not success: + raise RuntimeError(f"Failed to connect kernel client for kernel {self.kernel_id}") + + self.log.info(f"Successfully connected kernel client for kernel {self.kernel_id}") + + except Exception as e: + self.log.error(f"Failed to connect kernel client: {e}") + # Re-raise to fail the kernel start + raise + + async def cleanup_resources(self, restart=False): + """Cleanup resources, disconnecting the kernel client if not restarting. + + Parameters + ---------- + restart : bool + If True, the kernel is being restarted and we should keep the client + connected but clear its state. If False, fully disconnect. + """ + if self.kernel_client: + if restart: + # On restart, clear client state but keep connection + # The connection will be refreshed in post_start_kernel after restart + self.log.debug(f"Clearing kernel client state for restart of kernel {self.kernel_id}") + self.kernel_client.last_shell_status_time = None + self.kernel_client.last_control_status_time = None + # Disconnect before restart - will reconnect after + await self.kernel_client.stop_listening() + self.kernel_client.stop_channels() + else: + # On shutdown, fully disconnect the client + self.log.debug(f"Disconnecting kernel client for kernel {self.kernel_id}") + await self.kernel_client.stop_listening() + self.kernel_client.stop_channels() + + await super().cleanup_resources(restart=restart) + + +class GatewayMultiKernelManager(GatewayMappingKernelManager): + """Custom kernel manager that uses enhanced monitoring kernel manager with v3 API.""" + + @default("kernel_manager_class") + def _default_kernel_manager_class(self): + return "jupyter_server.gateway.v3.managers.GatewayKernelManager" + + def start_watching_activity(self, kernel_id): + pass + + def stop_buffering(self, kernel_id): + pass + diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 1afbef4d0d..eee6063122 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -110,6 +110,9 @@ GatewayMappingKernelManager, GatewaySessionManager, ) +from jupyter_server.gateway.v3.managers import ( + GatewayMultiKernelManager, +) from jupyter_server.log import log_request from jupyter_server.prometheus.metrics import ( ACTIVE_DURATION, @@ -131,6 +134,12 @@ AsyncMappingKernelManager, MappingKernelManager, ) +from jupyter_server.services.kernels.v3.connection.client_connection import ( + KernelClientWebsocketConnection, +) +from jupyter_server.services.kernels.v3.kernelmanager import ( + AsyncMappingKernelManager as V3AsyncMappingKernelManager, +) from jupyter_server.services.sessions.sessionmanager import SessionManager from jupyter_server.utils import ( JupyterServerAuthWarning, @@ -829,6 +838,13 @@ def start(self): extensions. """, ) +flags["kernels-v3"] = ( + {"ServerApp": {"kernels_api_version": 3}}, + _i18n( + "Enable the next-generation kernel API (v3) with shared kernel clients, " + "improved message routing, and enhanced kernel monitoring." + ), +) # Add notebook manager flags @@ -901,6 +917,10 @@ class ServerApp(JupyterApp): Authorizer, EventLogger, ZMQChannelsWebsocketConnection, + # V3 Kernel API classes + V3AsyncMappingKernelManager, + KernelClientWebsocketConnection, + GatewayMultiKernelManager, ] subcommands: dict[str, t.Any] = { @@ -1612,6 +1632,17 @@ def template_file_path(self) -> list[str]: help=_i18n("The content manager class to use."), ) + kernels_api_version = Integer( + 2, + config=True, + help=_i18n( + "Kernel API version to use. " + "Version 2 (default): Standard kernel API with direct ZMQ connections. " + "Version 3: Next-generation API with shared kernel clients, " + "improved message routing, and enhanced kernel monitoring." + ), + ) + kernel_manager_class = Type( klass=MappingKernelManager, config=True, @@ -1620,7 +1651,13 @@ def template_file_path(self) -> list[str]: @default("kernel_manager_class") def _default_kernel_manager_class(self) -> t.Union[str, type[AsyncMappingKernelManager]]: - if self.gateway_config.gateway_enabled: + if self.kernels_api_version == 3: + gateway_enabled = getattr(self, 'gateway_config', None) and getattr(self.gateway_config, 'gateway_enabled', False) + if gateway_enabled: + return "jupyter_server.gateway.v3.managers.GatewayMultiKernelManager" + return "jupyter_server.services.kernels.v3.kernelmanager.AsyncMappingKernelManager" + gateway_enabled = getattr(self, 'gateway_config', None) and getattr(self.gateway_config, 'gateway_enabled', False) + if gateway_enabled: return "jupyter_server.gateway.managers.GatewayMappingKernelManager" return AsyncMappingKernelManager @@ -1645,7 +1682,11 @@ def _default_session_manager_class(self) -> t.Union[str, type[SessionManager]]: def _default_kernel_websocket_connection_class( self, ) -> t.Union[str, type[ZMQChannelsWebsocketConnection]]: - if self.gateway_config.gateway_enabled: + if self.kernels_api_version == 3: + # V3 uses shared kernel client connection for both local and gateway + return "jupyter_server.services.kernels.v3.connection.client_connection.KernelClientWebsocketConnection" + gateway_enabled = getattr(self, 'gateway_config', None) and getattr(self.gateway_config, 'gateway_enabled', False) + if gateway_enabled: return "jupyter_server.gateway.connections.GatewayWebSocketConnection" return ZMQChannelsWebsocketConnection diff --git a/jupyter_server/services/kernels/v3/__init__.py b/jupyter_server/services/kernels/v3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jupyter_server/services/kernels/v3/client.py b/jupyter_server/services/kernels/v3/client.py new file mode 100644 index 0000000000..fbd11342d7 --- /dev/null +++ b/jupyter_server/services/kernels/v3/client.py @@ -0,0 +1,702 @@ +import asyncio +import time +import typing as t +from datetime import datetime, timezone + +from traitlets import HasTraits, Type +from jupyter_client.asynchronous.client import AsyncKernelClient +from jupyter_client.channels import AsyncZMQSocketChannel +from jupyter_client.channelsabc import ChannelABC +from .states import ExecutionStates +from .message_utils import parse_msg_id, encode_channel_in_message_dict + + +class NamedAsyncZMQSocketChannel(AsyncZMQSocketChannel): + """Prepends the channel name to all message IDs to this socket.""" + channel_name = "unknown" + + def send(self, msg): + """Send a message with automatic channel encoding.""" + msg = encode_channel_in_message_dict(msg, self.channel_name) + return super().send(msg) + + +class ShellChannel(NamedAsyncZMQSocketChannel): + """Shell channel that automatically encodes 'shell' in outgoing msg_ids.""" + channel_name = "shell" + + +class ControlChannel(AsyncZMQSocketChannel): + """Control channel that automatically encodes 'control' in outgoing msg_ids.""" + channel_name = "control" + + +class StdinChannel(AsyncZMQSocketChannel): + """Stdin channel that automatically encodes 'stdin' in outgoing msg_ids.""" + channel_name = "stdin" + + +class JupyterServerKernelClientMixin(HasTraits): + """Mixin that enhances AsyncKernelClient with listener API, message queuing, and channel encoding. + + Key Features: + + 1. **Listener API**: Register multiple listeners to receive kernel messages without blocking. + - `add_listener()`: Add a callback function to receive messages from the kernel + - `remove_listener()`: Remove a registered listener + - Supports message filtering by type and channel + - Multiple listeners can be registered (e.g., multiple WebSocket connections) + + 2. **Message Queuing**: Queue messages that arrive before the kernel client is ready. + - Messages from WebSockets are queued during kernel startup + - Queued messages are processed once the kernel connection is established + - Prevents message loss during the connection handshake + - Configurable queue size to prevent memory issues + + 3. **Channel Encoding**: Automatically encode channel names in all outgoing message IDs. + - All messages sent through shell, control, or stdin channels get the channel name prepended + - Format: `{channel}:{base_msg_id}` (e.g., "shell:abc123_456_0") + - Makes it easy to identify which channel status messages came from + - Enables proper execution state tracking (shell vs control channel responses) + - Uses custom channel classes (ShellChannel, ControlChannel, StdinChannel) + + This mixin is designed to work with Jupyter Server's multi-websocket architecture where + a single kernel client is shared across multiple WebSocket connections. + """ + + # Track kernel execution state (simplified - just a string) + execution_state: str = ExecutionStates.UNKNOWN.value + + # Track kernel activity + last_activity: datetime = None + + # Track last status message time per channel (shell and control) + last_shell_status_time: datetime = None + last_control_status_time: datetime = None + + # Connection test configuration + connection_test_timeout: float = 120.0 # Total timeout for connection test in seconds + connection_test_check_interval: float = 0.1 # How often to check for messages in seconds + connection_test_retry_interval: float = 10.0 # How often to retry kernel_info requests in seconds + + # Override channel classes to use our custom ones with automatic encoding + shell_channel_class = Type(ShellChannel) + control_channel_class = Type(ControlChannel) + stdin_channel_class = Type(StdinChannel) + + # Set of listener functions - don't use Traitlets Set, just plain Python set + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._listeners = {} # Maps callback -> filter config + self._listening = False + + # Connection state tracking + self._connecting = False + self._connection_ready = False + self._connection_ready_event = asyncio.Event() + + # Message queue for messages received before connection is ready + self._queued_messages = [] + self._max_queue_size = 1000 # Prevent memory issues + + # Note: The session is already EncodedMsgIdSession, created by the KernelManager + # No need to replace it here + + def add_listener( + self, + callback: t.Callable[[str, list[bytes]], None], + msg_types: t.Optional[t.List[t.Tuple[str, str]]] = None, + exclude_msg_types: t.Optional[t.List[t.Tuple[str, str]]] = None + ): + """Add a listener to be called when messages are received. + + Args: + callback: Function that takes (channel_name, msg_bytes) as arguments + msg_types: Optional list of (msg_type, channel) tuples to include. + If provided, only messages matching these filters will be sent to the listener. + Example: [("status", "iopub"), ("execute_reply", "shell")] + exclude_msg_types: Optional list of (msg_type, channel) tuples to exclude. + If provided, messages matching these filters will NOT be sent to the listener. + Example: [("status", "iopub")] + + Note: + - If both msg_types and exclude_msg_types are provided, msg_types takes precedence + - If neither is provided, all messages are sent (default behavior) + """ + if msg_types is not None and exclude_msg_types is not None: + raise ValueError("Cannot specify both msg_types and exclude_msg_types") + + # Store the listener with its filter configuration + self._listeners[callback] = { + 'msg_types': set(msg_types) if msg_types else None, + 'exclude_msg_types': set(exclude_msg_types) if exclude_msg_types else None + } + + def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]): + """Remove a listener.""" + self._listeners.pop(callback, None) + + def mark_connection_ready(self): + """Mark the connection as ready and process queued messages.""" + if not self._connection_ready: + self._connecting = False + self._connection_ready = True + self._connection_ready_event.set() + + # Process queued messages + asyncio.create_task(self._process_queued_messages()) + + async def wait_for_connection_ready(self, timeout: float = 30.0) -> bool: + """Wait for the connection to be ready.""" + try: + await asyncio.wait_for(self._connection_ready_event.wait(), timeout=timeout) + return True + except asyncio.TimeoutError: + return False + + async def _process_queued_messages(self): + """Process all messages that were queued during startup.""" + self.log.info(f"Processing {len(self._queued_messages)} queued messages") + + queued_messages = self._queued_messages.copy() + self._queued_messages.clear() + + for channel_name, msg in queued_messages: + try: + # Send queued messages to the kernel (these are incoming from websockets) + self._send_message(channel_name, msg) + except Exception as e: + self.log.error(f"Error processing queued message: {e}") + + def _queue_message_if_not_ready(self, channel_name: str, msg: list[bytes]) -> bool: + """Queue a message if connection is not ready. Returns True if queued.""" + if not self._connection_ready: + if len(self._queued_messages) < self._max_queue_size: + self._queued_messages.append((channel_name, msg)) + return True + else: + # Queue is full, drop oldest message + self._queued_messages.pop(0) + self._queued_messages.append((channel_name, msg)) + self.log.warning("Message queue full, dropping oldest message") + return True + return False + + def _send_message(self, channel_name: str, msg: list[bytes]): + # Route message to the appropriate kernel channel + try: + channel = getattr(self, f"{channel_name}_channel", None) + channel.session.send_raw(channel.socket, msg) + + except Exception as e: + self.log.warn("Error handling incoming message.") + + def handle_incoming_message(self, channel_name: str, msg: list[bytes]): + """Handle incoming kernel messages and encode channel in msg_id. + + This method processes incoming kernel messages from WebSocket clients. + It prepends the channel name to the msg_id for internal routing. + + Args: + channel_name: The channel the message came from ('shell', 'iopub', etc.) + msg: The raw message bytes (already deserialized from websocket format) + """ + # Validate message has content + if not msg or len(msg) == 0: + return + + # Prepend channel to msg_id for internal routing + try: + header = self.session.unpack(msg[0]) + msg_id = header["msg_id"] + + # Check if msg_id already has channel encoded + if not msg_id.startswith(f"{channel_name}:"): + # Prepend channel + header["msg_id"] = f"{channel_name}:{msg_id}" + msg[0] = self.session.pack(header) + + except Exception as e: + self.log.debug(f"Error encoding channel in incoming message ID: {e}") + + # If connection is not ready, queue the message + if self._queue_message_if_not_ready(channel_name, msg): + return + + self._send_message(channel_name, msg) + + + def handle_outgoing_message(self, channel_name: str, msg: list[bytes]): + """Public API for manufacturing messages to send to kernel client listeners. + + This allows external code to simulate kernel messages and send them to all + registered listeners, useful for testing and message injection. + + Args: + channel_name: The channel the message came from ('shell', 'iopub', etc.) + msg: The raw message bytes + """ + # Same as handle_incoming_message - route to all listeners + asyncio.create_task(self._route_to_listeners(channel_name, msg)) + + async def _route_to_listeners(self, channel_name: str, msg: list[bytes]): + """Route message to all registered listeners based on their filters.""" + if not self._listeners: + return + + # Validate message format before routing + if not msg or len(msg) < 4: + self.log.warning(f"Cannot route malformed message on {channel_name}: {len(msg) if msg else 0} parts (expected at least 4)") + return + + # Extract message type for filtering + msg_type = None + try: + header = self.session.unpack(msg[0]) if msg and len(msg) > 0 else {} + msg_type = header.get('msg_type', 'unknown') + except Exception as e: + self.log.debug(f"Error extracting message type: {e}") + msg_type = 'unknown' + + # Create tasks for listeners that match the filter + tasks = [] + for listener, filter_config in self._listeners.items(): + if self._should_route_to_listener(msg_type, channel_name, filter_config): + task = asyncio.create_task(self._call_listener(listener, channel_name, msg)) + tasks.append(task) + + # Wait for all listeners to complete + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + def _should_route_to_listener(self, msg_type: str, channel_name: str, filter_config: dict) -> bool: + """Determine if a message should be routed to a listener based on its filter configuration. + + Args: + msg_type: The message type (e.g., "status", "execute_reply") + channel_name: The channel name (e.g., "iopub", "shell") + filter_config: Dictionary with 'msg_types' and 'exclude_msg_types' keys + + Returns: + bool: True if the message should be routed to the listener, False otherwise + """ + msg_types = filter_config.get('msg_types') + exclude_msg_types = filter_config.get('exclude_msg_types') + + # If msg_types is specified (inclusion filter) + if msg_types is not None: + return (msg_type, channel_name) in msg_types + + # If exclude_msg_types is specified (exclusion filter) + if exclude_msg_types is not None: + return (msg_type, channel_name) not in exclude_msg_types + + # No filter specified - route all messages + return True + + async def _call_listener(self, listener: t.Callable, channel_name: str, msg: list[bytes]): + """Call a single listener, ensuring it's async and handling errors.""" + try: + result = listener(channel_name, msg) + if asyncio.iscoroutine(result): + await result + except Exception as e: + self.log.error(f"Error in listener {listener}: {e}") + + def _update_execution_state_from_status(self, channel_name: str, msg_dict: dict, parent_msg_id: str = None, execution_state: str = None): + """Update execution state from a status message if it originated from shell channel. + + This method checks if a status message on the iopub channel originated from a shell + channel request before updating the execution state. This prevents control channel + status messages from affecting execution state tracking. + + Additionally tracks the last time we received status messages from shell and control + channels for connection monitoring purposes. + + Args: + channel_name: The channel the message came from (should be 'iopub') + msg_dict: The deserialized message dictionary + parent_msg_id: Optional parent message ID (extracted if not provided) + execution_state: Optional execution state (extracted if not provided) + """ + if channel_name != "iopub" or msg_dict.get("msg_type") != "status": + return + + try: + # Extract parent_msg_id if not provided + if parent_msg_id is None: + parent_header = msg_dict.get("parent_header", {}) + if isinstance(parent_header, bytes): + parent_header = self.session.unpack(parent_header) + parent_msg_id = parent_header.get("msg_id") + + # Parse parent_msg_id to extract channel + if parent_msg_id: + try: + parent_channel, _, _ = parse_msg_id(parent_msg_id) + except Exception as e: + self.log.debug(f"Error parsing parent msg_id '{parent_msg_id}': {e}") + parent_channel = None + + # Track last status message time for both shell and control channels + current_time = datetime.now(timezone.utc) + if parent_channel == "shell": + self.last_shell_status_time = current_time + self.last_activity = current_time + elif parent_channel == "control": + self.last_control_status_time = current_time + + # Only update execution state if message came from shell channel + if parent_channel == "shell": + # Extract execution_state if not provided + if execution_state is None: + content = msg_dict.get("content", {}) + if isinstance(content, bytes): + content = self.session.unpack(content) + execution_state = content.get("execution_state") + + if execution_state: + old_state = self.execution_state + self.execution_state = execution_state + self.log.debug(f"Execution state: {old_state} -> {execution_state}") + elif parent_channel is None: + # Log when we can't determine parent channel + if execution_state is None: + content = msg_dict.get("content", {}) + if isinstance(content, bytes): + content = self.session.unpack(content) + execution_state = content.get("execution_state") + self.log.debug(f"Ignoring status message - cannot parse parent channel (state would be: {execution_state})") + except Exception as e: + self.log.debug(f"Error updating execution state from status message: {e}") + + async def broadcast_state(self): + """Broadcast current kernel execution state to all listeners. + + This method creates and sends a status message to all kernel listeners + (typically WebSocket connections) to inform them of the current kernel + execution state. + + The status message is manufactured using the session's message format + and sent through the normal listener routing mechanism. + + Note: Only broadcasts if execution_state is a valid kernel protocol state. + Skips broadcasting if state is "unknown" (not part of kernel protocol). + """ + try: + # Don't broadcast "unknown" state - it's not part of the kernel protocol + # Valid states are: starting, idle, busy, restarting, dead + if self.execution_state == ExecutionStates.UNKNOWN.value: + self.log.debug("Skipping broadcast_state - execution state is unknown") + return + + # Create status message + msg_dict = self.session.msg( + "status", + content={"execution_state": self.execution_state} + ) + + # Serialize the message + # session.serialize() returns: + # [b'', signature, header, parent_header, metadata, content, buffers...] + serialized = self.session.serialize(msg_dict) + + # Skip delimiter (index 0) and signature (index 1) to get message parts + # Result: [header, parent_header, metadata, content, buffers...] + if len(serialized) < 6: # Need delimiter + signature + 4 message parts minimum + self.log.warning(f"broadcast_state: serialized message too short: {len(serialized)} parts") + return + + msg_parts = serialized[2:] # Skip delimiter and signature + + # Send to listeners + self.handle_outgoing_message("iopub", msg_parts) + + except Exception as e: + self.log.warning(f"Failed to broadcast state: {e}") + + async def start_listening(self): + """Start listening for messages and monitoring channels.""" + # Start background tasks to monitor channels for messages + self._monitoring_tasks = [] + self._listening = True + + # Monitor each channel for incoming messages + for channel_name in ['iopub', 'shell', 'stdin', 'control']: + channel = getattr(self, f"{channel_name}_channel", None) + if channel and channel.is_alive(): + task = asyncio.create_task(self._monitor_channel_messages(channel_name, channel)) + self._monitoring_tasks.append(task) + + self.log.info(f"Started listening with {len(self._listeners)} listeners") + + async def stop_listening(self): + """Stop listening for messages.""" + # Stop monitoring tasks + if hasattr(self, '_monitoring_tasks'): + for task in self._monitoring_tasks: + task.cancel() + self._monitoring_tasks = [] + + self.log.info(f"Stopped listening") + + async def _monitor_channel_messages(self, channel_name: str, channel: ChannelABC): + """Monitor a channel for incoming messages and route them to listeners.""" + try: + while channel.is_alive(): + try: + # Check if there's a message ready (non-blocking) + has_message = await channel.msg_ready() + if has_message: + msg = await channel.socket.recv_multipart() + + # For deserialization and state tracking, use feed_identities to strip routing frames + idents, msg_list = channel.session.feed_identities(msg) + + # Deserialize WITHOUT content for performance (content=False) + msg_dict = channel.session.deserialize(msg_list, content=False) + + # Update execution state from status messages + self._update_execution_state_from_status(channel_name, msg_dict) + + # Route to listeners with msg_list + # After feed_identities, msg_list has format (delimiter already removed): + # [signature, header, parent_header, metadata, content, ...buffers] + # Skip signature (index 0) to get: [header, parent_header, metadata, content, ...buffers] + if msg_list and len(msg_list) >= 5: + await self._route_to_listeners(channel_name, msg_list[1:]) + else: + self.log.warning(f"Received malformed message on {channel_name}: {len(msg_list) if msg_list else 0} parts") + + except Exception as e: + # Log the error instead of silently ignoring it + self.log.debug(f"Error processing message in {channel_name}: {e}") + continue # Continue with next message instead of breaking + + # Small sleep to avoid busy waiting + await asyncio.sleep(0.01) + + except asyncio.CancelledError: + pass + except Exception as e: + self.log.error(f"Channel monitoring failed for {channel_name}: {e}") + + async def _test_kernel_communication(self, timeout: float = None) -> bool: + """Test kernel communication by monitoring execution state and sending kernel_info requests. + + This method uses a robust heuristic to determine if the kernel is connected: + 1. Checks if execution state is 'idle' (indicates shell channel is responding) + 2. Sends kernel_info requests to both shell and control channels in parallel + 3. Monitors for status message responses from either channel + 4. Retries periodically if no response is received + 5. Considers kernel connected if we receive any status messages, even if state is 'busy' + + Args: + timeout: Total timeout for connection test in seconds (uses connection_test_timeout if not provided) + + Returns: + bool: True if communication test successful, False otherwise + """ + if timeout is None: + timeout = self.connection_test_timeout + + start_time = time.time() + connection_attempt_time = datetime.now(timezone.utc) + + self.log.info("Starting kernel communication test") + + # Give the kernel a moment to be ready to receive messages + # Heartbeat beating doesn't guarantee the kernel is ready for requests + await asyncio.sleep(0.5) + + # Send initial kernel_info requests immediately + try: + await asyncio.gather( + self._send_kernel_info_shell(), + self._send_kernel_info_control(), + return_exceptions=True + ) + except Exception as e: + self.log.debug(f"Error sending initial kernel_info requests: {e}") + + last_kernel_info_time = time.time() + + while time.time() - start_time < timeout: + elapsed = time.time() - start_time + + # Check if execution state is idle (shell channel responding and kernel ready) + if self.execution_state == ExecutionStates.IDLE.value: + self.log.info("Kernel communication test succeeded: execution state is idle") + return True + + # Check if we've received any status messages since connection attempt + # This indicates the kernel is connected, even if busy executing something + if self.last_shell_status_time and self.last_shell_status_time > connection_attempt_time: + self.log.info("Kernel communication test succeeded: received shell status message") + return True + + if self.last_control_status_time and self.last_control_status_time > connection_attempt_time: + self.log.info("Kernel communication test succeeded: received control status message") + return True + + # Send kernel_info requests at regular intervals + time_since_last_request = time.time() - last_kernel_info_time + if time_since_last_request >= self.connection_test_retry_interval: + self.log.debug(f"Sending kernel_info requests to shell and control channels (elapsed: {elapsed:.1f}s)") + + try: + # Send kernel_info to both channels in parallel (no reply expected) + await asyncio.gather( + self._send_kernel_info_shell(), + self._send_kernel_info_control(), + return_exceptions=True + ) + last_kernel_info_time = time.time() + except Exception as e: + self.log.debug(f"Error sending kernel_info requests: {e}") + + # Wait before next check + await asyncio.sleep(self.connection_test_check_interval) + + self.log.error(f"Kernel communication test failed: no response after {timeout}s") + return False + + async def _send_kernel_info_shell(self): + """Send kernel_info request on shell channel (no reply expected).""" + try: + if hasattr(self, 'kernel_info'): + # Send without waiting for reply + self.kernel_info() + except Exception as e: + self.log.debug(f"Error sending kernel_info on shell channel: {e}") + + async def _send_kernel_info_control(self): + """Send kernel_info request on control channel (no reply expected).""" + try: + if hasattr(self.control_channel, 'send'): + msg = self.session.msg('kernel_info_request') + # Channel wrapper will automatically encode channel in msg_id + self.control_channel.send(msg) + except Exception as e: + self.log.debug(f"Error sending kernel_info on control channel: {e}") + + async def connect(self) -> bool: + """Connect to the kernel and verify communication. + + This method: + 1. Starts all channels + 2. Begins listening for messages + 3. Waits for heartbeat to confirm connectivity + 4. Tests kernel communication with configurable retries + 5. Marks connection as ready + + Returns: + bool: True if connection successful, False otherwise + """ + if self._connecting: + return await self.wait_for_connection_ready() + + if self._connection_ready: + return True + + self._connecting = True + + try: + self.execution_state = ExecutionStates.BUSY.value + self.last_activity = datetime.now(timezone.utc) + + # Handle both sync and async versions of start_channels + result = self.start_channels() + if asyncio.iscoroutine(result): + await result + + # Verify channels are running. + assert self.channels_running + + # Start our listening + await self.start_listening() + + # Unpause heartbeat channel if method exists + if hasattr(self.hb_channel, 'unpause'): + self.hb_channel.unpause() + + # Wait for heartbeat + attempt = 0 + max_attempts = 10 + while not self.hb_channel.is_beating(): + attempt += 1 + if attempt > max_attempts: + raise Exception("The kernel took too long to connect to the Kernel Sockets.") + await asyncio.sleep(0.1) + + # Test kernel communication (handles retries internally) + if not await self._test_kernel_communication(): + self.log.error(f"Kernel communication test failed after {self.connection_test_timeout}s timeout") + return False + + # Mark connection as ready and process queued messages + self.mark_connection_ready() + + # Update execution state to idle if it's not already set + # (it might already be idle if we received a status message during connection test) + if self.execution_state == ExecutionStates.BUSY.value: + self.execution_state = ExecutionStates.IDLE.value + self.last_activity = datetime.now(timezone.utc) + + self.log.info("Successfully connected to kernel") + return True + + except Exception as e: + self.log.error(f"Failed to connect to kernel: {e}") + self._connecting = False + return False + + async def disconnect(self): + """Disconnect from the kernel and reset connection state. + + This method: + 1. Stops listening for messages + 2. Stops all channels + 3. Resets connection state flags + 4. Clears channel references + + Note: Does not remove listeners - they will be preserved for reconnection. + """ + # Stop listening for messages + await self.stop_listening() + + # Stop all channels + self.stop_channels() + + # Reset connection state + self._connecting = False + self._connection_ready = False + self._connection_ready_event.clear() + + # Clear channel references + self._shell_channel = None + self._iopub_channel = None + self._stdin_channel = None + self._control_channel = None + self._hb_channel = None + + self.log.info("Disconnected from kernel") + + async def reconnect(self) -> bool: + """Reconnect to the kernel. + + This is a convenience method that disconnects and then connects again. + Useful for recovering from stale connections or network issues. + + Returns: + bool: True if reconnection successful, False otherwise + """ + self.log.info("Reconnecting to kernel...") + await self.disconnect() + return await self.connect() + + +class JupyterServerKernelClient(JupyterServerKernelClientMixin, AsyncKernelClient): + """ + A kernel client with listener functionality and message queuing. + """ diff --git a/jupyter_server/services/kernels/v3/connection/__init__.py b/jupyter_server/services/kernels/v3/connection/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jupyter_server/services/kernels/v3/connection/client_connection.py b/jupyter_server/services/kernels/v3/connection/client_connection.py new file mode 100644 index 0000000000..afa336242d --- /dev/null +++ b/jupyter_server/services/kernels/v3/connection/client_connection.py @@ -0,0 +1,195 @@ +from tornado.websocket import WebSocketClosedError +from traitlets import List as TraitletsList, Tuple as TraitletsTuple +from jupyter_server.services.kernels.connection.base import ( + BaseKernelWebsocketConnection, +) +from jupyter_server.services.kernels.connection.base import deserialize_msg_from_ws_v1, serialize_msg_to_ws_v1 +from jupyter_client.session import Session +from ..message_utils import encode_cell_id_in_message, strip_encoding_from_message + + +class KernelClientWebsocketConnection(BaseKernelWebsocketConnection): + """WebSocket connection that bridges frontend clients to the shared kernel client. + + This class implements the WebSocket side of the shared kernel client architecture. + Instead of creating its own ZMQ connections to the kernel, it registers as a listener + on the kernel manager's shared kernel client and routes messages bidirectionally. + + Key Responsibilities: + + 1. **Listener Registration**: + - Registers itself as a message listener on the shared kernel client + - Receives all kernel messages through the listener callback + - Automatically removed when the WebSocket disconnects + + 2. **Cell ID Encoding** (Outgoing to Kernel): + - Extracts `cellId` from message metadata + - Appends cell ID to message ID: `msg_id#{cell_id}` + - Enables routing responses back to the originating cell + - Works in conjunction with kernel client's channel encoding + + 3. **Message ID Decoding** (Incoming from Kernel): + - Strips both channel and cell ID encoding from message IDs + - Returns messages to frontend with original msg_ids + - Frontend receives messages in the format it expects + + 4. **Message Filtering** (Configurable): + - Supports filtering messages by type and channel + - Can include specific message types (`msg_types` trait) + - Can exclude specific message types (`exclude_msg_types` trait) + - Reduces bandwidth for specialized clients + """ + + kernel_ws_protocol = "v1.kernel.websocket.jupyter.org" + + # Configurable message filtering traits + msg_types = TraitletsList( + trait=TraitletsTuple(), + default_value=None, + allow_none=True, + config=True, + help=""" + List of (msg_type, channel) tuples to include for this websocket connection. + If None (default), all messages are sent. If specified, only messages matching + these (msg_type, channel) pairs will be sent to the websocket. + + Example: [("status", "iopub"), ("execute_reply", "shell")] + """ + ) + + exclude_msg_types = TraitletsList( + trait=TraitletsTuple(), + default_value=None, + allow_none=True, + config=True, + help=""" + List of (msg_type, channel) tuples to exclude for this websocket connection. + If None (default), no messages are excluded. If specified, messages matching + these (msg_type, channel) pairs will NOT be sent to the websocket. + + Example: [("status", "iopub")] + + Note: Cannot be used together with msg_types. If both are specified, + msg_types takes precedence. + """ + ) + + def _get_kernel_client(self): + """Get the kernel client directly from the kernel manager. + + The kernel client is now a property on the kernel manager itself, + created immediately when the kernel manager is instantiated. + + Note: self.kernel_manager is actually the parent, which is the specific + KernelManager instance for this kernel (not the MultiKernelManager). + """ + try: + # self.kernel_manager is the specific KernelManager for this kernel + km = self.kernel_manager + if not km: + raise RuntimeError(f"No kernel manager found for kernel {self.kernel_id}") + + # Get the pre-created kernel client from the kernel manager + if not hasattr(km, 'kernel_client') or km.kernel_client is None: + raise RuntimeError(f"Kernel manager for {self.kernel_id} has no kernel_client") + + return km.kernel_client + + except Exception as e: + raise RuntimeError(f"Failed to get kernel client for kernel {self.kernel_id}: {e}") + + async def connect(self): + """Connect to the kernel via a kernel session with deferred channel connection. + + The client connection is now handled by the kernel manager in post_start_kernel(). + The websocket just needs to add itself as a listener to receive messages. + """ + # Get the client from the kernel manager + client = self._get_kernel_client() + + # Add websocket listener immediately (messages will be queued if not ready) + # Use configured message filtering if specified + if self.msg_types is not None: + # Convert list of tuples to list for the API + msg_types_list = [tuple(item) for item in self.msg_types] if self.msg_types else None + client.add_listener(self.handle_outgoing_message, msg_types=msg_types_list) + elif self.exclude_msg_types is not None: + # Convert list of tuples to list for the API + exclude_msg_types_list = [tuple(item) for item in self.exclude_msg_types] if self.exclude_msg_types else None + client.add_listener(self.handle_outgoing_message, exclude_msg_types=exclude_msg_types_list) + else: + # No filtering - listen to all messages (default) + client.add_listener(self.handle_outgoing_message) + + # Broadcast current kernel state to this websocket immediately + # This ensures websockets that connect during/after restart get the current state + await client.broadcast_state() + + self.log.info(f"Kernel websocket connected and listening for kernel {self.kernel_id}") + + def disconnect(self): + """Disconnect the websocket from the kernel client.""" + try: + # Get the kernel client from the kernel manager + client = self._get_kernel_client() + if client: + # Remove this websocket's listener from the client + client.remove_listener(self.handle_outgoing_message) + except Exception as e: + self.log.warning(f"Failed to disconnect websocket for kernel {self.kernel_id}: {e}") + + def handle_incoming_message(self, incoming_msg): + """Handle incoming messages from WebSocket, encoding cellId into msg_id.""" + channel_name, msg_list = deserialize_msg_from_ws_v1(incoming_msg) + + try: + # Get the kernel client from the kernel manager + client = self._get_kernel_client() + if not client: + return + + # Extract cellId from metadata and encode into msg_id + try: + if len(msg_list) >= 3: # Need header, parent_header, metadata + session = Session() + metadata = session.unpack(msg_list[2]) + cell_id = metadata.get("cellId") + + if cell_id: + msg_list = encode_cell_id_in_message(msg_list, cell_id) + except Exception as e: + self.log.debug(f"Error encoding cellId in msg_id: {e}") + + # Send to kernel client (which will prepend channel) + client.handle_incoming_message(channel_name, msg_list) + except Exception as e: + self.log.error(f"Failed to handle incoming message for kernel {self.kernel_id}: {e}") + + def handle_outgoing_message(self, channel_name, msg): + """Handle outgoing messages to WebSocket, stripping channel and cellId from msg_id.""" + try: + # Validate message has minimum required parts + if not msg or len(msg) < 4: + self.log.warning(f"Message on {channel_name} has insufficient parts: {len(msg) if msg else 0}") + return + + # Validate parts are bytes + for i, part in enumerate(msg[:4]): + if not isinstance(part, bytes): + self.log.error(f"Message part {i} on {channel_name} is not bytes: {type(part)}") + return + + # Strip channel and cellId from msg_ids before sending to frontend + try: + msg = strip_encoding_from_message(msg) + except Exception as e: + self.log.debug(f"Error stripping encoding from msg_ids: {e}") + # Continue with original message if stripping fails + + # Serialize to websocket format and send + bin_msg = serialize_msg_to_ws_v1(msg, channel_name) + self.websocket_handler.write_message(bin_msg, binary=True) + except WebSocketClosedError: + self.log.warning("A Kernel Socket message arrived on a closed websocket channel.") + except Exception as err: + self.log.error(f"Error handling outgoing message on {channel_name}: {err}", exc_info=True) \ No newline at end of file diff --git a/jupyter_server/services/kernels/v3/kernelmanager.py b/jupyter_server/services/kernels/v3/kernelmanager.py new file mode 100644 index 0000000000..6cbbc80272 --- /dev/null +++ b/jupyter_server/services/kernels/v3/kernelmanager.py @@ -0,0 +1,120 @@ +"""Kernel manager for the Apple JupyterLab Kernel Monitor Extension.""" + +from jupyter_client.multikernelmanager import AsyncMultiKernelManager +from traitlets import Type, observe, Instance, default + +from jupyter_server.services.kernels.kernelmanager import ( + MappingKernelManager, + ServerKernelManager as _ServerKernelManager, +) + +from .client import JupyterServerKernelClient + + +class ServerKernelManager(_ServerKernelManager): + """Kernel manager with enhanced client. + + This kernel manager inherits from ServerKernelManager and adds: + - Enhanced kernel client (JupyterServerKernelClient) with message ID encoding + - Pre-created kernel client instance stored as a property + - Automatic client connection/disconnection on kernel start/shutdown + + The client encodes channel information in message IDs using simple string operations. + """ + + client_class = Type( + default_value=JupyterServerKernelClient, + klass='jupyter_client.client.KernelClient', + config=True, + help="""The kernel client class to use for creating kernel clients.""" + ) + + client_factory = Type( + default_value=JupyterServerKernelClient, + klass='jupyter_client.client.KernelClient', + config=True, + help="""The kernel client factory class to use.""" + ) + + kernel_client = Instance( + 'jupyter_client.client.KernelClient', + allow_none=True, + help="""Pre-created kernel client instance. Created on initialization.""" + ) + + def __init__(self, **kwargs): + """Initialize the kernel manager and create a kernel client instance.""" + super().__init__(**kwargs) + + # Create a kernel client instance immediately + self.kernel_client = self.client(session=self.session) + + @observe('client_class') + def _client_class_changed(self, change): + """Override parent's _client_class_changed to handle Type trait instead of DottedObjectName.""" + # Set client_factory to the same class + self.client_factory = change['new'] + + async def _async_post_start_kernel(self, **kwargs): + """After kernel starts, connect the kernel client. + + This method is called after the kernel has been successfully started. + It loads the latest connection info (with ports set by provisioner) + and connects the kernel client to the kernel. + + Note: If you override this method, make sure to call super().post_start_kernel(**kwargs) + to ensure the kernel client connects properly. + """ + await super()._async_post_start_kernel(**kwargs) + try: + # Load latest connection info from kernel manager + # The provisioner has now set the real ports + self.kernel_client.load_connection_info(self.get_connection_info(session=True)) + + # Connect the kernel client + success = await self.kernel_client.connect() + + if not success: + raise RuntimeError(f"Failed to connect kernel client for kernel {self.kernel_id}") + + self.log.info(f"Successfully connected kernel client for kernel {self.kernel_id}") + + except Exception as e: + self.log.error(f"Failed to connect kernel client: {e}") + # Re-raise to fail the kernel start + raise + + async def cleanup_resources(self, restart=False): + """Cleanup resources, disconnecting the kernel client if not restarting. + + Parameters + ---------- + restart : bool + If True, the kernel is being restarted and we should keep the client + connected but clear its state. If False, fully disconnect. + """ + if self.kernel_client: + if restart: + # On restart, clear client state but keep connection + # The connection will be refreshed in post_start_kernel after restart + self.log.debug(f"Clearing kernel client state for restart of kernel {self.kernel_id}") + self.kernel_client.last_shell_status_time = None + self.kernel_client.last_control_status_time = None + # Disconnect before restart - will reconnect after + await self.kernel_client.stop_listening() + self.kernel_client.stop_channels() + else: + # On shutdown, fully disconnect the client + self.log.debug(f"Disconnecting kernel client for kernel {self.kernel_id}") + await self.kernel_client.stop_listening() + self.kernel_client.stop_channels() + + await super().cleanup_resources(restart=restart) + + +class AsyncMappingKernelManager(MappingKernelManager, AsyncMultiKernelManager): # type:ignore[misc] + """Custom kernel manager that uses enhanced monitoring kernel manager with v3 API.""" + + @default("kernel_manager_class") + def _default_kernel_manager_class(self): + return "jupyter_server.services.kernels.v3.kernelmanager.ServerKernelManager" \ No newline at end of file diff --git a/jupyter_server/services/kernels/v3/message_utils.py b/jupyter_server/services/kernels/v3/message_utils.py new file mode 100644 index 0000000000..c775e38d25 --- /dev/null +++ b/jupyter_server/services/kernels/v3/message_utils.py @@ -0,0 +1,330 @@ +"""Utilities for encoding and decoding channel and source ID information in message IDs. + +This module provides functions to encode channel names and source IDs (like cell IDs) +directly into message IDs, eliminating the need for a separate message cache to track +message metadata. + +Format: {channel}:{base_msg_id}#{src_id} + +Examples: + - With channel and src_id: "shell:a1b2c3d4_12345_0#cell-abc123" + - With channel only: "shell:a1b2c3d4_12345_0" + - Legacy format (no encoding): "a1b2c3d4_12345_0" +""" + +from typing import Optional, Tuple, List +from jupyter_client.session import Session + + +class MsgIdError(Exception): + """Base exception for message ID operations.""" + pass + + +class InvalidMsgIdFormatError(MsgIdError): + """Raised when a message ID has an invalid format.""" + pass + + +class InvalidChannelError(MsgIdError): + """Raised when a channel name contains reserved characters.""" + pass + + +class InvalidSrcIdError(MsgIdError): + """Raised when a source ID contains reserved characters.""" + pass + + +def validate_channel(channel: Optional[str]) -> None: + """Validate that a channel name doesn't contain reserved characters. + + Args: + channel: Channel name to validate + + Raises: + InvalidChannelError: If channel contains ':' character + """ + if channel is not None and ':' in channel: + raise InvalidChannelError( + f"Channel name cannot contain ':' character: {channel}" + ) + + +def validate_src_id(src_id: Optional[str]) -> None: + """Validate that a source ID doesn't contain reserved characters. + + Args: + src_id: Source ID to validate + + Raises: + InvalidSrcIdError: If src_id contains ':' or '#' characters + """ + if src_id is not None: + if ':' in src_id: + raise InvalidSrcIdError( + f"Source ID cannot contain ':' character: {src_id}" + ) + if '#' in src_id: + raise InvalidSrcIdError( + f"Source ID cannot contain '#' character: {src_id}" + ) + + +def create_msg_id( + base_msg_id: str, + channel: Optional[str] = None, + src_id: Optional[str] = None +) -> str: + """Create a structured message ID with optional channel and source ID encoding. + + Args: + base_msg_id: Core message ID (e.g., "{session}_{pid}_{counter}") + channel: Optional channel name (shell, control, etc.) + src_id: Optional source identifier (e.g., cell ID) + + Returns: + Formatted message ID string + + Raises: + InvalidChannelError: If channel contains ':' + InvalidSrcIdError: If src_id contains ':' or '#' + + Examples: + >>> create_msg_id("abc123_456_0", "shell", "cell-xyz") + 'shell:abc123_456_0#cell-xyz' + >>> create_msg_id("abc123_456_1", "control") + 'control:abc123_456_1' + >>> create_msg_id("abc123_456_2") + 'abc123_456_2' + """ + validate_channel(channel) + validate_src_id(src_id) + + if channel is None: + # Legacy format for backward compatibility + result = base_msg_id + else: + result = f"{channel}:{base_msg_id}" + + if src_id is not None: + result = f"{result}#{src_id}" + + return result + + +def parse_msg_id(msg_id: str) -> Tuple[Optional[str], Optional[str], str]: + """Parse a message ID into its components. + + Args: + msg_id: Message ID string to parse + + Returns: + Tuple of (channel, src_id, base_msg_id) where: + - channel: Channel name ('shell', 'control', etc.) or None for legacy format + - src_id: Source identifier (e.g., cell ID) or None + - base_msg_id: Core message ID + + Examples: + >>> parse_msg_id("shell:abc123_456_0#cell-xyz") + ('shell', 'cell-xyz', 'abc123_456_0') + >>> parse_msg_id("control:abc123_456_1") + ('control', None, 'abc123_456_1') + >>> parse_msg_id("abc123_456_2") + (None, None, 'abc123_456_2') + """ + if not msg_id: + raise InvalidMsgIdFormatError("Message ID cannot be empty") + + # Split off src_id if present (after #) + if '#' in msg_id: + msg_id_part, src_id = msg_id.split('#', 1) + else: + msg_id_part = msg_id + src_id = None + + # Split channel and base msg_id (before :) + if ':' in msg_id_part: + channel, base_msg_id = msg_id_part.split(':', 1) + else: + # Legacy format - no channel specified + channel = None + base_msg_id = msg_id_part + + return channel, src_id, base_msg_id + + +def extract_channel(msg_id: str) -> Optional[str]: + """Extract just the channel from a message ID. + + Args: + msg_id: Message ID string + + Returns: + Channel name or None if not present + + Examples: + >>> extract_channel("shell:abc123_456_0#cell-xyz") + 'shell' + >>> extract_channel("abc123_456_2") + None + """ + channel, _, _ = parse_msg_id(msg_id) + return channel + + +def extract_src_id(msg_id: str) -> Optional[str]: + """Extract just the source ID from a message ID. + + Args: + msg_id: Message ID string + + Returns: + Source ID or None if not present + + Examples: + >>> extract_src_id("shell:abc123_456_0#cell-xyz") + 'cell-xyz' + >>> extract_src_id("shell:abc123_456_1") + None + """ + _, src_id, _ = parse_msg_id(msg_id) + return src_id + + +def extract_base_msg_id(msg_id: str) -> str: + """Extract just the base message ID from a message ID. + + Args: + msg_id: Message ID string + + Returns: + Base message ID (without channel or src_id encoding) + + Examples: + >>> extract_base_msg_id("shell:abc123_456_0#cell-xyz") + 'abc123_456_0' + >>> extract_base_msg_id("abc123_456_2") + 'abc123_456_2' + """ + _, _, base_msg_id = parse_msg_id(msg_id) + return base_msg_id + + +# ============================================================================ +# Message-level utilities for websocket connection +# ============================================================================ + + +def encode_channel_in_message_dict(msg_dict: dict, channel: str) -> dict: + """Encode channel into a message dict's header msg_id. + + This utility is used for client-initiated messages (not from websocket) + to ensure they have channel encoding for proper state tracking. + + Args: + msg_dict: Message dictionary with header, parent_header, metadata, content + channel: Channel name to encode ('shell', 'control', etc.) + + Returns: + Modified message dict with channel encoded in header msg_id + + Examples: + >>> msg = session.msg('kernel_info_request') + >>> msg = encode_channel_in_message_dict(msg, 'shell') + >>> # msg['header']['msg_id'] now has 'shell:' prefix + """ + if 'header' in msg_dict and 'msg_id' in msg_dict['header']: + msg_id = msg_dict['header']['msg_id'] + # Only encode if not already encoded + if not msg_id.startswith(f"{channel}:"): + msg_dict['header']['msg_id'] = f"{channel}:{msg_id}" + return msg_dict + + +def encode_cell_id_in_message(msg_list: List[bytes], cell_id: str) -> List[bytes]: + """Encode a cell ID into the header msg_id of a message. + + This utility function encapsulates the session pack/unpack operations needed + to add a cell ID to a message's header. It's designed to keep the websocket + connection code lean. + + Args: + msg_list: Message parts as list of bytes [header, parent_header, metadata, content, ...] + cell_id: Cell ID to encode into the message + + Returns: + Modified message list with cell ID encoded in header msg_id + + Examples: + >>> # msg_list with msg_id "abc123" becomes msg_id "abc123#cell-xyz" + >>> modified_msg = encode_cell_id_in_message(msg_list, "cell-xyz") + """ + # Need at least header part + if not msg_list or len(msg_list) < 1: + return msg_list + + try: + session = Session() + msg_copy = list(msg_list) # Make a copy to avoid modifying original + + # Unpack header + header = session.unpack(msg_copy[0]) + + # Encode cell ID into msg_id if not already present + if "msg_id" in header: + msg_id = header["msg_id"] + if "#" not in msg_id: # Only add if not already encoded + header["msg_id"] = f"{msg_id}#{cell_id}" + msg_copy[0] = session.pack(header) + + return msg_copy + except Exception: + # If encoding fails, return original message + return msg_list + + +def strip_encoding_from_message(msg_list: List[bytes]) -> List[bytes]: + """Strip channel and cell ID encoding from header and parent_header msg_ids. + + This utility function encapsulates the session pack/unpack operations needed + to strip encoding from a message before sending to the frontend. It's designed + to keep the websocket connection code lean. + + Args: + msg_list: Message parts as list of bytes [header, parent_header, metadata, content, ...] + + Returns: + Modified message list with encoding stripped from msg_ids + + Examples: + >>> # msg_list with msg_id "shell:abc123#cell-xyz" becomes "abc123" + >>> clean_msg = strip_encoding_from_message(msg_list) + """ + # Need at least header and parent_header + if not msg_list or len(msg_list) < 2: + return msg_list + + try: + session = Session() + msg_copy = list(msg_list) # Make a copy to avoid modifying original + + # Strip from header msg_id + header = session.unpack(msg_copy[0]) + if 'msg_id' in header: + _, _, base_msg_id = parse_msg_id(header['msg_id']) + header['msg_id'] = base_msg_id + msg_copy[0] = session.pack(header) + + # Strip from parent_header msg_id + parent_header = session.unpack(msg_copy[1]) + if 'msg_id' in parent_header and parent_header['msg_id']: + _, _, base_msg_id = parse_msg_id(parent_header['msg_id']) + parent_header['msg_id'] = base_msg_id + msg_copy[1] = session.pack(parent_header) + + return msg_copy + except Exception: + # If decoding fails, return original message + return msg_list + diff --git a/jupyter_server/services/kernels/v3/states.py b/jupyter_server/services/kernels/v3/states.py new file mode 100644 index 0000000000..6d6b28ba6d --- /dev/null +++ b/jupyter_server/services/kernels/v3/states.py @@ -0,0 +1,22 @@ +from enum import Enum +from enum import EnumMeta + + +class StrContainerEnumMeta(EnumMeta): + def __contains__(cls, item): + for name, member in cls.__members__.items(): + if item == name or item == member.value: + return True + return False + +class StrContainerEnum(str, Enum, metaclass=StrContainerEnumMeta): + """A Enum object that enables search for items + in a normal Enum object based on key and value. + """ + +class ExecutionStates(StrContainerEnum): + BUSY = "busy" + IDLE = "idle" + STARTING = "starting" + UNKNOWN = "unknown" + DEAD = "dead" \ No newline at end of file From 2100b20b5c5582c311282557943ef248c8e0b762 Mon Sep 17 00:00:00 2001 From: Zach Sailer Date: Mon, 17 Nov 2025 09:28:23 -0800 Subject: [PATCH 2/3] Improve code formatting and consistency in v3 kernel client - Fix import ordering and add blank lines for better readability - Improve code formatting with consistent spacing and line breaks - Remove unused exception variable and fix minor style issues - Ensure consistent inheritance for channel classes - Update type annotations to use modern Python syntax --- jupyter_server/services/kernels/v3/client.py | 117 ++++++++++++------- 1 file changed, 74 insertions(+), 43 deletions(-) diff --git a/jupyter_server/services/kernels/v3/client.py b/jupyter_server/services/kernels/v3/client.py index fbd11342d7..ecf05e9aa2 100644 --- a/jupyter_server/services/kernels/v3/client.py +++ b/jupyter_server/services/kernels/v3/client.py @@ -3,36 +3,41 @@ import typing as t from datetime import datetime, timezone -from traitlets import HasTraits, Type from jupyter_client.asynchronous.client import AsyncKernelClient from jupyter_client.channels import AsyncZMQSocketChannel from jupyter_client.channelsabc import ChannelABC +from traitlets import HasTraits, Type + +from .message_utils import encode_channel_in_message_dict, parse_msg_id from .states import ExecutionStates -from .message_utils import parse_msg_id, encode_channel_in_message_dict class NamedAsyncZMQSocketChannel(AsyncZMQSocketChannel): """Prepends the channel name to all message IDs to this socket.""" + channel_name = "unknown" - + def send(self, msg): """Send a message with automatic channel encoding.""" msg = encode_channel_in_message_dict(msg, self.channel_name) - return super().send(msg) - - + return super().send(msg) + + class ShellChannel(NamedAsyncZMQSocketChannel): """Shell channel that automatically encodes 'shell' in outgoing msg_ids.""" + channel_name = "shell" -class ControlChannel(AsyncZMQSocketChannel): +class ControlChannel(NamedAsyncZMQSocketChannel): """Control channel that automatically encodes 'control' in outgoing msg_ids.""" + channel_name = "control" -class StdinChannel(AsyncZMQSocketChannel): +class StdinChannel(NamedAsyncZMQSocketChannel): """Stdin channel that automatically encodes 'stdin' in outgoing msg_ids.""" + channel_name = "stdin" @@ -77,7 +82,9 @@ class JupyterServerKernelClientMixin(HasTraits): # Connection test configuration connection_test_timeout: float = 120.0 # Total timeout for connection test in seconds connection_test_check_interval: float = 0.1 # How often to check for messages in seconds - connection_test_retry_interval: float = 10.0 # How often to retry kernel_info requests in seconds + connection_test_retry_interval: float = ( + 10.0 # How often to retry kernel_info requests in seconds + ) # Override channel classes to use our custom ones with automatic encoding shell_channel_class = Type(ShellChannel) @@ -105,8 +112,8 @@ def __init__(self, *args, **kwargs): def add_listener( self, callback: t.Callable[[str, list[bytes]], None], - msg_types: t.Optional[t.List[t.Tuple[str, str]]] = None, - exclude_msg_types: t.Optional[t.List[t.Tuple[str, str]]] = None + msg_types: t.Optional[list[tuple[str, str]]] = None, + exclude_msg_types: t.Optional[list[tuple[str, str]]] = None, ): """Add a listener to be called when messages are received. @@ -128,8 +135,8 @@ def add_listener( # Store the listener with its filter configuration self._listeners[callback] = { - 'msg_types': set(msg_types) if msg_types else None, - 'exclude_msg_types': set(exclude_msg_types) if exclude_msg_types else None + "msg_types": set(msg_types) if msg_types else None, + "exclude_msg_types": set(exclude_msg_types) if exclude_msg_types else None, } def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]): @@ -188,7 +195,7 @@ def _send_message(self, channel_name: str, msg: list[bytes]): channel = getattr(self, f"{channel_name}_channel", None) channel.session.send_raw(channel.socket, msg) - except Exception as e: + except Exception: self.log.warn("Error handling incoming message.") def handle_incoming_message(self, channel_name: str, msg: list[bytes]): @@ -225,7 +232,6 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]): self._send_message(channel_name, msg) - def handle_outgoing_message(self, channel_name: str, msg: list[bytes]): """Public API for manufacturing messages to send to kernel client listeners. @@ -246,17 +252,19 @@ async def _route_to_listeners(self, channel_name: str, msg: list[bytes]): # Validate message format before routing if not msg or len(msg) < 4: - self.log.warning(f"Cannot route malformed message on {channel_name}: {len(msg) if msg else 0} parts (expected at least 4)") + self.log.warning( + f"Cannot route malformed message on {channel_name}: {len(msg) if msg else 0} parts (expected at least 4)" + ) return # Extract message type for filtering msg_type = None try: header = self.session.unpack(msg[0]) if msg and len(msg) > 0 else {} - msg_type = header.get('msg_type', 'unknown') + msg_type = header.get("msg_type", "unknown") except Exception as e: self.log.debug(f"Error extracting message type: {e}") - msg_type = 'unknown' + msg_type = "unknown" # Create tasks for listeners that match the filter tasks = [] @@ -269,7 +277,9 @@ async def _route_to_listeners(self, channel_name: str, msg: list[bytes]): if tasks: await asyncio.gather(*tasks, return_exceptions=True) - def _should_route_to_listener(self, msg_type: str, channel_name: str, filter_config: dict) -> bool: + def _should_route_to_listener( + self, msg_type: str, channel_name: str, filter_config: dict + ) -> bool: """Determine if a message should be routed to a listener based on its filter configuration. Args: @@ -280,8 +290,8 @@ def _should_route_to_listener(self, msg_type: str, channel_name: str, filter_con Returns: bool: True if the message should be routed to the listener, False otherwise """ - msg_types = filter_config.get('msg_types') - exclude_msg_types = filter_config.get('exclude_msg_types') + msg_types = filter_config.get("msg_types") + exclude_msg_types = filter_config.get("exclude_msg_types") # If msg_types is specified (inclusion filter) if msg_types is not None: @@ -303,7 +313,13 @@ async def _call_listener(self, listener: t.Callable, channel_name: str, msg: lis except Exception as e: self.log.error(f"Error in listener {listener}: {e}") - def _update_execution_state_from_status(self, channel_name: str, msg_dict: dict, parent_msg_id: str = None, execution_state: str = None): + def _update_execution_state_from_status( + self, + channel_name: str, + msg_dict: dict, + parent_msg_id: str = None, + execution_state: str = None, + ): """Update execution state from a status message if it originated from shell channel. This method checks if a status message on the iopub channel originated from a shell @@ -366,7 +382,9 @@ def _update_execution_state_from_status(self, channel_name: str, msg_dict: dict, if isinstance(content, bytes): content = self.session.unpack(content) execution_state = content.get("execution_state") - self.log.debug(f"Ignoring status message - cannot parse parent channel (state would be: {execution_state})") + self.log.debug( + f"Ignoring status message - cannot parse parent channel (state would be: {execution_state})" + ) except Exception as e: self.log.debug(f"Error updating execution state from status message: {e}") @@ -391,10 +409,7 @@ async def broadcast_state(self): return # Create status message - msg_dict = self.session.msg( - "status", - content={"execution_state": self.execution_state} - ) + msg_dict = self.session.msg("status", content={"execution_state": self.execution_state}) # Serialize the message # session.serialize() returns: @@ -404,7 +419,9 @@ async def broadcast_state(self): # Skip delimiter (index 0) and signature (index 1) to get message parts # Result: [header, parent_header, metadata, content, buffers...] if len(serialized) < 6: # Need delimiter + signature + 4 message parts minimum - self.log.warning(f"broadcast_state: serialized message too short: {len(serialized)} parts") + self.log.warning( + f"broadcast_state: serialized message too short: {len(serialized)} parts" + ) return msg_parts = serialized[2:] # Skip delimiter and signature @@ -422,7 +439,7 @@ async def start_listening(self): self._listening = True # Monitor each channel for incoming messages - for channel_name in ['iopub', 'shell', 'stdin', 'control']: + for channel_name in ["iopub", "shell", "stdin", "control"]: channel = getattr(self, f"{channel_name}_channel", None) if channel and channel.is_alive(): task = asyncio.create_task(self._monitor_channel_messages(channel_name, channel)) @@ -433,12 +450,12 @@ async def start_listening(self): async def stop_listening(self): """Stop listening for messages.""" # Stop monitoring tasks - if hasattr(self, '_monitoring_tasks'): + if hasattr(self, "_monitoring_tasks"): for task in self._monitoring_tasks: task.cancel() self._monitoring_tasks = [] - self.log.info(f"Stopped listening") + self.log.info("Stopped listening") async def _monitor_channel_messages(self, channel_name: str, channel: ChannelABC): """Monitor a channel for incoming messages and route them to listeners.""" @@ -466,7 +483,9 @@ async def _monitor_channel_messages(self, channel_name: str, channel: ChannelABC if msg_list and len(msg_list) >= 5: await self._route_to_listeners(channel_name, msg_list[1:]) else: - self.log.warning(f"Received malformed message on {channel_name}: {len(msg_list) if msg_list else 0} parts") + self.log.warning( + f"Received malformed message on {channel_name}: {len(msg_list) if msg_list else 0} parts" + ) except Exception as e: # Log the error instead of silently ignoring it @@ -514,7 +533,7 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool: await asyncio.gather( self._send_kernel_info_shell(), self._send_kernel_info_control(), - return_exceptions=True + return_exceptions=True, ) except Exception as e: self.log.debug(f"Error sending initial kernel_info requests: {e}") @@ -531,25 +550,35 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool: # Check if we've received any status messages since connection attempt # This indicates the kernel is connected, even if busy executing something - if self.last_shell_status_time and self.last_shell_status_time > connection_attempt_time: + if ( + self.last_shell_status_time + and self.last_shell_status_time > connection_attempt_time + ): self.log.info("Kernel communication test succeeded: received shell status message") return True - if self.last_control_status_time and self.last_control_status_time > connection_attempt_time: - self.log.info("Kernel communication test succeeded: received control status message") + if ( + self.last_control_status_time + and self.last_control_status_time > connection_attempt_time + ): + self.log.info( + "Kernel communication test succeeded: received control status message" + ) return True # Send kernel_info requests at regular intervals time_since_last_request = time.time() - last_kernel_info_time if time_since_last_request >= self.connection_test_retry_interval: - self.log.debug(f"Sending kernel_info requests to shell and control channels (elapsed: {elapsed:.1f}s)") + self.log.debug( + f"Sending kernel_info requests to shell and control channels (elapsed: {elapsed:.1f}s)" + ) try: # Send kernel_info to both channels in parallel (no reply expected) await asyncio.gather( self._send_kernel_info_shell(), self._send_kernel_info_control(), - return_exceptions=True + return_exceptions=True, ) last_kernel_info_time = time.time() except Exception as e: @@ -564,7 +593,7 @@ async def _test_kernel_communication(self, timeout: float = None) -> bool: async def _send_kernel_info_shell(self): """Send kernel_info request on shell channel (no reply expected).""" try: - if hasattr(self, 'kernel_info'): + if hasattr(self, "kernel_info"): # Send without waiting for reply self.kernel_info() except Exception as e: @@ -573,8 +602,8 @@ async def _send_kernel_info_shell(self): async def _send_kernel_info_control(self): """Send kernel_info request on control channel (no reply expected).""" try: - if hasattr(self.control_channel, 'send'): - msg = self.session.msg('kernel_info_request') + if hasattr(self.control_channel, "send"): + msg = self.session.msg("kernel_info_request") # Channel wrapper will automatically encode channel in msg_id self.control_channel.send(msg) except Exception as e: @@ -617,7 +646,7 @@ async def connect(self) -> bool: await self.start_listening() # Unpause heartbeat channel if method exists - if hasattr(self.hb_channel, 'unpause'): + if hasattr(self.hb_channel, "unpause"): self.hb_channel.unpause() # Wait for heartbeat @@ -631,7 +660,9 @@ async def connect(self) -> bool: # Test kernel communication (handles retries internally) if not await self._test_kernel_communication(): - self.log.error(f"Kernel communication test failed after {self.connection_test_timeout}s timeout") + self.log.error( + f"Kernel communication test failed after {self.connection_test_timeout}s timeout" + ) return False # Mark connection as ready and process queued messages From 37dd9fe1bd784a9ca80ca07b32f2fdd4ea1803c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Nov 2025 17:55:47 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- jupyter_server/gateway/v3/managers.py | 49 ++++++++------- jupyter_server/serverapp.py | 12 +++- .../v3/connection/client_connection.py | 32 ++++++---- .../services/kernels/v3/kernelmanager.py | 30 ++++++---- .../services/kernels/v3/message_utils.py | 60 +++++++++---------- jupyter_server/services/kernels/v3/states.py | 11 ++-- 6 files changed, 110 insertions(+), 84 deletions(-) diff --git a/jupyter_server/gateway/v3/managers.py b/jupyter_server/gateway/v3/managers.py index 35b0c74a77..61ce1dff3b 100644 --- a/jupyter_server/gateway/v3/managers.py +++ b/jupyter_server/gateway/v3/managers.py @@ -1,11 +1,12 @@ """Gateway kernel manager that integrates with our kernel monitoring system.""" import asyncio -from jupyter_server.gateway.managers import GatewayMappingKernelManager -from jupyter_server.gateway.managers import GatewayKernelManager as _GatewayKernelManager -from jupyter_server.gateway.managers import GatewayKernelClient as _GatewayKernelClient -from traitlets import default, Instance, Type +from traitlets import Instance, Type, default + +from jupyter_server.gateway.managers import GatewayKernelClient as _GatewayKernelClient +from jupyter_server.gateway.managers import GatewayKernelManager as _GatewayKernelManager +from jupyter_server.gateway.managers import GatewayMappingKernelManager from jupyter_server.services.kernels.v3.client import JupyterServerKernelClientMixin @@ -38,7 +39,7 @@ def _send_message(self, channel_name: str, msg: list[bytes]): # Send to gateway channel try: channel = getattr(self, f"{channel_name}_channel", None) - if channel and hasattr(channel, 'send'): + if channel and hasattr(channel, "send"): # Convert raw message to gateway format header = self.session.unpack(msg[0]) parent_header = self.session.unpack(msg[1]) @@ -46,14 +47,14 @@ def _send_message(self, channel_name: str, msg: list[bytes]): content = self.session.unpack(msg[3]) full_msg = { - 'header': header, - 'parent_header': parent_header, - 'metadata': metadata, - 'content': content, - 'buffers': msg[4:] if len(msg) > 4 else [], - 'channel': channel_name, - 'msg_id': header.get('msg_id'), - 'msg_type': header.get('msg_type') + "header": header, + "parent_header": parent_header, + "metadata": metadata, + "content": content, + "buffers": msg[4:] if len(msg) > 4 else [], + "channel": channel_name, + "msg_id": header.get("msg_id"), + "msg_type": header.get("msg_type"), } channel.send(full_msg) @@ -74,7 +75,7 @@ async def _monitor_channel_messages(self, channel_name: str, channel): channel_name, message, parent_msg_id=message.get("parent_header", {}).get("msg_id"), - execution_state=message.get("content", {}).get("execution_state") + execution_state=message.get("content", {}).get("execution_state"), ) # Serialize message to standard format for listeners @@ -83,10 +84,14 @@ async def _monitor_channel_messages(self, channel_name: str, channel): serialized = self.session.serialize(message) # Skip delimiter (index 0) and signature (index 1) to get [header, parent_header, metadata, content, ...] - if serialized and len(serialized) >= 6: # Need delimiter + signature + 4 message parts + if ( + serialized and len(serialized) >= 6 + ): # Need delimiter + signature + 4 message parts msg_list = serialized[2:] else: - self.log.warning(f"Gateway message too short: {len(serialized) if serialized else 0} parts") + self.log.warning( + f"Gateway message too short: {len(serialized) if serialized else 0} parts" + ) continue # Route to listeners @@ -124,14 +129,15 @@ class GatewayKernelManager(_GatewayKernelManager): When jupyter_server is configured to use a gateway, this manager ensures that remote kernels receive the same level of monitoring as local kernels. """ + # Configure the manager to use our enhanced gateway client client_class = GatewayKernelClient client_factory = GatewayKernelClient kernel_client = Instance( - 'jupyter_client.client.KernelClient', + "jupyter_client.client.KernelClient", allow_none=True, - help="""Pre-created kernel client instance. Created on initialization.""" + help="""Pre-created kernel client instance. Created on initialization.""", ) def __init__(self, **kwargs): @@ -184,7 +190,9 @@ async def cleanup_resources(self, restart=False): if restart: # On restart, clear client state but keep connection # The connection will be refreshed in post_start_kernel after restart - self.log.debug(f"Clearing kernel client state for restart of kernel {self.kernel_id}") + self.log.debug( + f"Clearing kernel client state for restart of kernel {self.kernel_id}" + ) self.kernel_client.last_shell_status_time = None self.kernel_client.last_control_status_time = None # Disconnect before restart - will reconnect after @@ -208,7 +216,6 @@ def _default_kernel_manager_class(self): def start_watching_activity(self, kernel_id): pass - + def stop_buffering(self, kernel_id): pass - diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index eee6063122..3a74f5306c 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -1652,11 +1652,15 @@ def template_file_path(self) -> list[str]: @default("kernel_manager_class") def _default_kernel_manager_class(self) -> t.Union[str, type[AsyncMappingKernelManager]]: if self.kernels_api_version == 3: - gateway_enabled = getattr(self, 'gateway_config', None) and getattr(self.gateway_config, 'gateway_enabled', False) + gateway_enabled = getattr(self, "gateway_config", None) and getattr( + self.gateway_config, "gateway_enabled", False + ) if gateway_enabled: return "jupyter_server.gateway.v3.managers.GatewayMultiKernelManager" return "jupyter_server.services.kernels.v3.kernelmanager.AsyncMappingKernelManager" - gateway_enabled = getattr(self, 'gateway_config', None) and getattr(self.gateway_config, 'gateway_enabled', False) + gateway_enabled = getattr(self, "gateway_config", None) and getattr( + self.gateway_config, "gateway_enabled", False + ) if gateway_enabled: return "jupyter_server.gateway.managers.GatewayMappingKernelManager" return AsyncMappingKernelManager @@ -1685,7 +1689,9 @@ def _default_kernel_websocket_connection_class( if self.kernels_api_version == 3: # V3 uses shared kernel client connection for both local and gateway return "jupyter_server.services.kernels.v3.connection.client_connection.KernelClientWebsocketConnection" - gateway_enabled = getattr(self, 'gateway_config', None) and getattr(self.gateway_config, 'gateway_enabled', False) + gateway_enabled = getattr(self, "gateway_config", None) and getattr( + self.gateway_config, "gateway_enabled", False + ) if gateway_enabled: return "jupyter_server.gateway.connections.GatewayWebSocketConnection" return ZMQChannelsWebsocketConnection diff --git a/jupyter_server/services/kernels/v3/connection/client_connection.py b/jupyter_server/services/kernels/v3/connection/client_connection.py index afa336242d..3baf9e702e 100644 --- a/jupyter_server/services/kernels/v3/connection/client_connection.py +++ b/jupyter_server/services/kernels/v3/connection/client_connection.py @@ -1,10 +1,14 @@ +from jupyter_client.session import Session from tornado.websocket import WebSocketClosedError -from traitlets import List as TraitletsList, Tuple as TraitletsTuple +from traitlets import List as TraitletsList +from traitlets import Tuple as TraitletsTuple + from jupyter_server.services.kernels.connection.base import ( BaseKernelWebsocketConnection, + deserialize_msg_from_ws_v1, + serialize_msg_to_ws_v1, ) -from jupyter_server.services.kernels.connection.base import deserialize_msg_from_ws_v1, serialize_msg_to_ws_v1 -from jupyter_client.session import Session + from ..message_utils import encode_cell_id_in_message, strip_encoding_from_message @@ -54,7 +58,7 @@ class KernelClientWebsocketConnection(BaseKernelWebsocketConnection): these (msg_type, channel) pairs will be sent to the websocket. Example: [("status", "iopub"), ("execute_reply", "shell")] - """ + """, ) exclude_msg_types = TraitletsList( @@ -71,7 +75,7 @@ class KernelClientWebsocketConnection(BaseKernelWebsocketConnection): Note: Cannot be used together with msg_types. If both are specified, msg_types takes precedence. - """ + """, ) def _get_kernel_client(self): @@ -90,7 +94,7 @@ def _get_kernel_client(self): raise RuntimeError(f"No kernel manager found for kernel {self.kernel_id}") # Get the pre-created kernel client from the kernel manager - if not hasattr(km, 'kernel_client') or km.kernel_client is None: + if not hasattr(km, "kernel_client") or km.kernel_client is None: raise RuntimeError(f"Kernel manager for {self.kernel_id} has no kernel_client") return km.kernel_client @@ -115,8 +119,12 @@ async def connect(self): client.add_listener(self.handle_outgoing_message, msg_types=msg_types_list) elif self.exclude_msg_types is not None: # Convert list of tuples to list for the API - exclude_msg_types_list = [tuple(item) for item in self.exclude_msg_types] if self.exclude_msg_types else None - client.add_listener(self.handle_outgoing_message, exclude_msg_types=exclude_msg_types_list) + exclude_msg_types_list = ( + [tuple(item) for item in self.exclude_msg_types] if self.exclude_msg_types else None + ) + client.add_listener( + self.handle_outgoing_message, exclude_msg_types=exclude_msg_types_list + ) else: # No filtering - listen to all messages (default) client.add_listener(self.handle_outgoing_message) @@ -170,7 +178,9 @@ def handle_outgoing_message(self, channel_name, msg): try: # Validate message has minimum required parts if not msg or len(msg) < 4: - self.log.warning(f"Message on {channel_name} has insufficient parts: {len(msg) if msg else 0}") + self.log.warning( + f"Message on {channel_name} has insufficient parts: {len(msg) if msg else 0}" + ) return # Validate parts are bytes @@ -192,4 +202,6 @@ def handle_outgoing_message(self, channel_name, msg): except WebSocketClosedError: self.log.warning("A Kernel Socket message arrived on a closed websocket channel.") except Exception as err: - self.log.error(f"Error handling outgoing message on {channel_name}: {err}", exc_info=True) \ No newline at end of file + self.log.error( + f"Error handling outgoing message on {channel_name}: {err}", exc_info=True + ) diff --git a/jupyter_server/services/kernels/v3/kernelmanager.py b/jupyter_server/services/kernels/v3/kernelmanager.py index 6cbbc80272..e9f82eefe0 100644 --- a/jupyter_server/services/kernels/v3/kernelmanager.py +++ b/jupyter_server/services/kernels/v3/kernelmanager.py @@ -1,10 +1,12 @@ """Kernel manager for the Apple JupyterLab Kernel Monitor Extension.""" from jupyter_client.multikernelmanager import AsyncMultiKernelManager -from traitlets import Type, observe, Instance, default +from traitlets import Instance, Type, default, observe from jupyter_server.services.kernels.kernelmanager import ( MappingKernelManager, +) +from jupyter_server.services.kernels.kernelmanager import ( ServerKernelManager as _ServerKernelManager, ) @@ -24,22 +26,22 @@ class ServerKernelManager(_ServerKernelManager): client_class = Type( default_value=JupyterServerKernelClient, - klass='jupyter_client.client.KernelClient', + klass="jupyter_client.client.KernelClient", config=True, - help="""The kernel client class to use for creating kernel clients.""" + help="""The kernel client class to use for creating kernel clients.""", ) client_factory = Type( default_value=JupyterServerKernelClient, - klass='jupyter_client.client.KernelClient', + klass="jupyter_client.client.KernelClient", config=True, - help="""The kernel client factory class to use.""" + help="""The kernel client factory class to use.""", ) kernel_client = Instance( - 'jupyter_client.client.KernelClient', + "jupyter_client.client.KernelClient", allow_none=True, - help="""Pre-created kernel client instance. Created on initialization.""" + help="""Pre-created kernel client instance. Created on initialization.""", ) def __init__(self, **kwargs): @@ -49,11 +51,11 @@ def __init__(self, **kwargs): # Create a kernel client instance immediately self.kernel_client = self.client(session=self.session) - @observe('client_class') + @observe("client_class") def _client_class_changed(self, change): """Override parent's _client_class_changed to handle Type trait instead of DottedObjectName.""" # Set client_factory to the same class - self.client_factory = change['new'] + self.client_factory = change["new"] async def _async_post_start_kernel(self, **kwargs): """After kernel starts, connect the kernel client. @@ -97,7 +99,9 @@ async def cleanup_resources(self, restart=False): if restart: # On restart, clear client state but keep connection # The connection will be refreshed in post_start_kernel after restart - self.log.debug(f"Clearing kernel client state for restart of kernel {self.kernel_id}") + self.log.debug( + f"Clearing kernel client state for restart of kernel {self.kernel_id}" + ) self.kernel_client.last_shell_status_time = None self.kernel_client.last_control_status_time = None # Disconnect before restart - will reconnect after @@ -110,11 +114,11 @@ async def cleanup_resources(self, restart=False): self.kernel_client.stop_channels() await super().cleanup_resources(restart=restart) - + class AsyncMappingKernelManager(MappingKernelManager, AsyncMultiKernelManager): # type:ignore[misc] """Custom kernel manager that uses enhanced monitoring kernel manager with v3 API.""" - + @default("kernel_manager_class") def _default_kernel_manager_class(self): - return "jupyter_server.services.kernels.v3.kernelmanager.ServerKernelManager" \ No newline at end of file + return "jupyter_server.services.kernels.v3.kernelmanager.ServerKernelManager" diff --git a/jupyter_server/services/kernels/v3/message_utils.py b/jupyter_server/services/kernels/v3/message_utils.py index c775e38d25..0544bb39fe 100644 --- a/jupyter_server/services/kernels/v3/message_utils.py +++ b/jupyter_server/services/kernels/v3/message_utils.py @@ -12,27 +12,32 @@ - Legacy format (no encoding): "a1b2c3d4_12345_0" """ -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple + from jupyter_client.session import Session class MsgIdError(Exception): """Base exception for message ID operations.""" + pass class InvalidMsgIdFormatError(MsgIdError): """Raised when a message ID has an invalid format.""" + pass class InvalidChannelError(MsgIdError): """Raised when a channel name contains reserved characters.""" + pass class InvalidSrcIdError(MsgIdError): """Raised when a source ID contains reserved characters.""" + pass @@ -45,10 +50,8 @@ def validate_channel(channel: Optional[str]) -> None: Raises: InvalidChannelError: If channel contains ':' character """ - if channel is not None and ':' in channel: - raise InvalidChannelError( - f"Channel name cannot contain ':' character: {channel}" - ) + if channel is not None and ":" in channel: + raise InvalidChannelError(f"Channel name cannot contain ':' character: {channel}") def validate_src_id(src_id: Optional[str]) -> None: @@ -61,20 +64,14 @@ def validate_src_id(src_id: Optional[str]) -> None: InvalidSrcIdError: If src_id contains ':' or '#' characters """ if src_id is not None: - if ':' in src_id: - raise InvalidSrcIdError( - f"Source ID cannot contain ':' character: {src_id}" - ) - if '#' in src_id: - raise InvalidSrcIdError( - f"Source ID cannot contain '#' character: {src_id}" - ) + if ":" in src_id: + raise InvalidSrcIdError(f"Source ID cannot contain ':' character: {src_id}") + if "#" in src_id: + raise InvalidSrcIdError(f"Source ID cannot contain '#' character: {src_id}") def create_msg_id( - base_msg_id: str, - channel: Optional[str] = None, - src_id: Optional[str] = None + base_msg_id: str, channel: Optional[str] = None, src_id: Optional[str] = None ) -> str: """Create a structured message ID with optional channel and source ID encoding. @@ -137,15 +134,15 @@ def parse_msg_id(msg_id: str) -> Tuple[Optional[str], Optional[str], str]: raise InvalidMsgIdFormatError("Message ID cannot be empty") # Split off src_id if present (after #) - if '#' in msg_id: - msg_id_part, src_id = msg_id.split('#', 1) + if "#" in msg_id: + msg_id_part, src_id = msg_id.split("#", 1) else: msg_id_part = msg_id src_id = None # Split channel and base msg_id (before :) - if ':' in msg_id_part: - channel, base_msg_id = msg_id_part.split(':', 1) + if ":" in msg_id_part: + channel, base_msg_id = msg_id_part.split(":", 1) else: # Legacy format - no channel specified channel = None @@ -230,15 +227,15 @@ def encode_channel_in_message_dict(msg_dict: dict, channel: str) -> dict: Modified message dict with channel encoded in header msg_id Examples: - >>> msg = session.msg('kernel_info_request') - >>> msg = encode_channel_in_message_dict(msg, 'shell') + >>> msg = session.msg("kernel_info_request") + >>> msg = encode_channel_in_message_dict(msg, "shell") >>> # msg['header']['msg_id'] now has 'shell:' prefix """ - if 'header' in msg_dict and 'msg_id' in msg_dict['header']: - msg_id = msg_dict['header']['msg_id'] + if "header" in msg_dict and "msg_id" in msg_dict["header"]: + msg_id = msg_dict["header"]["msg_id"] # Only encode if not already encoded if not msg_id.startswith(f"{channel}:"): - msg_dict['header']['msg_id'] = f"{channel}:{msg_id}" + msg_dict["header"]["msg_id"] = f"{channel}:{msg_id}" return msg_dict @@ -311,20 +308,19 @@ def strip_encoding_from_message(msg_list: List[bytes]) -> List[bytes]: # Strip from header msg_id header = session.unpack(msg_copy[0]) - if 'msg_id' in header: - _, _, base_msg_id = parse_msg_id(header['msg_id']) - header['msg_id'] = base_msg_id + if "msg_id" in header: + _, _, base_msg_id = parse_msg_id(header["msg_id"]) + header["msg_id"] = base_msg_id msg_copy[0] = session.pack(header) # Strip from parent_header msg_id parent_header = session.unpack(msg_copy[1]) - if 'msg_id' in parent_header and parent_header['msg_id']: - _, _, base_msg_id = parse_msg_id(parent_header['msg_id']) - parent_header['msg_id'] = base_msg_id + if "msg_id" in parent_header and parent_header["msg_id"]: + _, _, base_msg_id = parse_msg_id(parent_header["msg_id"]) + parent_header["msg_id"] = base_msg_id msg_copy[1] = session.pack(parent_header) return msg_copy except Exception: # If decoding fails, return original message return msg_list - diff --git a/jupyter_server/services/kernels/v3/states.py b/jupyter_server/services/kernels/v3/states.py index 6d6b28ba6d..fb47eb8240 100644 --- a/jupyter_server/services/kernels/v3/states.py +++ b/jupyter_server/services/kernels/v3/states.py @@ -1,5 +1,4 @@ -from enum import Enum -from enum import EnumMeta +from enum import Enum, EnumMeta class StrContainerEnumMeta(EnumMeta): @@ -8,15 +7,17 @@ def __contains__(cls, item): if item == name or item == member.value: return True return False - + + class StrContainerEnum(str, Enum, metaclass=StrContainerEnumMeta): """A Enum object that enables search for items in a normal Enum object based on key and value. """ - + + class ExecutionStates(StrContainerEnum): BUSY = "busy" IDLE = "idle" STARTING = "starting" UNKNOWN = "unknown" - DEAD = "dead" \ No newline at end of file + DEAD = "dead"