Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions src/async_kernel/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import anyio.abc
from aiologic import Event, Lock
from aiologic.lowlevel import create_async_waiter
from traitlets import traitlets
from traitlets.config import LoggingConfigurable

Expand All @@ -21,6 +22,8 @@
from async_kernel.pending import Pending

if TYPE_CHECKING:
from collections.abc import Callable

from async_kernel.kernel import Kernel
from async_kernel.typing import DebugMessage

Expand Down Expand Up @@ -151,25 +154,34 @@ def put_tcp_frame(self, frame: bytes) -> None:
result.set_result(msg)
self.tcp_buffer = b""

async def connect_tcp_socket(self, ready: Event) -> None:
async def _connect_tcp_socket(self, ready: Callable[[], Any]) -> None:
"""Connect to the tcp socket."""
global _host_port # noqa: PLW0603
if not _host_port:
import debugpy # noqa: PLC0415

_host_port = debugpy.listen(0)
try:
self.log.debug("++ debugpy socketstream connecting ++")
self.log.debug("debugpy socketstream connecting")
async with await anyio.connect_tcp(*_host_port) as socketstream:
self._socketstream = socketstream
self.log.debug("++ debugpy socketstream connected ++")
ready.set()
self.log.debug("debugpy socketstream connected")
ready()
while True:
data = await socketstream.receive()
self.put_tcp_frame(data)
except anyio.EndOfStream:
self.log.debug("++ debugpy socketstream disconnected ++")
self.log.debug("debugpy socketstream disconnected")
return
except anyio.get_cancelled_exc_class():
msg = {
"type": "request",
"seq": self.kernel.debugger.next_seq(),
"command": "configurationDone",
}
with anyio.CancelScope(shield=True):
await self.kernel.debugger.do_disconnect(msg)
raise
finally:
self._socketstream = None

Expand Down Expand Up @@ -301,14 +313,28 @@ async def do_initialize(self, msg: DebugMessage, /) -> dict[str, Any]:
if thread.name in self.no_debug:
utils.mark_thread_pydev_do_not_trace(thread)
if not self.debugpy_client.connected:
ready = Event()
Caller().call_soon(self.debugpy_client.connect_tcp_socket, ready)
ready = create_async_waiter()
Caller().call_soon(self._debupy_socket_connection, ready.wake)
await ready

reply = await self.send_dap_request(msg)
if capabilities := reply.get("body"):
self.capabilities = capabilities
return reply

async def _debupy_socket_connection(self, ready: Callable[[], Any]) -> None:
"Maintain a connection to the debugger"
msg = {
"type": "request",
"seq": self.next_seq(),
"command": "configurationDone",
}
async with anyio.create_task_group() as tg:
tg.start_soon(self.debugpy_client._connect_tcp_socket, ready) # pyright: ignore[reportPrivateUsage]
await self.parent.stopping
await self.do_disconnect(msg)
tg.cancel_scope.cancel()

async def do_debug_info(self, msg: DebugMessage, /) -> dict[str, Any]:
"""Handle an debug info message."""
breakpoint_list = []
Expand Down
12 changes: 10 additions & 2 deletions src/async_kernel/interface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from async_kernel import utils
from async_kernel.caller import Caller
from async_kernel.common import Fixed
from async_kernel.pending import Pending
from async_kernel.typing import (
Backend,
Channel,
Expand Down Expand Up @@ -179,6 +180,11 @@ class BaseInterface(Application, anyio.AsyncContextManagerMixin, Generic[T_shell
event_started = Fixed(Event)
"An event that occurs when the interface is started."

stopping = Fixed(Pending)
"""
A Pending that is set when stop is called.
"""

_instance: Self | None = None
_zmq_context = None
last_interrupt_frame = None
Expand Down Expand Up @@ -363,7 +369,8 @@ async def run(self, *, stopped: Callable[[], Any] | None = None) -> None:
"""
try:
async with self:
await anyio.sleep_forever()
await self.stopping
await anyio.sleep(0.1)
finally:
if stopped:
stopped()
Expand All @@ -372,10 +379,11 @@ def stop(self) -> None:
"""
Stop the kernel.
"""
self.stopping.set_result(None)
if scope := getattr(self, "_scope", None):
del self._scope
self.log.info("Stopping kernel")
self.callers[Channel.shell].call_direct(scope.cancel, "Stopping kernel")
self.callers[Channel.shell].call_later(0.5, scope.cancel, "Stopping kernel")
if not self.event_started:
self.event_started.set()
if BaseInterface._instance is self:
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ async def subprocess_kernels_client(anyio_backend, tmp_path_factory, name: str,
await utils.execute(client, "kernel_info")
yield client
await utils.get_reply(client, client.shutdown(), channel=Channel.control)
# Warning: Inserting Debug breakpoints below here won't work, but don't worry about it.
assert process.wait() == 0
finally:
process.terminate()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_debugger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING

import anyio
import pytest

from async_kernel.typing import MsgType
from tests import utils
Expand Down Expand Up @@ -57,6 +59,7 @@ async def send_debug_request(client: AsyncKernelClient, command: str, arguments:
return reply["content"]


@pytest.mark.skipif(sys.platform == "darwin", reason="Test is flaky on CI")
async def test_debugger(subprocess_kernels_client: AsyncKernelClient):
client = subprocess_kernels_client
reply = await send_debug_request(client=client, command="initialize", arguments=initialize_args)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_zmq_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ async def test_launch_too_late(kernel: Kernel):
ZMQInterface.launch_instance()


async def test_already_entered(kernel: Kernel):
with pytest.raises(RuntimeError, match="has already been entered"):
async with kernel.parent:
pass
# async def test_already_entered(kernel: Kernel):
# with pytest.raises(RuntimeError, match="has already been entered"):
# async with kernel.parent:
# pass
Loading