diff --git a/jupyter_server/gateway/v3/__init__.py b/jupyter_server/gateway/v3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/jupyter_server/gateway/v3/managers.py b/jupyter_server/gateway/v3/managers.py new file mode 100644 index 000000000..61ce1dff3 --- /dev/null +++ b/jupyter_server/gateway/v3/managers.py @@ -0,0 +1,221 @@ +"""Gateway kernel manager that integrates with our kernel monitoring system.""" + +import asyncio + +from traitlets import Instance, Type, default + +from jupyter_server.gateway.managers import GatewayKernelClient as _GatewayKernelClient +from jupyter_server.gateway.managers import GatewayKernelManager as _GatewayKernelManager +from jupyter_server.gateway.managers import GatewayMappingKernelManager +from jupyter_server.services.kernels.v3.client import JupyterServerKernelClientMixin + + +class GatewayKernelClient(JupyterServerKernelClientMixin, _GatewayKernelClient): + """ + Gateway kernel client that combines our monitoring capabilities with gateway support. + + This client inherits from: + - JupyterServerKernelClientMixin: Provides kernel monitoring capabilities, message caching, + and execution state tracking that integrates with our kernel monitor system + - GatewayKernelClient: Provides gateway communication capabilities for remote kernels + + The combination allows remote gateway kernels to be monitored with the same level of + detail as local kernels, including heartbeat monitoring, execution state tracking, + and kernel lifecycle management. + """ + + async def _test_kernel_communication(self, timeout: float = 10.0) -> bool: + """Skip kernel_info test for gateway kernels. + + Gateway kernels handle communication differently and the kernel_info + test can hang due to message routing differences. + + Returns: + bool: Always returns True for gateway kernels + """ + return True + + def _send_message(self, channel_name: str, msg: list[bytes]): + # Send to gateway channel + try: + channel = getattr(self, f"{channel_name}_channel", None) + if channel and hasattr(channel, "send"): + # Convert raw message to gateway format + header = self.session.unpack(msg[0]) + parent_header = self.session.unpack(msg[1]) + metadata = self.session.unpack(msg[2]) + content = self.session.unpack(msg[3]) + + full_msg = { + "header": header, + "parent_header": parent_header, + "metadata": metadata, + "content": content, + "buffers": msg[4:] if len(msg) > 4 else [], + "channel": channel_name, + "msg_id": header.get("msg_id"), + "msg_type": header.get("msg_type"), + } + + channel.send(full_msg) + except Exception as e: + self.log.warn(f"Error handling incoming message on gateway: {e}") + + async def _monitor_channel_messages(self, channel_name: str, channel): + """Monitor a gateway channel for incoming messages.""" + try: + while channel.is_alive(): + try: + # Get message from gateway channel queue + message = await channel.get_msg() + + # Update execution state from status messages + # Gateway messages are already deserialized dicts + self._update_execution_state_from_status( + channel_name, + message, + parent_msg_id=message.get("parent_header", {}).get("msg_id"), + execution_state=message.get("content", {}).get("execution_state"), + ) + + # Serialize message to standard format for listeners + # Gateway messages are dicts, convert to list[bytes] format + # session.serialize() returns: [b'', signature, header, parent_header, metadata, content, buffers...] + serialized = self.session.serialize(message) + + # Skip delimiter (index 0) and signature (index 1) to get [header, parent_header, metadata, content, ...] + if ( + serialized and len(serialized) >= 6 + ): # Need delimiter + signature + 4 message parts + msg_list = serialized[2:] + else: + self.log.warning( + f"Gateway message too short: {len(serialized) if serialized else 0} parts" + ) + continue + + # Route to listeners + await self._route_to_listeners(channel_name, msg_list) + + except asyncio.TimeoutError: + # No message available, continue loop + continue + except Exception as e: + self.log.debug(f"Error processing gateway message in {channel_name}: {e}") + continue + + await asyncio.sleep(0.01) + + except asyncio.CancelledError: + pass + except Exception as e: + self.log.error(f"Gateway channel monitoring failed for {channel_name}: {e}") + + +class GatewayKernelManager(_GatewayKernelManager): + """ + Gateway kernel manager that uses our enhanced gateway kernel client. + + This manager inherits from jupyter_server's GatewayKernelManager and configures it + to use our GatewayKernelClient, which provides: + + - Gateway communication capabilities for remote kernels + - Kernel monitoring integration (heartbeat, execution state tracking) + - Message ID encoding with channel and src_id using simple string operations + - Full compatibility with our kernel monitor extension + - Pre-created kernel client instance stored as a property + - Automatic client connection/disconnection on kernel start/shutdown + + When jupyter_server is configured to use a gateway, this manager ensures that + remote kernels receive the same level of monitoring as local kernels. + """ + + # Configure the manager to use our enhanced gateway client + client_class = GatewayKernelClient + client_factory = GatewayKernelClient + + kernel_client = Instance( + "jupyter_client.client.KernelClient", + allow_none=True, + help="""Pre-created kernel client instance. Created on initialization.""", + ) + + def __init__(self, **kwargs): + """Initialize the kernel manager and create a kernel client instance.""" + super().__init__(**kwargs) + + # Create a kernel client instance immediately + self.kernel_client = self.client(session=self.session) + + async def post_start_kernel(self, **kwargs): + """After kernel starts, connect the kernel client. + + This method is called after the kernel has been successfully started. + It loads the latest connection info (with ports set by provisioner) + and connects the kernel client to the kernel. + + Note: If you override this method, make sure to call super().post_start_kernel(**kwargs) + to ensure the kernel client connects properly. + """ + await super().post_start_kernel(**kwargs) + + try: + # Load latest connection info from kernel manager + # The provisioner has now set the real ports + self.kernel_client.load_connection_info(self.get_connection_info(session=True)) + + # Connect the kernel client + success = await self.kernel_client.connect() + + if not success: + raise RuntimeError(f"Failed to connect kernel client for kernel {self.kernel_id}") + + self.log.info(f"Successfully connected kernel client for kernel {self.kernel_id}") + + except Exception as e: + self.log.error(f"Failed to connect kernel client: {e}") + # Re-raise to fail the kernel start + raise + + async def cleanup_resources(self, restart=False): + """Cleanup resources, disconnecting the kernel client if not restarting. + + Parameters + ---------- + restart : bool + If True, the kernel is being restarted and we should keep the client + connected but clear its state. If False, fully disconnect. + """ + if self.kernel_client: + if restart: + # On restart, clear client state but keep connection + # The connection will be refreshed in post_start_kernel after restart + self.log.debug( + f"Clearing kernel client state for restart of kernel {self.kernel_id}" + ) + self.kernel_client.last_shell_status_time = None + self.kernel_client.last_control_status_time = None + # Disconnect before restart - will reconnect after + await self.kernel_client.stop_listening() + self.kernel_client.stop_channels() + else: + # On shutdown, fully disconnect the client + self.log.debug(f"Disconnecting kernel client for kernel {self.kernel_id}") + await self.kernel_client.stop_listening() + self.kernel_client.stop_channels() + + await super().cleanup_resources(restart=restart) + + +class GatewayMultiKernelManager(GatewayMappingKernelManager): + """Custom kernel manager that uses enhanced monitoring kernel manager with v3 API.""" + + @default("kernel_manager_class") + def _default_kernel_manager_class(self): + return "jupyter_server.gateway.v3.managers.GatewayKernelManager" + + def start_watching_activity(self, kernel_id): + pass + + def stop_buffering(self, kernel_id): + pass diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 1afbef4d0..3a74f5306 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -110,6 +110,9 @@ GatewayMappingKernelManager, GatewaySessionManager, ) +from jupyter_server.gateway.v3.managers import ( + GatewayMultiKernelManager, +) from jupyter_server.log import log_request from jupyter_server.prometheus.metrics import ( ACTIVE_DURATION, @@ -131,6 +134,12 @@ AsyncMappingKernelManager, MappingKernelManager, ) +from jupyter_server.services.kernels.v3.connection.client_connection import ( + KernelClientWebsocketConnection, +) +from jupyter_server.services.kernels.v3.kernelmanager import ( + AsyncMappingKernelManager as V3AsyncMappingKernelManager, +) from jupyter_server.services.sessions.sessionmanager import SessionManager from jupyter_server.utils import ( JupyterServerAuthWarning, @@ -829,6 +838,13 @@ def start(self): extensions. """, ) +flags["kernels-v3"] = ( + {"ServerApp": {"kernels_api_version": 3}}, + _i18n( + "Enable the next-generation kernel API (v3) with shared kernel clients, " + "improved message routing, and enhanced kernel monitoring." + ), +) # Add notebook manager flags @@ -901,6 +917,10 @@ class ServerApp(JupyterApp): Authorizer, EventLogger, ZMQChannelsWebsocketConnection, + # V3 Kernel API classes + V3AsyncMappingKernelManager, + KernelClientWebsocketConnection, + GatewayMultiKernelManager, ] subcommands: dict[str, t.Any] = { @@ -1612,6 +1632,17 @@ def template_file_path(self) -> list[str]: help=_i18n("The content manager class to use."), ) + kernels_api_version = Integer( + 2, + config=True, + help=_i18n( + "Kernel API version to use. " + "Version 2 (default): Standard kernel API with direct ZMQ connections. " + "Version 3: Next-generation API with shared kernel clients, " + "improved message routing, and enhanced kernel monitoring." + ), + ) + kernel_manager_class = Type( klass=MappingKernelManager, config=True, @@ -1620,7 +1651,17 @@ def template_file_path(self) -> list[str]: @default("kernel_manager_class") def _default_kernel_manager_class(self) -> t.Union[str, type[AsyncMappingKernelManager]]: - if self.gateway_config.gateway_enabled: + if self.kernels_api_version == 3: + gateway_enabled = getattr(self, "gateway_config", None) and getattr( + self.gateway_config, "gateway_enabled", False + ) + if gateway_enabled: + return "jupyter_server.gateway.v3.managers.GatewayMultiKernelManager" + return "jupyter_server.services.kernels.v3.kernelmanager.AsyncMappingKernelManager" + gateway_enabled = getattr(self, "gateway_config", None) and getattr( + self.gateway_config, "gateway_enabled", False + ) + if gateway_enabled: return "jupyter_server.gateway.managers.GatewayMappingKernelManager" return AsyncMappingKernelManager @@ -1645,7 +1686,13 @@ def _default_session_manager_class(self) -> t.Union[str, type[SessionManager]]: def _default_kernel_websocket_connection_class( self, ) -> t.Union[str, type[ZMQChannelsWebsocketConnection]]: - if self.gateway_config.gateway_enabled: + if self.kernels_api_version == 3: + # V3 uses shared kernel client connection for both local and gateway + return "jupyter_server.services.kernels.v3.connection.client_connection.KernelClientWebsocketConnection" + gateway_enabled = getattr(self, "gateway_config", None) and getattr( + self.gateway_config, "gateway_enabled", False + ) + if gateway_enabled: return "jupyter_server.gateway.connections.GatewayWebSocketConnection" return ZMQChannelsWebsocketConnection diff --git a/jupyter_server/services/kernels/v3/__init__.py b/jupyter_server/services/kernels/v3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/jupyter_server/services/kernels/v3/client.py b/jupyter_server/services/kernels/v3/client.py new file mode 100644 index 000000000..ecf05e9aa --- /dev/null +++ b/jupyter_server/services/kernels/v3/client.py @@ -0,0 +1,733 @@ +import asyncio +import time +import typing as t +from datetime import datetime, timezone + +from jupyter_client.asynchronous.client import AsyncKernelClient +from jupyter_client.channels import AsyncZMQSocketChannel +from jupyter_client.channelsabc import ChannelABC +from traitlets import HasTraits, Type + +from .message_utils import encode_channel_in_message_dict, parse_msg_id +from .states import ExecutionStates + + +class NamedAsyncZMQSocketChannel(AsyncZMQSocketChannel): + """Prepends the channel name to all message IDs to this socket.""" + + channel_name = "unknown" + + def send(self, msg): + """Send a message with automatic channel encoding.""" + msg = encode_channel_in_message_dict(msg, self.channel_name) + return super().send(msg) + + +class ShellChannel(NamedAsyncZMQSocketChannel): + """Shell channel that automatically encodes 'shell' in outgoing msg_ids.""" + + channel_name = "shell" + + +class ControlChannel(NamedAsyncZMQSocketChannel): + """Control channel that automatically encodes 'control' in outgoing msg_ids.""" + + channel_name = "control" + + +class StdinChannel(NamedAsyncZMQSocketChannel): + """Stdin channel that automatically encodes 'stdin' in outgoing msg_ids.""" + + channel_name = "stdin" + + +class JupyterServerKernelClientMixin(HasTraits): + """Mixin that enhances AsyncKernelClient with listener API, message queuing, and channel encoding. + + Key Features: + + 1. **Listener API**: Register multiple listeners to receive kernel messages without blocking. + - `add_listener()`: Add a callback function to receive messages from the kernel + - `remove_listener()`: Remove a registered listener + - Supports message filtering by type and channel + - Multiple listeners can be registered (e.g., multiple WebSocket connections) + + 2. **Message Queuing**: Queue messages that arrive before the kernel client is ready. + - Messages from WebSockets are queued during kernel startup + - Queued messages are processed once the kernel connection is established + - Prevents message loss during the connection handshake + - Configurable queue size to prevent memory issues + + 3. **Channel Encoding**: Automatically encode channel names in all outgoing message IDs. + - All messages sent through shell, control, or stdin channels get the channel name prepended + - Format: `{channel}:{base_msg_id}` (e.g., "shell:abc123_456_0") + - Makes it easy to identify which channel status messages came from + - Enables proper execution state tracking (shell vs control channel responses) + - Uses custom channel classes (ShellChannel, ControlChannel, StdinChannel) + + This mixin is designed to work with Jupyter Server's multi-websocket architecture where + a single kernel client is shared across multiple WebSocket connections. + """ + + # Track kernel execution state (simplified - just a string) + execution_state: str = ExecutionStates.UNKNOWN.value + + # Track kernel activity + last_activity: datetime = None + + # Track last status message time per channel (shell and control) + last_shell_status_time: datetime = None + last_control_status_time: datetime = None + + # Connection test configuration + connection_test_timeout: float = 120.0 # Total timeout for connection test in seconds + connection_test_check_interval: float = 0.1 # How often to check for messages in seconds + connection_test_retry_interval: float = ( + 10.0 # How often to retry kernel_info requests in seconds + ) + + # Override channel classes to use our custom ones with automatic encoding + shell_channel_class = Type(ShellChannel) + control_channel_class = Type(ControlChannel) + stdin_channel_class = Type(StdinChannel) + + # Set of listener functions - don't use Traitlets Set, just plain Python set + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._listeners = {} # Maps callback -> filter config + self._listening = False + + # Connection state tracking + self._connecting = False + self._connection_ready = False + self._connection_ready_event = asyncio.Event() + + # Message queue for messages received before connection is ready + self._queued_messages = [] + self._max_queue_size = 1000 # Prevent memory issues + + # Note: The session is already EncodedMsgIdSession, created by the KernelManager + # No need to replace it here + + def add_listener( + self, + callback: t.Callable[[str, list[bytes]], None], + msg_types: t.Optional[list[tuple[str, str]]] = None, + exclude_msg_types: t.Optional[list[tuple[str, str]]] = None, + ): + """Add a listener to be called when messages are received. + + Args: + callback: Function that takes (channel_name, msg_bytes) as arguments + msg_types: Optional list of (msg_type, channel) tuples to include. + If provided, only messages matching these filters will be sent to the listener. + Example: [("status", "iopub"), ("execute_reply", "shell")] + exclude_msg_types: Optional list of (msg_type, channel) tuples to exclude. + If provided, messages matching these filters will NOT be sent to the listener. + Example: [("status", "iopub")] + + Note: + - If both msg_types and exclude_msg_types are provided, msg_types takes precedence + - If neither is provided, all messages are sent (default behavior) + """ + if msg_types is not None and exclude_msg_types is not None: + raise ValueError("Cannot specify both msg_types and exclude_msg_types") + + # Store the listener with its filter configuration + self._listeners[callback] = { + "msg_types": set(msg_types) if msg_types else None, + "exclude_msg_types": set(exclude_msg_types) if exclude_msg_types else None, + } + + def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]): + """Remove a listener.""" + self._listeners.pop(callback, None) + + def mark_connection_ready(self): + """Mark the connection as ready and process queued messages.""" + if not self._connection_ready: + self._connecting = False + self._connection_ready = True + self._connection_ready_event.set() + + # Process queued messages + asyncio.create_task(self._process_queued_messages()) + + async def wait_for_connection_ready(self, timeout: float = 30.0) -> bool: + """Wait for the connection to be ready.""" + try: + await asyncio.wait_for(self._connection_ready_event.wait(), timeout=timeout) + return True + except asyncio.TimeoutError: + return False + + async def _process_queued_messages(self): + """Process all messages that were queued during startup.""" + self.log.info(f"Processing {len(self._queued_messages)} queued messages") + + queued_messages = self._queued_messages.copy() + self._queued_messages.clear() + + for channel_name, msg in queued_messages: + try: + # Send queued messages to the kernel (these are incoming from websockets) + self._send_message(channel_name, msg) + except Exception as e: + self.log.error(f"Error processing queued message: {e}") + + def _queue_message_if_not_ready(self, channel_name: str, msg: list[bytes]) -> bool: + """Queue a message if connection is not ready. Returns True if queued.""" + if not self._connection_ready: + if len(self._queued_messages) < self._max_queue_size: + self._queued_messages.append((channel_name, msg)) + return True + else: + # Queue is full, drop oldest message + self._queued_messages.pop(0) + self._queued_messages.append((channel_name, msg)) + self.log.warning("Message queue full, dropping oldest message") + return True + return False + + def _send_message(self, channel_name: str, msg: list[bytes]): + # Route message to the appropriate kernel channel + try: + channel = getattr(self, f"{channel_name}_channel", None) + channel.session.send_raw(channel.socket, msg) + + except Exception: + self.log.warn("Error handling incoming message.") + + def handle_incoming_message(self, channel_name: str, msg: list[bytes]): + """Handle incoming kernel messages and encode channel in msg_id. + + This method processes incoming kernel messages from WebSocket clients. + It prepends the channel name to the msg_id for internal routing. + + Args: + channel_name: The channel the message came from ('shell', 'iopub', etc.) + msg: The raw message bytes (already deserialized from websocket format) + """ + # Validate message has content + if not msg or len(msg) == 0: + return + + # Prepend channel to msg_id for internal routing + try: + header = self.session.unpack(msg[0]) + msg_id = header["msg_id"] + + # Check if msg_id already has channel encoded + if not msg_id.startswith(f"{channel_name}:"): + # Prepend channel + header["msg_id"] = f"{channel_name}:{msg_id}" + msg[0] = self.session.pack(header) + + except Exception as e: + self.log.debug(f"Error encoding channel in incoming message ID: {e}") + + # If connection is not ready, queue the message + if self._queue_message_if_not_ready(channel_name, msg): + return + + self._send_message(channel_name, msg) + + def handle_outgoing_message(self, channel_name: str, msg: list[bytes]): + """Public API for manufacturing messages to send to kernel client listeners. + + This allows external code to simulate kernel messages and send them to all + registered listeners, useful for testing and message injection. + + Args: + channel_name: The channel the message came from ('shell', 'iopub', etc.) + msg: The raw message bytes + """ + # Same as handle_incoming_message - route to all listeners + asyncio.create_task(self._route_to_listeners(channel_name, msg)) + + async def _route_to_listeners(self, channel_name: str, msg: list[bytes]): + """Route message to all registered listeners based on their filters.""" + if not self._listeners: + return + + # Validate message format before routing + if not msg or len(msg) < 4: + self.log.warning( + f"Cannot route malformed message on {channel_name}: {len(msg) if msg else 0} parts (expected at least 4)" + ) + return + + # Extract message type for filtering + msg_type = None + try: + header = self.session.unpack(msg[0]) if msg and len(msg) > 0 else {} + msg_type = header.get("msg_type", "unknown") + except Exception as e: + self.log.debug(f"Error extracting message type: {e}") + msg_type = "unknown" + + # Create tasks for listeners that match the filter + tasks = [] + for listener, filter_config in self._listeners.items(): + if self._should_route_to_listener(msg_type, channel_name, filter_config): + task = asyncio.create_task(self._call_listener(listener, channel_name, msg)) + tasks.append(task) + + # Wait for all listeners to complete + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + def _should_route_to_listener( + self, msg_type: str, channel_name: str, filter_config: dict + ) -> bool: + """Determine if a message should be routed to a listener based on its filter configuration. + + Args: + msg_type: The message type (e.g., "status", "execute_reply") + channel_name: The channel name (e.g., "iopub", "shell") + filter_config: Dictionary with 'msg_types' and 'exclude_msg_types' keys + + Returns: + bool: True if the message should be routed to the listener, False otherwise + """ + msg_types = filter_config.get("msg_types") + exclude_msg_types = filter_config.get("exclude_msg_types") + + # If msg_types is specified (inclusion filter) + if msg_types is not None: + return (msg_type, channel_name) in msg_types + + # If exclude_msg_types is specified (exclusion filter) + if exclude_msg_types is not None: + return (msg_type, channel_name) not in exclude_msg_types + + # No filter specified - route all messages + return True + + async def _call_listener(self, listener: t.Callable, channel_name: str, msg: list[bytes]): + """Call a single listener, ensuring it's async and handling errors.""" + try: + result = listener(channel_name, msg) + if asyncio.iscoroutine(result): + await result + except Exception as e: + self.log.error(f"Error in listener {listener}: {e}") + + def _update_execution_state_from_status( + self, + channel_name: str, + msg_dict: dict, + parent_msg_id: str = None, + execution_state: str = None, + ): + """Update execution state from a status message if it originated from shell channel. + + This method checks if a status message on the iopub channel originated from a shell + channel request before updating the execution state. This prevents control channel + status messages from affecting execution state tracking. + + Additionally tracks the last time we received status messages from shell and control + channels for connection monitoring purposes. + + Args: + channel_name: The channel the message came from (should be 'iopub') + msg_dict: The deserialized message dictionary + parent_msg_id: Optional parent message ID (extracted if not provided) + execution_state: Optional execution state (extracted if not provided) + """ + if channel_name != "iopub" or msg_dict.get("msg_type") != "status": + return + + try: + # Extract parent_msg_id if not provided + if parent_msg_id is None: + parent_header = msg_dict.get("parent_header", {}) + if isinstance(parent_header, bytes): + parent_header = self.session.unpack(parent_header) + parent_msg_id = parent_header.get("msg_id") + + # Parse parent_msg_id to extract channel + if parent_msg_id: + try: + parent_channel, _, _ = parse_msg_id(parent_msg_id) + except Exception as e: + self.log.debug(f"Error parsing parent msg_id '{parent_msg_id}': {e}") + parent_channel = None + + # Track last status message time for both shell and control channels + current_time = datetime.now(timezone.utc) + if parent_channel == "shell": + self.last_shell_status_time = current_time + self.last_activity = current_time + elif parent_channel == "control": + self.last_control_status_time = current_time + + # Only update execution state if message came from shell channel + if parent_channel == "shell": + # Extract execution_state if not provided + if execution_state is None: + content = msg_dict.get("content", {}) + if isinstance(content, bytes): + content = self.session.unpack(content) + execution_state = content.get("execution_state") + + if execution_state: + old_state = self.execution_state + self.execution_state = execution_state + self.log.debug(f"Execution state: {old_state} -> {execution_state}") + elif parent_channel is None: + # Log when we can't determine parent channel + if execution_state is None: + content = msg_dict.get("content", {}) + if isinstance(content, bytes): + content = self.session.unpack(content) + execution_state = content.get("execution_state") + self.log.debug( + f"Ignoring status message - cannot parse parent channel (state would be: {execution_state})" + ) + except Exception as e: + self.log.debug(f"Error updating execution state from status message: {e}") + + async def broadcast_state(self): + """Broadcast current kernel execution state to all listeners. + + This method creates and sends a status message to all kernel listeners + (typically WebSocket connections) to inform them of the current kernel + execution state. + + The status message is manufactured using the session's message format + and sent through the normal listener routing mechanism. + + Note: Only broadcasts if execution_state is a valid kernel protocol state. + Skips broadcasting if state is "unknown" (not part of kernel protocol). + """ + try: + # Don't broadcast "unknown" state - it's not part of the kernel protocol + # Valid states are: starting, idle, busy, restarting, dead + if self.execution_state == ExecutionStates.UNKNOWN.value: + self.log.debug("Skipping broadcast_state - execution state is unknown") + return + + # Create status message + msg_dict = self.session.msg("status", content={"execution_state": self.execution_state}) + + # Serialize the message + # session.serialize() returns: + # [b'', signature, header, parent_header, metadata, content, buffers...] + serialized = self.session.serialize(msg_dict) + + # Skip delimiter (index 0) and signature (index 1) to get message parts + # Result: [header, parent_header, metadata, content, buffers...] + if len(serialized) < 6: # Need delimiter + signature + 4 message parts minimum + self.log.warning( + f"broadcast_state: serialized message too short: {len(serialized)} parts" + ) + return + + msg_parts = serialized[2:] # Skip delimiter and signature + + # Send to listeners + self.handle_outgoing_message("iopub", msg_parts) + + except Exception as e: + self.log.warning(f"Failed to broadcast state: {e}") + + async def start_listening(self): + """Start listening for messages and monitoring channels.""" + # Start background tasks to monitor channels for messages + self._monitoring_tasks = [] + self._listening = True + + # Monitor each channel for incoming messages + for channel_name in ["iopub", "shell", "stdin", "control"]: + channel = getattr(self, f"{channel_name}_channel", None) + if channel and channel.is_alive(): + task = asyncio.create_task(self._monitor_channel_messages(channel_name, channel)) + self._monitoring_tasks.append(task) + + self.log.info(f"Started listening with {len(self._listeners)} listeners") + + async def stop_listening(self): + """Stop listening for messages.""" + # Stop monitoring tasks + if hasattr(self, "_monitoring_tasks"): + for task in self._monitoring_tasks: + task.cancel() + self._monitoring_tasks = [] + + self.log.info("Stopped listening") + + async def _monitor_channel_messages(self, channel_name: str, channel: ChannelABC): + """Monitor a channel for incoming messages and route them to listeners.""" + try: + while channel.is_alive(): + try: + # Check if there's a message ready (non-blocking) + has_message = await channel.msg_ready() + if has_message: + msg = await channel.socket.recv_multipart() + + # For deserialization and state tracking, use feed_identities to strip routing frames + idents, msg_list = channel.session.feed_identities(msg) + + # Deserialize WITHOUT content for performance (content=False) + msg_dict = channel.session.deserialize(msg_list, content=False) + + # Update execution state from status messages + self._update_execution_state_from_status(channel_name, msg_dict) + + # Route to listeners with msg_list + # After feed_identities, msg_list has format (delimiter already removed): + # [signature, header, parent_header, metadata, content, ...buffers] + # Skip signature (index 0) to get: [header, parent_header, metadata, content, ...buffers] + if msg_list and len(msg_list) >= 5: + await self._route_to_listeners(channel_name, msg_list[1:]) + else: + self.log.warning( + f"Received malformed message on {channel_name}: {len(msg_list) if msg_list else 0} parts" + ) + + except Exception as e: + # Log the error instead of silently ignoring it + self.log.debug(f"Error processing message in {channel_name}: {e}") + continue # Continue with next message instead of breaking + + # Small sleep to avoid busy waiting + await asyncio.sleep(0.01) + + except asyncio.CancelledError: + pass + except Exception as e: + self.log.error(f"Channel monitoring failed for {channel_name}: {e}") + + async def _test_kernel_communication(self, timeout: float = None) -> bool: + """Test kernel communication by monitoring execution state and sending kernel_info requests. + + This method uses a robust heuristic to determine if the kernel is connected: + 1. Checks if execution state is 'idle' (indicates shell channel is responding) + 2. Sends kernel_info requests to both shell and control channels in parallel + 3. Monitors for status message responses from either channel + 4. Retries periodically if no response is received + 5. Considers kernel connected if we receive any status messages, even if state is 'busy' + + Args: + timeout: Total timeout for connection test in seconds (uses connection_test_timeout if not provided) + + Returns: + bool: True if communication test successful, False otherwise + """ + if timeout is None: + timeout = self.connection_test_timeout + + start_time = time.time() + connection_attempt_time = datetime.now(timezone.utc) + + self.log.info("Starting kernel communication test") + + # Give the kernel a moment to be ready to receive messages + # Heartbeat beating doesn't guarantee the kernel is ready for requests + await asyncio.sleep(0.5) + + # Send initial kernel_info requests immediately + try: + await asyncio.gather( + self._send_kernel_info_shell(), + self._send_kernel_info_control(), + return_exceptions=True, + ) + except Exception as e: + self.log.debug(f"Error sending initial kernel_info requests: {e}") + + last_kernel_info_time = time.time() + + while time.time() - start_time < timeout: + elapsed = time.time() - start_time + + # Check if execution state is idle (shell channel responding and kernel ready) + if self.execution_state == ExecutionStates.IDLE.value: + self.log.info("Kernel communication test succeeded: execution state is idle") + return True + + # Check if we've received any status messages since connection attempt + # This indicates the kernel is connected, even if busy executing something + if ( + self.last_shell_status_time + and self.last_shell_status_time > connection_attempt_time + ): + self.log.info("Kernel communication test succeeded: received shell status message") + return True + + if ( + self.last_control_status_time + and self.last_control_status_time > connection_attempt_time + ): + self.log.info( + "Kernel communication test succeeded: received control status message" + ) + return True + + # Send kernel_info requests at regular intervals + time_since_last_request = time.time() - last_kernel_info_time + if time_since_last_request >= self.connection_test_retry_interval: + self.log.debug( + f"Sending kernel_info requests to shell and control channels (elapsed: {elapsed:.1f}s)" + ) + + try: + # Send kernel_info to both channels in parallel (no reply expected) + await asyncio.gather( + self._send_kernel_info_shell(), + self._send_kernel_info_control(), + return_exceptions=True, + ) + last_kernel_info_time = time.time() + except Exception as e: + self.log.debug(f"Error sending kernel_info requests: {e}") + + # Wait before next check + await asyncio.sleep(self.connection_test_check_interval) + + self.log.error(f"Kernel communication test failed: no response after {timeout}s") + return False + + async def _send_kernel_info_shell(self): + """Send kernel_info request on shell channel (no reply expected).""" + try: + if hasattr(self, "kernel_info"): + # Send without waiting for reply + self.kernel_info() + except Exception as e: + self.log.debug(f"Error sending kernel_info on shell channel: {e}") + + async def _send_kernel_info_control(self): + """Send kernel_info request on control channel (no reply expected).""" + try: + if hasattr(self.control_channel, "send"): + msg = self.session.msg("kernel_info_request") + # Channel wrapper will automatically encode channel in msg_id + self.control_channel.send(msg) + except Exception as e: + self.log.debug(f"Error sending kernel_info on control channel: {e}") + + async def connect(self) -> bool: + """Connect to the kernel and verify communication. + + This method: + 1. Starts all channels + 2. Begins listening for messages + 3. Waits for heartbeat to confirm connectivity + 4. Tests kernel communication with configurable retries + 5. Marks connection as ready + + Returns: + bool: True if connection successful, False otherwise + """ + if self._connecting: + return await self.wait_for_connection_ready() + + if self._connection_ready: + return True + + self._connecting = True + + try: + self.execution_state = ExecutionStates.BUSY.value + self.last_activity = datetime.now(timezone.utc) + + # Handle both sync and async versions of start_channels + result = self.start_channels() + if asyncio.iscoroutine(result): + await result + + # Verify channels are running. + assert self.channels_running + + # Start our listening + await self.start_listening() + + # Unpause heartbeat channel if method exists + if hasattr(self.hb_channel, "unpause"): + self.hb_channel.unpause() + + # Wait for heartbeat + attempt = 0 + max_attempts = 10 + while not self.hb_channel.is_beating(): + attempt += 1 + if attempt > max_attempts: + raise Exception("The kernel took too long to connect to the Kernel Sockets.") + await asyncio.sleep(0.1) + + # Test kernel communication (handles retries internally) + if not await self._test_kernel_communication(): + self.log.error( + f"Kernel communication test failed after {self.connection_test_timeout}s timeout" + ) + return False + + # Mark connection as ready and process queued messages + self.mark_connection_ready() + + # Update execution state to idle if it's not already set + # (it might already be idle if we received a status message during connection test) + if self.execution_state == ExecutionStates.BUSY.value: + self.execution_state = ExecutionStates.IDLE.value + self.last_activity = datetime.now(timezone.utc) + + self.log.info("Successfully connected to kernel") + return True + + except Exception as e: + self.log.error(f"Failed to connect to kernel: {e}") + self._connecting = False + return False + + async def disconnect(self): + """Disconnect from the kernel and reset connection state. + + This method: + 1. Stops listening for messages + 2. Stops all channels + 3. Resets connection state flags + 4. Clears channel references + + Note: Does not remove listeners - they will be preserved for reconnection. + """ + # Stop listening for messages + await self.stop_listening() + + # Stop all channels + self.stop_channels() + + # Reset connection state + self._connecting = False + self._connection_ready = False + self._connection_ready_event.clear() + + # Clear channel references + self._shell_channel = None + self._iopub_channel = None + self._stdin_channel = None + self._control_channel = None + self._hb_channel = None + + self.log.info("Disconnected from kernel") + + async def reconnect(self) -> bool: + """Reconnect to the kernel. + + This is a convenience method that disconnects and then connects again. + Useful for recovering from stale connections or network issues. + + Returns: + bool: True if reconnection successful, False otherwise + """ + self.log.info("Reconnecting to kernel...") + await self.disconnect() + return await self.connect() + + +class JupyterServerKernelClient(JupyterServerKernelClientMixin, AsyncKernelClient): + """ + A kernel client with listener functionality and message queuing. + """ diff --git a/jupyter_server/services/kernels/v3/connection/__init__.py b/jupyter_server/services/kernels/v3/connection/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/jupyter_server/services/kernels/v3/connection/client_connection.py b/jupyter_server/services/kernels/v3/connection/client_connection.py new file mode 100644 index 000000000..3baf9e702 --- /dev/null +++ b/jupyter_server/services/kernels/v3/connection/client_connection.py @@ -0,0 +1,207 @@ +from jupyter_client.session import Session +from tornado.websocket import WebSocketClosedError +from traitlets import List as TraitletsList +from traitlets import Tuple as TraitletsTuple + +from jupyter_server.services.kernels.connection.base import ( + BaseKernelWebsocketConnection, + deserialize_msg_from_ws_v1, + serialize_msg_to_ws_v1, +) + +from ..message_utils import encode_cell_id_in_message, strip_encoding_from_message + + +class KernelClientWebsocketConnection(BaseKernelWebsocketConnection): + """WebSocket connection that bridges frontend clients to the shared kernel client. + + This class implements the WebSocket side of the shared kernel client architecture. + Instead of creating its own ZMQ connections to the kernel, it registers as a listener + on the kernel manager's shared kernel client and routes messages bidirectionally. + + Key Responsibilities: + + 1. **Listener Registration**: + - Registers itself as a message listener on the shared kernel client + - Receives all kernel messages through the listener callback + - Automatically removed when the WebSocket disconnects + + 2. **Cell ID Encoding** (Outgoing to Kernel): + - Extracts `cellId` from message metadata + - Appends cell ID to message ID: `msg_id#{cell_id}` + - Enables routing responses back to the originating cell + - Works in conjunction with kernel client's channel encoding + + 3. **Message ID Decoding** (Incoming from Kernel): + - Strips both channel and cell ID encoding from message IDs + - Returns messages to frontend with original msg_ids + - Frontend receives messages in the format it expects + + 4. **Message Filtering** (Configurable): + - Supports filtering messages by type and channel + - Can include specific message types (`msg_types` trait) + - Can exclude specific message types (`exclude_msg_types` trait) + - Reduces bandwidth for specialized clients + """ + + kernel_ws_protocol = "v1.kernel.websocket.jupyter.org" + + # Configurable message filtering traits + msg_types = TraitletsList( + trait=TraitletsTuple(), + default_value=None, + allow_none=True, + config=True, + help=""" + List of (msg_type, channel) tuples to include for this websocket connection. + If None (default), all messages are sent. If specified, only messages matching + these (msg_type, channel) pairs will be sent to the websocket. + + Example: [("status", "iopub"), ("execute_reply", "shell")] + """, + ) + + exclude_msg_types = TraitletsList( + trait=TraitletsTuple(), + default_value=None, + allow_none=True, + config=True, + help=""" + List of (msg_type, channel) tuples to exclude for this websocket connection. + If None (default), no messages are excluded. If specified, messages matching + these (msg_type, channel) pairs will NOT be sent to the websocket. + + Example: [("status", "iopub")] + + Note: Cannot be used together with msg_types. If both are specified, + msg_types takes precedence. + """, + ) + + def _get_kernel_client(self): + """Get the kernel client directly from the kernel manager. + + The kernel client is now a property on the kernel manager itself, + created immediately when the kernel manager is instantiated. + + Note: self.kernel_manager is actually the parent, which is the specific + KernelManager instance for this kernel (not the MultiKernelManager). + """ + try: + # self.kernel_manager is the specific KernelManager for this kernel + km = self.kernel_manager + if not km: + raise RuntimeError(f"No kernel manager found for kernel {self.kernel_id}") + + # Get the pre-created kernel client from the kernel manager + if not hasattr(km, "kernel_client") or km.kernel_client is None: + raise RuntimeError(f"Kernel manager for {self.kernel_id} has no kernel_client") + + return km.kernel_client + + except Exception as e: + raise RuntimeError(f"Failed to get kernel client for kernel {self.kernel_id}: {e}") + + async def connect(self): + """Connect to the kernel via a kernel session with deferred channel connection. + + The client connection is now handled by the kernel manager in post_start_kernel(). + The websocket just needs to add itself as a listener to receive messages. + """ + # Get the client from the kernel manager + client = self._get_kernel_client() + + # Add websocket listener immediately (messages will be queued if not ready) + # Use configured message filtering if specified + if self.msg_types is not None: + # Convert list of tuples to list for the API + msg_types_list = [tuple(item) for item in self.msg_types] if self.msg_types else None + client.add_listener(self.handle_outgoing_message, msg_types=msg_types_list) + elif self.exclude_msg_types is not None: + # Convert list of tuples to list for the API + exclude_msg_types_list = ( + [tuple(item) for item in self.exclude_msg_types] if self.exclude_msg_types else None + ) + client.add_listener( + self.handle_outgoing_message, exclude_msg_types=exclude_msg_types_list + ) + else: + # No filtering - listen to all messages (default) + client.add_listener(self.handle_outgoing_message) + + # Broadcast current kernel state to this websocket immediately + # This ensures websockets that connect during/after restart get the current state + await client.broadcast_state() + + self.log.info(f"Kernel websocket connected and listening for kernel {self.kernel_id}") + + def disconnect(self): + """Disconnect the websocket from the kernel client.""" + try: + # Get the kernel client from the kernel manager + client = self._get_kernel_client() + if client: + # Remove this websocket's listener from the client + client.remove_listener(self.handle_outgoing_message) + except Exception as e: + self.log.warning(f"Failed to disconnect websocket for kernel {self.kernel_id}: {e}") + + def handle_incoming_message(self, incoming_msg): + """Handle incoming messages from WebSocket, encoding cellId into msg_id.""" + channel_name, msg_list = deserialize_msg_from_ws_v1(incoming_msg) + + try: + # Get the kernel client from the kernel manager + client = self._get_kernel_client() + if not client: + return + + # Extract cellId from metadata and encode into msg_id + try: + if len(msg_list) >= 3: # Need header, parent_header, metadata + session = Session() + metadata = session.unpack(msg_list[2]) + cell_id = metadata.get("cellId") + + if cell_id: + msg_list = encode_cell_id_in_message(msg_list, cell_id) + except Exception as e: + self.log.debug(f"Error encoding cellId in msg_id: {e}") + + # Send to kernel client (which will prepend channel) + client.handle_incoming_message(channel_name, msg_list) + except Exception as e: + self.log.error(f"Failed to handle incoming message for kernel {self.kernel_id}: {e}") + + def handle_outgoing_message(self, channel_name, msg): + """Handle outgoing messages to WebSocket, stripping channel and cellId from msg_id.""" + try: + # Validate message has minimum required parts + if not msg or len(msg) < 4: + self.log.warning( + f"Message on {channel_name} has insufficient parts: {len(msg) if msg else 0}" + ) + return + + # Validate parts are bytes + for i, part in enumerate(msg[:4]): + if not isinstance(part, bytes): + self.log.error(f"Message part {i} on {channel_name} is not bytes: {type(part)}") + return + + # Strip channel and cellId from msg_ids before sending to frontend + try: + msg = strip_encoding_from_message(msg) + except Exception as e: + self.log.debug(f"Error stripping encoding from msg_ids: {e}") + # Continue with original message if stripping fails + + # Serialize to websocket format and send + bin_msg = serialize_msg_to_ws_v1(msg, channel_name) + self.websocket_handler.write_message(bin_msg, binary=True) + except WebSocketClosedError: + self.log.warning("A Kernel Socket message arrived on a closed websocket channel.") + except Exception as err: + self.log.error( + f"Error handling outgoing message on {channel_name}: {err}", exc_info=True + ) diff --git a/jupyter_server/services/kernels/v3/kernelmanager.py b/jupyter_server/services/kernels/v3/kernelmanager.py new file mode 100644 index 000000000..e9f82eefe --- /dev/null +++ b/jupyter_server/services/kernels/v3/kernelmanager.py @@ -0,0 +1,124 @@ +"""Kernel manager for the Apple JupyterLab Kernel Monitor Extension.""" + +from jupyter_client.multikernelmanager import AsyncMultiKernelManager +from traitlets import Instance, Type, default, observe + +from jupyter_server.services.kernels.kernelmanager import ( + MappingKernelManager, +) +from jupyter_server.services.kernels.kernelmanager import ( + ServerKernelManager as _ServerKernelManager, +) + +from .client import JupyterServerKernelClient + + +class ServerKernelManager(_ServerKernelManager): + """Kernel manager with enhanced client. + + This kernel manager inherits from ServerKernelManager and adds: + - Enhanced kernel client (JupyterServerKernelClient) with message ID encoding + - Pre-created kernel client instance stored as a property + - Automatic client connection/disconnection on kernel start/shutdown + + The client encodes channel information in message IDs using simple string operations. + """ + + client_class = Type( + default_value=JupyterServerKernelClient, + klass="jupyter_client.client.KernelClient", + config=True, + help="""The kernel client class to use for creating kernel clients.""", + ) + + client_factory = Type( + default_value=JupyterServerKernelClient, + klass="jupyter_client.client.KernelClient", + config=True, + help="""The kernel client factory class to use.""", + ) + + kernel_client = Instance( + "jupyter_client.client.KernelClient", + allow_none=True, + help="""Pre-created kernel client instance. Created on initialization.""", + ) + + def __init__(self, **kwargs): + """Initialize the kernel manager and create a kernel client instance.""" + super().__init__(**kwargs) + + # Create a kernel client instance immediately + self.kernel_client = self.client(session=self.session) + + @observe("client_class") + def _client_class_changed(self, change): + """Override parent's _client_class_changed to handle Type trait instead of DottedObjectName.""" + # Set client_factory to the same class + self.client_factory = change["new"] + + async def _async_post_start_kernel(self, **kwargs): + """After kernel starts, connect the kernel client. + + This method is called after the kernel has been successfully started. + It loads the latest connection info (with ports set by provisioner) + and connects the kernel client to the kernel. + + Note: If you override this method, make sure to call super().post_start_kernel(**kwargs) + to ensure the kernel client connects properly. + """ + await super()._async_post_start_kernel(**kwargs) + try: + # Load latest connection info from kernel manager + # The provisioner has now set the real ports + self.kernel_client.load_connection_info(self.get_connection_info(session=True)) + + # Connect the kernel client + success = await self.kernel_client.connect() + + if not success: + raise RuntimeError(f"Failed to connect kernel client for kernel {self.kernel_id}") + + self.log.info(f"Successfully connected kernel client for kernel {self.kernel_id}") + + except Exception as e: + self.log.error(f"Failed to connect kernel client: {e}") + # Re-raise to fail the kernel start + raise + + async def cleanup_resources(self, restart=False): + """Cleanup resources, disconnecting the kernel client if not restarting. + + Parameters + ---------- + restart : bool + If True, the kernel is being restarted and we should keep the client + connected but clear its state. If False, fully disconnect. + """ + if self.kernel_client: + if restart: + # On restart, clear client state but keep connection + # The connection will be refreshed in post_start_kernel after restart + self.log.debug( + f"Clearing kernel client state for restart of kernel {self.kernel_id}" + ) + self.kernel_client.last_shell_status_time = None + self.kernel_client.last_control_status_time = None + # Disconnect before restart - will reconnect after + await self.kernel_client.stop_listening() + self.kernel_client.stop_channels() + else: + # On shutdown, fully disconnect the client + self.log.debug(f"Disconnecting kernel client for kernel {self.kernel_id}") + await self.kernel_client.stop_listening() + self.kernel_client.stop_channels() + + await super().cleanup_resources(restart=restart) + + +class AsyncMappingKernelManager(MappingKernelManager, AsyncMultiKernelManager): # type:ignore[misc] + """Custom kernel manager that uses enhanced monitoring kernel manager with v3 API.""" + + @default("kernel_manager_class") + def _default_kernel_manager_class(self): + return "jupyter_server.services.kernels.v3.kernelmanager.ServerKernelManager" diff --git a/jupyter_server/services/kernels/v3/message_utils.py b/jupyter_server/services/kernels/v3/message_utils.py new file mode 100644 index 000000000..0544bb39f --- /dev/null +++ b/jupyter_server/services/kernels/v3/message_utils.py @@ -0,0 +1,326 @@ +"""Utilities for encoding and decoding channel and source ID information in message IDs. + +This module provides functions to encode channel names and source IDs (like cell IDs) +directly into message IDs, eliminating the need for a separate message cache to track +message metadata. + +Format: {channel}:{base_msg_id}#{src_id} + +Examples: + - With channel and src_id: "shell:a1b2c3d4_12345_0#cell-abc123" + - With channel only: "shell:a1b2c3d4_12345_0" + - Legacy format (no encoding): "a1b2c3d4_12345_0" +""" + +from typing import List, Optional, Tuple + +from jupyter_client.session import Session + + +class MsgIdError(Exception): + """Base exception for message ID operations.""" + + pass + + +class InvalidMsgIdFormatError(MsgIdError): + """Raised when a message ID has an invalid format.""" + + pass + + +class InvalidChannelError(MsgIdError): + """Raised when a channel name contains reserved characters.""" + + pass + + +class InvalidSrcIdError(MsgIdError): + """Raised when a source ID contains reserved characters.""" + + pass + + +def validate_channel(channel: Optional[str]) -> None: + """Validate that a channel name doesn't contain reserved characters. + + Args: + channel: Channel name to validate + + Raises: + InvalidChannelError: If channel contains ':' character + """ + if channel is not None and ":" in channel: + raise InvalidChannelError(f"Channel name cannot contain ':' character: {channel}") + + +def validate_src_id(src_id: Optional[str]) -> None: + """Validate that a source ID doesn't contain reserved characters. + + Args: + src_id: Source ID to validate + + Raises: + InvalidSrcIdError: If src_id contains ':' or '#' characters + """ + if src_id is not None: + if ":" in src_id: + raise InvalidSrcIdError(f"Source ID cannot contain ':' character: {src_id}") + if "#" in src_id: + raise InvalidSrcIdError(f"Source ID cannot contain '#' character: {src_id}") + + +def create_msg_id( + base_msg_id: str, channel: Optional[str] = None, src_id: Optional[str] = None +) -> str: + """Create a structured message ID with optional channel and source ID encoding. + + Args: + base_msg_id: Core message ID (e.g., "{session}_{pid}_{counter}") + channel: Optional channel name (shell, control, etc.) + src_id: Optional source identifier (e.g., cell ID) + + Returns: + Formatted message ID string + + Raises: + InvalidChannelError: If channel contains ':' + InvalidSrcIdError: If src_id contains ':' or '#' + + Examples: + >>> create_msg_id("abc123_456_0", "shell", "cell-xyz") + 'shell:abc123_456_0#cell-xyz' + >>> create_msg_id("abc123_456_1", "control") + 'control:abc123_456_1' + >>> create_msg_id("abc123_456_2") + 'abc123_456_2' + """ + validate_channel(channel) + validate_src_id(src_id) + + if channel is None: + # Legacy format for backward compatibility + result = base_msg_id + else: + result = f"{channel}:{base_msg_id}" + + if src_id is not None: + result = f"{result}#{src_id}" + + return result + + +def parse_msg_id(msg_id: str) -> Tuple[Optional[str], Optional[str], str]: + """Parse a message ID into its components. + + Args: + msg_id: Message ID string to parse + + Returns: + Tuple of (channel, src_id, base_msg_id) where: + - channel: Channel name ('shell', 'control', etc.) or None for legacy format + - src_id: Source identifier (e.g., cell ID) or None + - base_msg_id: Core message ID + + Examples: + >>> parse_msg_id("shell:abc123_456_0#cell-xyz") + ('shell', 'cell-xyz', 'abc123_456_0') + >>> parse_msg_id("control:abc123_456_1") + ('control', None, 'abc123_456_1') + >>> parse_msg_id("abc123_456_2") + (None, None, 'abc123_456_2') + """ + if not msg_id: + raise InvalidMsgIdFormatError("Message ID cannot be empty") + + # Split off src_id if present (after #) + if "#" in msg_id: + msg_id_part, src_id = msg_id.split("#", 1) + else: + msg_id_part = msg_id + src_id = None + + # Split channel and base msg_id (before :) + if ":" in msg_id_part: + channel, base_msg_id = msg_id_part.split(":", 1) + else: + # Legacy format - no channel specified + channel = None + base_msg_id = msg_id_part + + return channel, src_id, base_msg_id + + +def extract_channel(msg_id: str) -> Optional[str]: + """Extract just the channel from a message ID. + + Args: + msg_id: Message ID string + + Returns: + Channel name or None if not present + + Examples: + >>> extract_channel("shell:abc123_456_0#cell-xyz") + 'shell' + >>> extract_channel("abc123_456_2") + None + """ + channel, _, _ = parse_msg_id(msg_id) + return channel + + +def extract_src_id(msg_id: str) -> Optional[str]: + """Extract just the source ID from a message ID. + + Args: + msg_id: Message ID string + + Returns: + Source ID or None if not present + + Examples: + >>> extract_src_id("shell:abc123_456_0#cell-xyz") + 'cell-xyz' + >>> extract_src_id("shell:abc123_456_1") + None + """ + _, src_id, _ = parse_msg_id(msg_id) + return src_id + + +def extract_base_msg_id(msg_id: str) -> str: + """Extract just the base message ID from a message ID. + + Args: + msg_id: Message ID string + + Returns: + Base message ID (without channel or src_id encoding) + + Examples: + >>> extract_base_msg_id("shell:abc123_456_0#cell-xyz") + 'abc123_456_0' + >>> extract_base_msg_id("abc123_456_2") + 'abc123_456_2' + """ + _, _, base_msg_id = parse_msg_id(msg_id) + return base_msg_id + + +# ============================================================================ +# Message-level utilities for websocket connection +# ============================================================================ + + +def encode_channel_in_message_dict(msg_dict: dict, channel: str) -> dict: + """Encode channel into a message dict's header msg_id. + + This utility is used for client-initiated messages (not from websocket) + to ensure they have channel encoding for proper state tracking. + + Args: + msg_dict: Message dictionary with header, parent_header, metadata, content + channel: Channel name to encode ('shell', 'control', etc.) + + Returns: + Modified message dict with channel encoded in header msg_id + + Examples: + >>> msg = session.msg("kernel_info_request") + >>> msg = encode_channel_in_message_dict(msg, "shell") + >>> # msg['header']['msg_id'] now has 'shell:' prefix + """ + if "header" in msg_dict and "msg_id" in msg_dict["header"]: + msg_id = msg_dict["header"]["msg_id"] + # Only encode if not already encoded + if not msg_id.startswith(f"{channel}:"): + msg_dict["header"]["msg_id"] = f"{channel}:{msg_id}" + return msg_dict + + +def encode_cell_id_in_message(msg_list: List[bytes], cell_id: str) -> List[bytes]: + """Encode a cell ID into the header msg_id of a message. + + This utility function encapsulates the session pack/unpack operations needed + to add a cell ID to a message's header. It's designed to keep the websocket + connection code lean. + + Args: + msg_list: Message parts as list of bytes [header, parent_header, metadata, content, ...] + cell_id: Cell ID to encode into the message + + Returns: + Modified message list with cell ID encoded in header msg_id + + Examples: + >>> # msg_list with msg_id "abc123" becomes msg_id "abc123#cell-xyz" + >>> modified_msg = encode_cell_id_in_message(msg_list, "cell-xyz") + """ + # Need at least header part + if not msg_list or len(msg_list) < 1: + return msg_list + + try: + session = Session() + msg_copy = list(msg_list) # Make a copy to avoid modifying original + + # Unpack header + header = session.unpack(msg_copy[0]) + + # Encode cell ID into msg_id if not already present + if "msg_id" in header: + msg_id = header["msg_id"] + if "#" not in msg_id: # Only add if not already encoded + header["msg_id"] = f"{msg_id}#{cell_id}" + msg_copy[0] = session.pack(header) + + return msg_copy + except Exception: + # If encoding fails, return original message + return msg_list + + +def strip_encoding_from_message(msg_list: List[bytes]) -> List[bytes]: + """Strip channel and cell ID encoding from header and parent_header msg_ids. + + This utility function encapsulates the session pack/unpack operations needed + to strip encoding from a message before sending to the frontend. It's designed + to keep the websocket connection code lean. + + Args: + msg_list: Message parts as list of bytes [header, parent_header, metadata, content, ...] + + Returns: + Modified message list with encoding stripped from msg_ids + + Examples: + >>> # msg_list with msg_id "shell:abc123#cell-xyz" becomes "abc123" + >>> clean_msg = strip_encoding_from_message(msg_list) + """ + # Need at least header and parent_header + if not msg_list or len(msg_list) < 2: + return msg_list + + try: + session = Session() + msg_copy = list(msg_list) # Make a copy to avoid modifying original + + # Strip from header msg_id + header = session.unpack(msg_copy[0]) + if "msg_id" in header: + _, _, base_msg_id = parse_msg_id(header["msg_id"]) + header["msg_id"] = base_msg_id + msg_copy[0] = session.pack(header) + + # Strip from parent_header msg_id + parent_header = session.unpack(msg_copy[1]) + if "msg_id" in parent_header and parent_header["msg_id"]: + _, _, base_msg_id = parse_msg_id(parent_header["msg_id"]) + parent_header["msg_id"] = base_msg_id + msg_copy[1] = session.pack(parent_header) + + return msg_copy + except Exception: + # If decoding fails, return original message + return msg_list diff --git a/jupyter_server/services/kernels/v3/states.py b/jupyter_server/services/kernels/v3/states.py new file mode 100644 index 000000000..fb47eb824 --- /dev/null +++ b/jupyter_server/services/kernels/v3/states.py @@ -0,0 +1,23 @@ +from enum import Enum, EnumMeta + + +class StrContainerEnumMeta(EnumMeta): + def __contains__(cls, item): + for name, member in cls.__members__.items(): + if item == name or item == member.value: + return True + return False + + +class StrContainerEnum(str, Enum, metaclass=StrContainerEnumMeta): + """A Enum object that enables search for items + in a normal Enum object based on key and value. + """ + + +class ExecutionStates(StrContainerEnum): + BUSY = "busy" + IDLE = "idle" + STARTING = "starting" + UNKNOWN = "unknown" + DEAD = "dead"