From b1b5fcae3774b7d67fcd1bf242dc78039c3912cf Mon Sep 17 00:00:00 2001 From: Alan Fleming <> Date: Sun, 14 Sep 2025 20:57:23 +1000 Subject: [PATCH] Caller.queue_call - divide into queue_get_sender, queue_call and queue_call_no_wait. --- src/async_kernel/caller.py | 108 +++++++++++++++++++------------------ src/async_kernel/kernel.py | 2 +- src/async_kernel/typing.py | 3 +- tests/test_caller.py | 16 +++--- 4 files changed, 65 insertions(+), 64 deletions(-) diff --git a/src/async_kernel/caller.py b/src/async_kernel/caller.py index 622d45e..63f6cfa 100644 --- a/src/async_kernel/caller.py +++ b/src/async_kernel/caller.py @@ -21,7 +21,7 @@ import async_kernel from async_kernel.kernelspec import Backend -from async_kernel.typing import NoValue, PosArgsT, T +from async_kernel.typing import NoValue, T from async_kernel.utils import wait_thread_event if TYPE_CHECKING: @@ -587,7 +587,7 @@ def call_later( Schedule func to be called in caller's event loop copying the current context. Args: - func: The function (awaitables permitted, though discouraged). + func: The function. delay: The minimum delay to add between submission and execution. *args: Arguments to use with func. **kwargs: Keyword arguments to use with func. @@ -605,7 +605,7 @@ def call_soon( Schedule func to be called in caller's event loop copying the current context. Args: - func: The function (awaitables permitted, though discouraged). + func: The function. *args: Arguments to use with func. **kwargs: Keyword arguments to use with func. """ @@ -625,7 +625,7 @@ def call_direct( need to be performed from within the callers event loop/taskgroup. Args: - func: The function (awaitables permitted, though discouraged). + func: The function. *args: Arguments to use with func. **kwargs: Keyword arguments to use with func. @@ -641,49 +641,15 @@ def queue_exists(self, func: Callable) -> bool: "Returns True if an execution queue exists for `func`." return func in self._queue_map - if TYPE_CHECKING: - - @overload - def queue_call( - self, - func: Callable[[*PosArgsT], Awaitable[Any]], - /, - *args: *PosArgsT, - max_buffer_size: NoValue | int = NoValue, # pyright: ignore[reportInvalidTypeForm] - wait: Literal[True], - ) -> CoroutineType[Any, Any, None]: ... - @overload - def queue_call( - self, - func: Callable[[*PosArgsT], Awaitable[Any]], - /, - *args: *PosArgsT, - max_buffer_size: NoValue | int = NoValue, # pyright: ignore[reportInvalidTypeForm] - wait: Literal[False] | Any = False, - ) -> None: ... - - def queue_call( - self, - func: Callable[[*PosArgsT], Awaitable[Any]], - /, - *args: *PosArgsT, - max_buffer_size: NoValue | int = NoValue, # pyright: ignore[reportInvalidTypeForm] - wait: bool = False, - ) -> CoroutineType[Any, Any, None] | None: + def queue_get_sender( + self, func: Callable, max_buffer_size: None | int = None + ) -> MemoryObjectSendStream[tuple[contextvars.Context, tuple, dict]]: """ - Queue the execution of `func` with the arguments `*args` in a queue unique to it (not thread-safe). - - The args are added to a queue associated with the provided `func`. If queue does not already exist for - func, a new queue is created with a specified maximum buffer size. The arguments are then sent to the queue, - and an `execute_loop` coroutine is started to consume the queue and execute the function with the received - arguments. Exceptions during execution are caught and logged. + Get or create a new queue unique to func in this caller. - Args: - func: The asynchronous function to execute. - *args: The arguments to pass to the function. - max_buffer_size: The maximum buffer size for the queue. If NoValue, defaults to [async_kernel.Caller.MAX_BUFFER_SIZE]. - wait: Set as True to return a coroutine that will return once the request is sent. - Use this to prevent experiencing exceptions if the buffer is full. + This method can be used to configure the buffer size of the queue for the methods + - `queue_call` + - `queue_call_no_wait` !!! info @@ -691,20 +657,21 @@ def queue_call( 1. It explicitly closed with the method `queue_close`. 1. All strong references are lost the function/method. - """ self._check_in_thread() + max_buffer_size = max_buffer_size or self.MAX_BUFFER_SIZE if not (sender := self._queue_map.get(func)): - max_buffer_size = self.MAX_BUFFER_SIZE if max_buffer_size is NoValue else max_buffer_size - sender, queue = anyio.create_memory_object_stream[tuple[*PosArgsT]](max_buffer_size=max_buffer_size) + sender, queue = anyio.create_memory_object_stream[tuple[contextvars.Context, tuple, dict]](max_buffer_size) async def execute_loop(): try: with contextlib.suppress(anyio.get_cancelled_exc_class()): async with queue as receive_stream: - async for args in receive_stream: + async for context, args, kwargs in receive_stream: try: - await func(*args) + result = context.run(func, *args, **kwargs) + if inspect.iscoroutine(result): + await result except Exception as e: self.log.exception("Execution %f failed", func, exc_info=e) finally: @@ -712,7 +679,44 @@ async def execute_loop(): self._queue_map[func] = sender self.call_soon(execute_loop) - return sender.send(args) if wait else sender.send_nowait(args) + return sender + + async def queue_call( + self, + func: Callable[P, T | CoroutineType[Any, Any, T]], + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """ + Queue the execution of `func` in a queue unique to it and this caller (not thread-safe). + + Args: + func: The function. + *args: Arguments to use with func. + **kwargs: Keyword arguments to use with func. + """ + sender = self.queue_get_sender(func) + await sender.send((contextvars.copy_context(), args, kwargs)) + + def queue_call_no_wait( + self, + func: Callable[P, T | CoroutineType[Any, Any, T]], + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """ + Queue the execution of `func` in a queue unique to it and this caller (not thread-safe). + + Args: + func: The function. + *args: Arguments to use with func. + **kwargs: Keyword arguments to use with func. + """ + + sender = self.queue_get_sender(func) + sender.send_nowait((contextvars.copy_context(), args, kwargs)) def queue_close(self, func: Callable) -> None: """ @@ -781,7 +785,7 @@ def to_thread_by_name( [^notes]: 'MainThread' is special name corresponding to the main thread. A `RuntimeError` will be raised if a Caller does not exist for the main thread. - func: The function (awaitables permitted, though discouraged). + func: The function. *args: Arguments to use with func. **kwargs: Keyword arguments to use with func. diff --git a/src/async_kernel/kernel.py b/src/async_kernel/kernel.py index a54e905..12c52f1 100644 --- a/src/async_kernel/kernel.py +++ b/src/async_kernel/kernel.py @@ -746,7 +746,7 @@ async def handle_message_request(self, job: Job, /) -> None: runner = _wrap_handler(self.run_handler, handler) match run_mode: case RunMode.queue: - await Caller().queue_call(runner, job, wait=True) + await Caller().queue_call(runner, job) case RunMode.thread: Caller.to_thread(runner, job) case RunMode.task: diff --git a/src/async_kernel/typing.py b/src/async_kernel/typing.py index 307915b..e0a7da9 100644 --- a/src/async_kernel/typing.py +++ b/src/async_kernel/typing.py @@ -2,7 +2,7 @@ import enum from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, TypedDict, TypeVar, TypeVarTuple +from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, TypedDict, TypeVar from typing_extensions import Sentinel, override @@ -30,7 +30,6 @@ T = TypeVar("T") D = TypeVar("D", bound=dict) P = ParamSpec("P") -PosArgsT = TypeVarTuple("PosArgsT") class SocketID(enum.StrEnum): diff --git a/tests/test_caller.py b/tests/test_caller.py index 6aa041c..35af686 100644 --- a/tests/test_caller.py +++ b/tests/test_caller.py @@ -414,21 +414,19 @@ async def test__check_in_thread(self, anyio_backend): worker._check_in_thread() # pyright: ignore[reportPrivateUsage] async def test_execution_queue(self, caller: Caller): - delay = 0.01 - N = 10 + N = 5 pool = list(range(N)) - results = [] - async def func(a, b, results=results): + async def func(a, b, /, delay, *, results): await anyio.sleep(delay) results.append(b) - for i in range(3): + caller.queue_get_sender(func, max_buffer_size=2) + for _ in range(2): + results = [] for j in pool: - buff = i * N + 1 - if waiter := caller.queue_call(func, 0, j, wait=j >= buff, max_buffer_size=buff): - await waiter # pyright: ignore[reportGeneralTypeIssues] + await caller.queue_call(func, 0, j, delay=0.05 * j, results=results) assert caller.queue_exists(func) assert results != pool caller.queue_close(func) @@ -451,7 +449,7 @@ async def method(self): async with Caller(create=True) as caller: obj = MyObj() weakref.finalize(obj, obj_finalized.set) - caller.queue_call(obj.method) + caller.queue_call_no_wait(obj.method) await method_called.wait() assert caller.queue_exists(obj.method), "A ref should be retained unless it is explicitly removed" del obj