Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions src/async_kernel/interface/zmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import threading
import time
from contextlib import asynccontextmanager
from threading import Event
from typing import TYPE_CHECKING, Any, Generic, Literal, Never, Self

import zmq
from aiologic import BinarySemaphore
from aiologic import BinarySemaphore, CountdownEvent
from aiologic.lowlevel import enable_signal_safety
from jupyter_client.connect import ConnectionFileMixin
from jupyter_client.session import Session
Expand All @@ -27,7 +26,7 @@
from async_kernel.typing import Channel, Content, Job, Message, MsgHeader, NoValue, T_shell_co

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator
from collections.abc import AsyncGenerator, Callable, Generator

from jupyter_client import KernelConnectionInfo

Expand Down Expand Up @@ -134,18 +133,18 @@ async def __asynccontextmanager__(self, *, set_started=True) -> AsyncGenerator[S
self._zmq_context.term()

def _start_hb_iopub_shell_control_threads(self) -> None:
def heartbeat(ready: Event) -> None:
def heartbeat(ready: Callable[[], None]) -> None:
# ref: https://jupyter-client.readthedocs.io/en/stable/messaging.html#heartbeat-for-kernels
utils.mark_thread_pydev_do_not_trace()
with self._bind_socket(Channel.heartbeat) as socket:
ready.set()
ready()
self.started.wait_sync()
try:
zmq.proxy(socket, socket)
except zmq.ContextTerminated:
return

def pub_proxy(ready: Event) -> None:
def pub_proxy(ready: Callable[[], None]) -> None:
utils.mark_thread_pydev_do_not_trace()

# We use an internal proxy to collect pub messages for distribution.
Expand All @@ -158,30 +157,28 @@ def pub_proxy(ready: Event) -> None:
# Capture broadcast messages received on both frontend and backend
capture = self._zmq_context.socket(zmq.PUB)
capture.bind(self._iopub_url)
threading.Thread(target=self._pub_capture, args=[ready]).start()

with self._bind_socket(Channel.iopub) as iopub_socket:
ready.set()
ready()
try:
zmq.proxy(frontend, iopub_socket, capture)
except (zmq.ContextTerminated, Exception):
pass
frontend.close(linger=50)
capture.close(linger=50)

hb_ready, iopub_ready = (Event(), Event())
threading.Thread(target=heartbeat, name="heartbeat", args=[hb_ready]).start()
hb_ready.wait()
threading.Thread(target=pub_proxy, name="iopub proxy", args=[iopub_ready]).start()
iopub_ready.wait()
ready = CountdownEvent(5)

threading.Thread(target=heartbeat, name="heartbeat", args=[ready.down]).start()
threading.Thread(target=pub_proxy, name="iopub proxy", args=[ready.down]).start()
threading.Thread(target=self._pub_capture, args=[ready.down]).start()
# message loops
for channel in [Channel.shell, Channel.control]:
ready = Event()
name = f"{channel}-receive_msg_loop"
threading.Thread(target=self.receive_msg_loop, name=name, args=(channel, ready)).start()
ready.wait()
threading.Thread(target=self.receive_msg_loop, name=name, args=(channel, ready.down)).start()
ready.wait()

def _pub_capture(self, ready: Event) -> None:
def _pub_capture(self, ready: Callable[[], None]) -> None:
"""
Capture connection messages on iopub.

Expand All @@ -197,7 +194,7 @@ def _pub_capture(self, ready: Event) -> None:
# Only subscribe to the 'pub subscribe' topic byte `1` (byte `0` is 'pub unsubscribe').
socket.subscribe(b"\x01")
with socket:
ready.wait()
ready()
self.started.wait_sync()
while True:
try:
Expand Down Expand Up @@ -307,7 +304,7 @@ def iopub_send(
buffers=buffers,
)

def receive_msg_loop(self, channel: Literal[Channel.control, Channel.shell], ready: Event) -> None:
def receive_msg_loop(self, channel: Literal[Channel.control, Channel.shell], ready: Callable[[], None]) -> None:
"""
Opens a zmq socket for the channel, receives messages and calls the message handler.
"""
Expand All @@ -333,7 +330,7 @@ async def send_reply(job: Job, content: dict, /) -> None:
log.debug("***send_reply %s*** %s", channel, msg)

with self._bind_socket(channel) as socket:
ready.set()
ready()
self.started.wait_sync()

while True:
Expand Down
Loading