Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
"traitlets~=5.14",
"jupyter_client~=8.6",
"pyzmq~=27.0",
"anyio~=4.8",
"anyio~=4.11",
"typing_extensions~=4.14",
"sniffio~=1.3",
"matplotlib-inline~=0.1",
Expand Down Expand Up @@ -75,7 +75,7 @@ test = [
"pytest-mock",
"pytest-rerunfailures",
"pytest>=8.4,<9",
"trio",
"trio>=0.31.0",
"hatch"
]
dev = [
Expand Down
184 changes: 87 additions & 97 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@
import functools
import inspect
import logging
import math
import reprlib
import threading
import time
import weakref
from collections import deque
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from types import CoroutineType
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast, overload
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Never, Self, cast, overload

import anyio
import sniffio
from anyio.streams.memory import MemoryObjectSendStream
from typing_extensions import override
from zmq import Context, Socket, SocketType

Expand All @@ -32,7 +30,6 @@
from types import CoroutineType

from anyio.abc import TaskGroup, TaskStatus
from anyio.streams.memory import MemoryObjectSendStream

from async_kernel.typing import P

Expand Down Expand Up @@ -63,7 +60,7 @@ class InvalidStateError(RuntimeError):
class AsyncEvent:
"""An asynchronous thread-safe event compatible with [async_kernel.caller.Caller][]."""

__slots__ = ["_events", "_flag", "_thread"]
__slots__ = ["__weakref__", "_events", "_flag", "_thread"]

def __init__(self, thread: threading.Thread | None = None) -> None:
self._thread = thread or threading.current_thread()
Expand Down Expand Up @@ -128,7 +125,7 @@ class Future(Awaitable[T]):
"""

_cancelled = False
_cancel_scope: anyio.CancelScope | None = None
_canceller: Callable[[str | None], Any] | None = None
_exception = None
_setting_value = False
_result: T
Expand Down Expand Up @@ -277,20 +274,27 @@ def add_done_callback(self, fn: Callable[[Self], Any]) -> None:

def cancel(self, msg: str | None = None) -> bool:
"""
Cancel the Future and schedule callbacks (thread-safe using Caller).
Cancel the Future (thread-safe using Caller).

!!! note

- Cancellation cannot be undone.
- The future will not be done until set_result or set_excetion is called
in both cases the value is ignore and replaced with a [FutureCancelledError][async_kernel.caller.FutureCancelledError]
and the result is inaccessible.

Args:
msg: The message to use when raising a FutureCancelledError.
msg: The message to use when cancelling.

Returns if it has been cancelled.
"""
if not self.done():
if msg and isinstance(self._cancelled, str):
msg = f"{self._cancelled}\n{msg}"
self._cancelled = msg or self._cancelled or True
if scope := self._cancel_scope:
if canceller := self._canceller:
if threading.current_thread() is self._thread:
scope.cancel()
canceller(msg)
else:
Caller(thread=self._thread).call_direct(self.cancel)
return self.cancelled()
Expand Down Expand Up @@ -339,11 +343,20 @@ def remove_done_callback(self, fn: Callable[[Self], object], /) -> int:
self._done_callbacks.remove(fn)
return n

def set_cancel_scope(self, scope: anyio.CancelScope) -> None:
"Provide a cancel scope for cancellation."
if self._cancelled or self._cancel_scope:
def set_canceller(self, canceller: Callable[[str | None], Any]) -> None:
"""
Set a callback to handle cancellation.

!!! note

`set_result` must still be called to mark the future as completed. You can pass any
value as it will be replaced with a [async_kernel.caller.FutureCancelledError][].
"""
if self.done() or self._canceller:
raise InvalidStateError
self._cancel_scope = scope
self._canceller = canceller
if self.cancelled():
self.cancel()

def get_caller(self) -> Caller:
"The [async_kernel.caller.Caller][] that is running for this *futures* thread."
Expand Down Expand Up @@ -374,7 +387,7 @@ class Caller:
_to_thread_pool: ClassVar[deque[Self]] = deque()
_pool_instances: ClassVar[weakref.WeakSet[Self]] = weakref.WeakSet()
_backend: Backend
_queue_map: weakref.WeakKeyDictionary[Callable[..., Awaitable[Any]], MemoryObjectSendStream[tuple]]
_queue_map: dict[int, Future]
_taskgroup: TaskGroup | None = None
_jobs: deque[tuple[contextvars.Context, Future] | Callable[[], Any]]
_thread: threading.Thread
Expand Down Expand Up @@ -429,7 +442,7 @@ def __new__(
inst._jobs = deque()
inst._job_added = threading.Event()
inst._protected = protected
inst._queue_map = weakref.WeakKeyDictionary()
inst._queue_map = {}
cls._instances[thread] = inst
return inst

Expand Down Expand Up @@ -489,29 +502,29 @@ async def _server_loop(self, tg: TaskGroup, task_status: TaskStatus[None]) -> No

async def _wrap_call(self, fut: Future) -> None:
if fut.cancelled():
fut.set_result(None) # This will cancel
if not fut.done():
fut.set_result(None) # This will cancel
return
md = fut.metadata
func = md["func"]
token = self._future_var.set(fut)
try:
with anyio.CancelScope() as scope:
fut.set_cancel_scope(scope)
fut.set_canceller(scope.cancel)
try:
if (delay := md.get("delay")) and ((delay := delay - time.monotonic() + md["start_time"]) > 0):
await anyio.sleep(delay)
# Evaluate
result = func(*md["args"], **md["kwargs"])
if inspect.iscoroutine(result):
result = await result
# Cancellation
if fut.cancelled() and not scope.cancel_called:
scope.cancel()
fut.set_result(result)
except anyio.get_cancelled_exc_class():
with contextlib.suppress(anyio.get_cancelled_exc_class()):
fut.cancel()
if not fut.cancelled():
with contextlib.suppress(anyio.get_cancelled_exc_class()):
fut.cancel()
fut.set_result(None) # This will cancel
raise
except Exception as e:
fut.set_exception(e)
except Exception as e:
Expand Down Expand Up @@ -582,9 +595,8 @@ def stop(self, *, force=False) -> None:
if self._protected and not force:
return
self._stopped = True
for sender in self._queue_map.values():
sender.close()
self._queue_map.clear()
for func in tuple(self._queue_map):
self.queue_close(func)
self._job_added.set()
self._instances.pop(self.thread, None)
if self in self._to_thread_pool:
Expand Down Expand Up @@ -699,54 +711,17 @@ def call_direct(
self._jobs.append(functools.partial(func, *args, **kwargs))
self._job_added.set()

def queue_exists(self, func: Callable) -> bool:
"Returns True if an execution queue exists for `func`."
return func in self._queue_map

def queue_get_sender(
self, func: Callable, max_buffer_size: float = math.inf
) -> MemoryObjectSendStream[tuple[contextvars.Context, tuple, dict]]:
"""
Get or create a new queue unique to func in this caller.

This method can be used to configure the buffer size of the queue for the methods
- `queue_call`
- `queue_call_no_wait`
def queue_get(self, func: Callable) -> Future[Never] | None:
"""Returns Future for `func` where the queue is running.

Args:
func: The function to which the queue is associated with this caller.
max_buffer_size: The maximum allowed queued messages, see [anyio.create_memory_object_stream][] for further details.

!!! info

The queue will stay open until one of the following occurs.
!!! warning

1. It explicitly closed with the method `queue_close`.
1. All strong references are lost the function/method.
- This future loops forever until the loop is closed or func no longer exists.
- `queue_close` is the preferred means to shutdown the queue.
"""
self._check_in_thread()
if not (sender := self._queue_map.get(func)):
sender, queue = anyio.create_memory_object_stream[tuple[contextvars.Context, tuple, dict]](max_buffer_size)
return self._queue_map.get(hash(func))

async def execute_loop():
try:
with contextlib.suppress(anyio.get_cancelled_exc_class()):
async with queue as receive_stream:
async for context, args, kwargs in receive_stream:
try:
result = context.run(func, *args, **kwargs)
if inspect.iscoroutine(result):
await result
except Exception as e:
self.log.exception("Execution %f failed", func, exc_info=e)
finally:
self._queue_map.pop(func, None)

self._queue_map[func] = sender
self.call_soon(execute_loop)
return sender

async def queue_call(
def queue_call(
self,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
Expand All @@ -756,48 +731,63 @@ async def queue_call(
"""
Queue the execution of `func` in a queue unique to it and this caller (thread-safe).

This is the async version that will wait until the call is added to the queue. This
is the preferred way to queue calls, with the optimal pathway being in the current thread.

Args:
func: The function.
*args: Arguments to use with `func`.
**kwargs: Keyword arguments to use with `func`.
"""
if self._thread is not threading.current_thread():
await self.call_soon(self.queue_call, func, *args, **kwargs)
else:
sender = self.queue_get_sender(func)
await sender.send((contextvars.copy_context(), args, kwargs))
key = hash(func)
if not (fut := self._queue_map.get(key)):
queue = deque()
event_added = threading.Event()

def queue_call_no_wait(
self,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Queue the execution of `func` in a queue unique to it and this caller (thread-safe).
with contextlib.suppress(TypeError):
weakref.finalize(func.__self__ if inspect.ismethod(func) else func, lambda: self.queue_close(key))

This is a convenience method that calls `queue_call` as a task.
def sender(args):
queue.append(args)
event_added.set()

Args:
func: The function.
*args: Arguments to use with `func`.
**kwargs: Keyword arguments to use with `func`.
"""
self.call_soon(self.queue_call, func, *args, **kwargs)
async def execute_loop(sender, queue: deque, event_added=event_added) -> None:
fut = self.current_future()
assert fut
try:
while True:
event_added.clear()
if queue:
context, func_, args, kwargs = queue.popleft()
try:
result = context.run(func_, *args, **kwargs)
if inspect.iscoroutine(object=result):
await result
except (anyio.get_cancelled_exc_class(), Exception) as e:
if fut.cancelled():
break
self.log.exception("Execution %f failed", func_, exc_info=e)
finally:
func_ = None
else:
await wait_thread_event(event_added)
finally:
self._queue_map.pop(key)

self._queue_map[key] = fut = self.call_soon(
execute_loop, sender=sender, queue=queue, event_added=event_added
)
fut.metadata["kwargs"]["sender"]((contextvars.copy_context(), func, args, kwargs))

def queue_close(self, func: Callable) -> None:
def queue_close(self, func: Callable | int) -> None:
"""
Close the execution queue associated with `func` (thread-safe).

Args:
func: The queue of the function to close.
"""
if sender := self._queue_map.pop(func, None):
self.call_direct(sender.close)
key = func if isinstance(func, int) else hash(func)
if fut := self._queue_map.pop(key, None):
if (kwargs := fut.metadata.get("kwargs")) and (event := kwargs.get("event_added")):
event.set()
fut.cancel()

@classmethod
def stop_all(cls, *, _stop_protected: bool = False) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/async_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ async def handle_message_request(self, job: Job, /) -> None:
runner = _wrap_handler(self.run_handler, handler)
match run_mode:
case RunMode.queue:
await Caller().queue_call(runner, job)
Caller().queue_call(runner, job)
case RunMode.thread:
Caller.to_thread(runner, job)
case RunMode.task:
Expand Down
Loading