Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 56 additions & 52 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import async_kernel
from async_kernel.kernelspec import Backend
from async_kernel.typing import NoValue, PosArgsT, T
from async_kernel.typing import NoValue, T
from async_kernel.utils import wait_thread_event

if TYPE_CHECKING:
Expand Down Expand Up @@ -587,7 +587,7 @@ def call_later(
Schedule func to be called in caller's event loop copying the current context.

Args:
func: The function (awaitables permitted, though discouraged).
func: The function.
delay: The minimum delay to add between submission and execution.
*args: Arguments to use with func.
**kwargs: Keyword arguments to use with func.
Expand All @@ -605,7 +605,7 @@ def call_soon(
Schedule func to be called in caller's event loop copying the current context.

Args:
func: The function (awaitables permitted, though discouraged).
func: The function.
*args: Arguments to use with func.
**kwargs: Keyword arguments to use with func.
"""
Expand All @@ -625,7 +625,7 @@ def call_direct(
need to be performed from within the callers event loop/taskgroup.

Args:
func: The function (awaitables permitted, though discouraged).
func: The function.
*args: Arguments to use with func.
**kwargs: Keyword arguments to use with func.

Expand All @@ -641,78 +641,82 @@ def queue_exists(self, func: Callable) -> bool:
"Returns True if an execution queue exists for `func`."
return func in self._queue_map

if TYPE_CHECKING:

@overload
def queue_call(
self,
func: Callable[[*PosArgsT], Awaitable[Any]],
/,
*args: *PosArgsT,
max_buffer_size: NoValue | int = NoValue, # pyright: ignore[reportInvalidTypeForm]
wait: Literal[True],
) -> CoroutineType[Any, Any, None]: ...
@overload
def queue_call(
self,
func: Callable[[*PosArgsT], Awaitable[Any]],
/,
*args: *PosArgsT,
max_buffer_size: NoValue | int = NoValue, # pyright: ignore[reportInvalidTypeForm]
wait: Literal[False] | Any = False,
) -> None: ...

def queue_call(
self,
func: Callable[[*PosArgsT], Awaitable[Any]],
/,
*args: *PosArgsT,
max_buffer_size: NoValue | int = NoValue, # pyright: ignore[reportInvalidTypeForm]
wait: bool = False,
) -> CoroutineType[Any, Any, None] | None:
def queue_get_sender(
self, func: Callable, max_buffer_size: None | int = None
) -> MemoryObjectSendStream[tuple[contextvars.Context, tuple, dict]]:
"""
Queue the execution of `func` with the arguments `*args` in a queue unique to it (not thread-safe).

The args are added to a queue associated with the provided `func`. If queue does not already exist for
func, a new queue is created with a specified maximum buffer size. The arguments are then sent to the queue,
and an `execute_loop` coroutine is started to consume the queue and execute the function with the received
arguments. Exceptions during execution are caught and logged.
Get or create a new queue unique to func in this caller.

Args:
func: The asynchronous function to execute.
*args: The arguments to pass to the function.
max_buffer_size: The maximum buffer size for the queue. If NoValue, defaults to [async_kernel.Caller.MAX_BUFFER_SIZE].
wait: Set as True to return a coroutine that will return once the request is sent.
Use this to prevent experiencing exceptions if the buffer is full.
This method can be used to configure the buffer size of the queue for the methods
- `queue_call`
- `queue_call_no_wait`

!!! info

The queue will stay open until one of the following occurs.

1. It explicitly closed with the method `queue_close`.
1. All strong references are lost the function/method.

"""
self._check_in_thread()
max_buffer_size = max_buffer_size or self.MAX_BUFFER_SIZE
if not (sender := self._queue_map.get(func)):
max_buffer_size = self.MAX_BUFFER_SIZE if max_buffer_size is NoValue else max_buffer_size
sender, queue = anyio.create_memory_object_stream[tuple[*PosArgsT]](max_buffer_size=max_buffer_size)
sender, queue = anyio.create_memory_object_stream[tuple[contextvars.Context, tuple, dict]](max_buffer_size)

async def execute_loop():
try:
with contextlib.suppress(anyio.get_cancelled_exc_class()):
async with queue as receive_stream:
async for args in receive_stream:
async for context, args, kwargs in receive_stream:
try:
await func(*args)
result = context.run(func, *args, **kwargs)
if inspect.iscoroutine(result):
await result
except Exception as e:
self.log.exception("Execution %f failed", func, exc_info=e)
finally:
self._queue_map.pop(func, None)

self._queue_map[func] = sender
self.call_soon(execute_loop)
return sender.send(args) if wait else sender.send_nowait(args)
return sender

async def queue_call(
self,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Queue the execution of `func` in a queue unique to it and this caller (not thread-safe).

Args:
func: The function.
*args: Arguments to use with func.
**kwargs: Keyword arguments to use with func.
"""
sender = self.queue_get_sender(func)
await sender.send((contextvars.copy_context(), args, kwargs))

def queue_call_no_wait(
self,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Queue the execution of `func` in a queue unique to it and this caller (not thread-safe).

Args:
func: The function.
*args: Arguments to use with func.
**kwargs: Keyword arguments to use with func.
"""

sender = self.queue_get_sender(func)
sender.send_nowait((contextvars.copy_context(), args, kwargs))

def queue_close(self, func: Callable) -> None:
"""
Expand Down Expand Up @@ -781,7 +785,7 @@ def to_thread_by_name(
[^notes]: 'MainThread' is special name corresponding to the main thread.
A `RuntimeError` will be raised if a Caller does not exist for the main thread.

func: The function (awaitables permitted, though discouraged).
func: The function.
*args: Arguments to use with func.
**kwargs: Keyword arguments to use with func.

Expand Down
2 changes: 1 addition & 1 deletion src/async_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ async def handle_message_request(self, job: Job, /) -> None:
runner = _wrap_handler(self.run_handler, handler)
match run_mode:
case RunMode.queue:
await Caller().queue_call(runner, job, wait=True)
await Caller().queue_call(runner, job)
case RunMode.thread:
Caller.to_thread(runner, job)
case RunMode.task:
Expand Down
3 changes: 1 addition & 2 deletions src/async_kernel/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import enum
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, TypedDict, TypeVar, TypeVarTuple
from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, TypedDict, TypeVar

from typing_extensions import Sentinel, override

Expand Down Expand Up @@ -30,7 +30,6 @@
T = TypeVar("T")
D = TypeVar("D", bound=dict)
P = ParamSpec("P")
PosArgsT = TypeVarTuple("PosArgsT")


class SocketID(enum.StrEnum):
Expand Down
16 changes: 7 additions & 9 deletions tests/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,21 +414,19 @@ async def test__check_in_thread(self, anyio_backend):
worker._check_in_thread() # pyright: ignore[reportPrivateUsage]

async def test_execution_queue(self, caller: Caller):
delay = 0.01
N = 10
N = 5

pool = list(range(N))
results = []

async def func(a, b, results=results):
async def func(a, b, /, delay, *, results):
await anyio.sleep(delay)
results.append(b)

for i in range(3):
caller.queue_get_sender(func, max_buffer_size=2)
for _ in range(2):
results = []
for j in pool:
buff = i * N + 1
if waiter := caller.queue_call(func, 0, j, wait=j >= buff, max_buffer_size=buff):
await waiter # pyright: ignore[reportGeneralTypeIssues]
await caller.queue_call(func, 0, j, delay=0.05 * j, results=results)
assert caller.queue_exists(func)
assert results != pool
caller.queue_close(func)
Expand All @@ -451,7 +449,7 @@ async def method(self):
async with Caller(create=True) as caller:
obj = MyObj()
weakref.finalize(obj, obj_finalized.set)
caller.queue_call(obj.method)
caller.queue_call_no_wait(obj.method)
await method_called.wait()
assert caller.queue_exists(obj.method), "A ref should be retained unless it is explicitly removed"
del obj
Expand Down