Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ dev = [
"basedpyright",
"ruff",
"debugpy",
"ipykernel",
{include-group = "uvloop"},
{include-group = "test"}
]
Expand Down
38 changes: 31 additions & 7 deletions src/async_kernel/asyncshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from async_kernel import utils
from async_kernel.caller import Caller
from async_kernel.compiler import XCachingCompiler
from async_kernel.typing import Content, MetadataKeys, Tags
from async_kernel.typing import MetadataKeys, Tags

if TYPE_CHECKING:
from collections.abc import Callable

from async_kernel.kernel import Kernel


Expand Down Expand Up @@ -73,10 +75,14 @@ class AsyncDisplayPublisher(DisplayPublisher):

topic: ClassVar = b"display_data"

def __init__(self, shell=None, *args, **kwargs) -> None:
super().__init__(shell, *args, **kwargs)
self._hooks = []

@override
def publish( # pyright: ignore[reportIncompatibleMethodOverride]
self,
data: Content,
data: dict[str, Any],
metadata: dict | None = None,
*,
transient: dict | None = None,
Expand All @@ -95,11 +101,17 @@ def publish( # pyright: ignore[reportIncompatibleMethodOverride]

[Reference](https://jupyter-client.readthedocs.io/en/stable/messaging.html#update-display-data)
"""
utils.get_kernel().iopub_send(
msg_or_type="update_display_data" if update else "display_data",
content={"data": data, "metadata": metadata or {}, "transient": transient or {}} | kwargs,
ident=self.topic,
)
content = {"data": data, "metadata": metadata or {}, "transient": transient or {}} | kwargs
msg_type = "update_display_data" if update else "display_data"
msg = utils.get_kernel().session.msg(msg_type, content, parent=utils.get_parent()) # pyright: ignore[reportArgumentType]
for hook in self._hooks:
try:
msg = hook(msg)
except Exception:
pass
if msg is None:
return
utils.get_kernel().iopub_send(msg)

@override
def clear_output(self, wait: bool = False) -> None:
Expand All @@ -113,6 +125,18 @@ def clear_output(self, wait: bool = False) -> None:
"""
utils.get_kernel().iopub_send(msg_or_type="clear_output", content={"wait": wait}, ident=self.topic)

def register_hook(self, hook: Callable[[dict], dict | None]) -> None:
"""Register a hook for when publish is called.

The hook should return the message or None.
Only return `None` when the message should *not* be sent.
"""
self._hooks.append(hook)

def unregister_hook(self, hook: Callable[[dict], dict | None]) -> None:
while hook in self._hooks:
self._hooks.remove(hook)


class AsyncInteractiveShell(InteractiveShell):
"""
Expand Down
8 changes: 8 additions & 0 deletions src/async_kernel/comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from typing import Self

import comm
Expand Down Expand Up @@ -123,3 +124,10 @@ def set_comm():
"""
comm.create_comm = Comm
comm.get_comm_manager = get_comm_manager

# # Monkey patch ipykernel for modules that use it such as pyviz_comms:https://github.com/holoviz/pyviz_comms/blob/4cd44d902364590ba8892c8e7f48d7888d0a1c0c/pyviz_comms/__init__.py#L403C14-L403C28
with contextlib.suppress(ImportError):
import ipykernel.comm # noqa: PLC0415

ipykernel.comm.Comm = Comm
ipykernel.comm.CommManager = CommManager
2 changes: 1 addition & 1 deletion src/async_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def __init__(self, settings: dict | None = None, /) -> None:
if not os.environ.get("MPLBACKEND"):
os.environ["MPLBACKEND"] = "module://matplotlib_inline.backend_inline"
# setting get loaded in `_validate_settings`
assert self.shell, "The shell should be loaded here."
self._settings = settings or {}

@override
Expand Down Expand Up @@ -474,6 +473,7 @@ async def _start_in_context(self) -> AsyncGenerator[Self, Any]:
if self._sockets:
msg = "Already started"
raise RuntimeError(msg)
assert self.shell
self.anyio_backend = sniffio.current_async_library()
try:
async with Caller(log=self.log, create=True, protected=True) as caller:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,21 @@ async def test_shell_can_set_namespace(kernel):
assert set(kernel.shell.user_ns) == {"Out", "_oh", "In", "exit", "_dh", "open", "get_ipython", "_ih", "quit"}


async def test_shell_display_hook_reg(kernel: Kernel):
val: None | dict = None

def my_hook(msg):
nonlocal val
val = msg

kernel.shell.display_pub.register_hook(my_hook)
assert my_hook in kernel.shell.display_pub._hooks # pyright: ignore[reportPrivateUsage]
kernel.shell.display_pub.publish({"test": True})
kernel.shell.display_pub.unregister_hook(my_hook)
assert my_hook not in kernel.shell.display_pub._hooks # pyright: ignore[reportPrivateUsage]
assert val


@pytest.mark.parametrize("mode", RunMode)
async def test_header_mode(client, mode: RunMode):
code = f"""
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.