Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
"--notebook-dir=docs/notebooks",
"--no-browser",
"--IdentityProvider.token=''",
"--port=9991"
"--port=9991",
"--KernelManager.transport_encryption=auto"
],
"cwd": "${workspaceFolder}",
"justMyCode": false,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dependencies = [
"sniffio>=1.3.0; sys_platform != 'emscripten'",
"outcome; sys_platform != 'emscripten'",
"pyzmq>=27.0; sys_platform != 'emscripten'", # pyzmq sockets (and threading) are not supported on pyodide (emscripten), `CallableInterface` is used for kernel messaging.
"jupyter_client>=8.8; sys_platform != 'emscripten'",
"jupyter_client>=8.9.0; sys_platform != 'emscripten'",
]

[project.urls]
Expand Down
10 changes: 9 additions & 1 deletion src/async_kernel/interface/zmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator

from jupyter_client import KernelConnectionInfo


__all__ = ["ZMQInterface"]

Expand Down Expand Up @@ -162,7 +164,7 @@ def _observe_connection_file(self, change) -> None:
self.load_connection_info(json.loads(path.read_bytes()))

@override
def load_connection_info(self, info: dict[str, Any]) -> None:
def load_connection_info(self, info: KernelConnectionInfo) -> None:
if self._sockets:
msg = "It is too late to configure!"
raise RuntimeError(msg)
Expand Down Expand Up @@ -285,6 +287,10 @@ def _bind_socket(self, channel: Channel) -> Generator[Any | Socket[Any], Any, No
socket_type = zmq.XPUB
socket: zmq.Socket = self._zmq_context.socket(socket_type)
socket.linger = 50
if self.curve_secretkey is not None and self.curve_publickey is not None:
socket.curve_secretkey = self.curve_secretkey
socket.curve_publickey = self.curve_publickey
socket.curve_server = True
name = f"{channel}_port"
port = bind_socket(socket=socket, transport=self.transport, ip=self.ip, port=getattr(self, name)) # pyright: ignore[reportArgumentType]
self.set_trait(name, port)
Expand Down Expand Up @@ -315,6 +321,8 @@ def _write_connection_file(
key=self.session.key,
signature_scheme=self.session.signature_scheme,
kernel_name=self.name,
curve_publickey=self.curve_publickey,
curve_secretkey=self.curve_secretkey,
**{f"{channel}_port": getattr(self, f"{channel}_port") for channel in Channel},
)
ip_files: list[pathlib.Path] = []
Expand Down
4 changes: 3 additions & 1 deletion src/async_kernel/kernelspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def write_kernel_spec(
"display_name": display_name or f"Python {sys.version.split()[0]} ({name})",
"language": language,
"interrupt_mode": "message",
"metadata": metadata if metadata is not None else {"debugger": True, "concurrent": True},
"metadata": metadata
if metadata is not None
else {"debugger": True, "concurrent": True, "supported_encryption": "curve"},
"kernel_protocol_version": PROTOCOL_VERSION,
}
# write kernel.json
Expand Down
2 changes: 1 addition & 1 deletion tests/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_install_kernel_start_zmq_interface(monkeypatch, fake_kernel_dir: pathli
"display_name": "my kernel",
"language": "python",
"interrupt_mode": "message",
"metadata": {"debugger": True, "concurrent": True},
"metadata": {"debugger": True, "concurrent": True, "supported_encryption": "curve"},
"kernel_protocol_version": "5.5",
}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernelspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_install_kernel_spec(tmp_path: Path, monkeypatch):
"display_name": f"Python {sys.version.split()[0]} ({name})",
"language": "python",
"interrupt_mode": "message",
"metadata": {"debugger": True, "concurrent": True},
"metadata": {"debugger": True, "concurrent": True, "supported_encryption": "curve"},
"kernel_protocol_version": "5.5",
}
}
Expand Down
13 changes: 5 additions & 8 deletions tests/test_pending.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,11 @@ async def test_wait_sync(self, caller: Caller, result: bool, anyio_backend: Back
assert pen.wait_sync(result=result) == (2 if result else None)
assert pen.wait_sync(result=result) == (2 if result else None)

async def test_wait_sync_timeout(self, anyio_backend: Backend):
async with Caller("manual") as caller:
pen = caller.call_soon(lambda: 0 / 1) # should never get called
with pytest.raises(TimeoutError):
pen.wait_sync(timeout=0.001)
assert pen.cancelled()
with pytest.raises(PendingCancelled):
await pen
def test_wait_sync_timeout(self):
pen = Pending()
with pytest.raises(TimeoutError):
pen.wait_sync(timeout=0.001)
assert pen.cancelled()

async def test_many_waiters(self, caller: Caller):
N = 100
Expand Down
63 changes: 63 additions & 0 deletions tests/test_zmq_curve_encrypted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import anyio
import pytest
import zmq
from jupyter_client import connect
from jupyter_client.asynchronous.client import AsyncKernelClient

from async_kernel.interface.zmq import ZMQInterface
from async_kernel.typing import Channel
from tests import utils

if TYPE_CHECKING:
import pathlib

from async_kernel import Kernel


# pyright: reportPrivateUsage=false


@pytest.fixture(scope="module")
async def curve_encrypted_kernel(anyio_backend, tmp_path_factory):
connection_file: pathlib.Path = tmp_path_factory.mktemp("async_kernel") / "temp_connection.json"
curve_publickey, curve_secretkey = zmq.curve_keypair()
connect.write_connection_file(
str(connection_file),
curve_publickey=curve_publickey,
curve_secretkey=curve_secretkey,
)
interface = ZMQInterface(connection_file=connection_file)
async with interface:
yield interface.kernel


@pytest.fixture(scope="module")
async def curve_encrypted_client(curve_encrypted_kernel: Kernel):

assert isinstance(curve_encrypted_kernel.parent, ZMQInterface)
client = AsyncKernelClient()
client.load_connection_info(curve_encrypted_kernel.parent.get_connection_info())
client.start_channels()
try:
yield client
finally:
await utils.clear_iopub(client, timeout=0.1)
client.stop_channels()
await anyio.sleep(0)


async def test_curve_encryption(
curve_encrypted_kernel: Kernel[ZMQInterface], curve_encrypted_client: AsyncKernelClient
):
assert curve_encrypted_kernel.parent._sockets[Channel.shell].curve_server == 1

msg_id, reply = await utils.execute(curve_encrypted_client, "1+1", clear_pub=False)
assert reply["status"] == "ok"
await utils.check_pub_message(curve_encrypted_client, msg_id, execution_state="busy")
await utils.check_pub_message(curve_encrypted_client, msg_id, msg_type="execute_input")
await utils.check_pub_message(curve_encrypted_client, msg_id, msg_type="execute_result")
await utils.check_pub_message(curve_encrypted_client, msg_id, execution_state="idle")
Loading
Loading