Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 62 additions & 34 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,13 @@ def __init__(self, thread: threading.Thread | None = None, /, **metadata) -> Non
@override
def __repr__(self) -> str:
md = self.metadata
rep = f"Future< {self._thread.name}" + (" ⛔" if self.cancelled() else "") + (" 🏁" if self.done() else "")
if "func" in md:
items = [f"{k}={truncated_rep.repr(v)}" for k, v in md.items() if k not in self.REPR_OMIT]
rep = f"| {md['func']} {' | '.join(items) if items else ''}"
rep += f" | {md['func']} {' | '.join(items) if items else ''}"
else:
rep = f"{truncated_rep.repr(md)}" if md else ""
return f"Future< {self._thread.name} {rep}>"
rep += f" {truncated_rep.repr(md)}" if md else ""
return rep + " >"

@override
def __await__(self) -> Generator[Any, None, T]:
Expand Down Expand Up @@ -371,9 +372,9 @@ class Caller:
_backend: Backend
_queue_map: weakref.WeakKeyDictionary[Callable[..., Awaitable[Any]], MemoryObjectSendStream[tuple]]
_taskgroup: TaskGroup | None = None
_callers: deque[tuple[contextvars.Context, Future] | Callable[[], Any]]
_jobs: deque[tuple[contextvars.Context, Future] | Callable[[], Any]]
_thread: threading.Thread
_callers_added: threading.Event
_job_added: threading.Event
_stopped_event: threading.Event
_stopped = False
_protected = False
Expand Down Expand Up @@ -428,8 +429,8 @@ def __new__(
inst._backend = Backend(sniffio.current_async_library())
inst._thread = thread
inst.log = log or logging.LoggerAdapter(logging.getLogger())
inst._callers = deque()
inst._callers_added = threading.Event()
inst._jobs = deque()
inst._job_added = threading.Event()
inst._protected = protected
inst._queue_map = weakref.WeakKeyDictionary()
cls._instances[thread] = inst
Expand Down Expand Up @@ -461,24 +462,26 @@ async def _server_loop(self, tg: TaskGroup, task_status: TaskStatus[None]) -> No
self.iopub_sockets[self.thread] = socket
task_status.started()
while not self._stopped:
if not self._callers:
self._callers_added.clear()
await wait_thread_event(self._callers_added)
while self._callers:
if not self._jobs:
self._job_added.clear()
await wait_thread_event(self._job_added)
while self._jobs:
if self._stopped:
return
job = self._callers.popleft()
job = self._jobs.popleft()
if isinstance(job, Callable):
try:
job()
result = job()
if inspect.iscoroutine(result):
await result
except Exception as e:
self.log.exception("Simple call failed", exc_info=e)
else:
context, fut = job
context.run(tg.start_soon, self._wrap_call, fut)
finally:
self._running = False
for job in self._callers:
for job in self._jobs:
if isinstance(job, tuple):
job[1].set_exception(FutureCancelledError())
socket.close()
Expand All @@ -490,11 +493,8 @@ def _schedule_wrapped_call(self, func: Callable, /, args: tuple, kwargs: dict, *
if self._stopped:
raise anyio.ClosedResourceError
fut = Future(self.thread, func=func, args=args, kwargs=kwargs, **extra)
if threading.current_thread() is self.thread and (tg := self._taskgroup):
tg.start_soon(self._wrap_call, fut)
else:
self._callers.append((contextvars.copy_context(), fut))
self._callers_added.set()
self._jobs.append((contextvars.copy_context(), fut))
self._job_added.set()
return fut

async def _wrap_call(self, fut: Future) -> None:
Expand All @@ -510,9 +510,11 @@ async def _wrap_call(self, fut: Future) -> None:
try:
if (delay := md.get("delay")) and ((delay := delay - time.monotonic() + md["start_time"]) > 0):
await anyio.sleep(delay)
result = func(*md["args"], **md["kwargs"]) if callable(func) else func
if inspect.isawaitable(result) and result is not fut:
# Evaluate
result = func(*md["args"], **md["kwargs"])
if inspect.iscoroutine(result):
result = await result
# Cancellation
if fut.cancelled() and not scope.cancel_called:
scope.cancel()
fut.set_result(result)
Expand Down Expand Up @@ -566,15 +568,20 @@ def stop(self, *, force=False) -> None:
for sender in self._queue_map.values():
sender.close()
self._queue_map.clear()
self._callers_added.set()
self._job_added.set()
self._instances.pop(self.thread, None)
if self in self._to_thread_pool:
self._to_thread_pool.remove(self)
if self.thread is not threading.current_thread():
self._stopped_event.wait()

def call_later(
self, delay: float, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs
self,
delay: float,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""
Schedule func to be called in caller's event loop copying the current context.
Expand All @@ -587,7 +594,13 @@ def call_later(
"""
return self._schedule_wrapped_call(func, args, kwargs, delay=delay, start_time=time.monotonic())

def call_soon(self, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs) -> Future[T]:
def call_soon(
self,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""
Schedule func to be called in caller's event loop copying the current context.

Expand All @@ -598,12 +611,18 @@ def call_soon(self, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwa
"""
return self._schedule_wrapped_call(func, args, kwargs)

def call_direct(self, func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwargs) -> None:
def call_direct(
self,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Schedule `func` to be called in caller's event loop directly.

This method is provided to facilitate lightweight *thread-safe* function calls that
need to be done from within the callers event loop.
need to be performed from within the callers event loop/taskgroup.

Args:
func: The function (awaitables permitted, though discouraged).
Expand All @@ -612,11 +631,11 @@ def call_direct(self, func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwar

??? warning

- Use this method for lightweight calls only.
- Corroutines will **not** be awaited.
**Use this method for lightweight calls only!**

"""
self._callers.append(functools.partial(func, *args, **kwargs))
self._callers_added.set()
self._jobs.append(functools.partial(func, *args, **kwargs))
self._job_added.set()

def queue_exists(self, func: Callable) -> bool:
"Returns True if an execution queue exists for `func`."
Expand Down Expand Up @@ -684,8 +703,6 @@ async def execute_loop():
with contextlib.suppress(anyio.get_cancelled_exc_class()):
async with queue as receive_stream:
async for args in receive_stream:
if func not in self._queue_map:
break
try:
await func(*args)
except Exception as e:
Expand Down Expand Up @@ -736,13 +753,24 @@ def get_instance(cls, name: str | None = "MainThread", *, create: bool = False)
raise RuntimeError(msg)

@classmethod
def to_thread(cls, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs) -> Future[T]:
def to_thread(
cls,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""A classmethod to call func in a separate thread see also [to_thread_by_name][async_kernel.Caller.to_thread_by_name]."""
return cls.to_thread_by_name(None, func, *args, **kwargs)

@classmethod
def to_thread_by_name(
cls, name: str | None, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs
cls,
name: str | None,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""
A classmethod to call func in the thread specified by name.
Expand Down
85 changes: 36 additions & 49 deletions tests/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
import threading
import time
import weakref
from contextvars import ContextVar
from random import random
from typing import Literal, cast

import anyio
import anyio.to_thread
import pytest
import sniffio
from anyio.abc import TaskStatus

from async_kernel.caller import (
AsyncEvent,
Expand Down Expand Up @@ -173,7 +171,7 @@ async def test_wait_cancelled_shield(self, caller: Caller):

def test_repr(self):
fut = Future(name="test", mydict={"test": "a long string" * 100})
assert repr(fut) == "Future< MainThread {'mydict': {…}, 'name': 'test'}>"
assert repr(fut) == "Future< MainThread {'mydict': {…}, 'name': 'test'} >"


@pytest.mark.anyio
Expand All @@ -190,6 +188,11 @@ async def test_sync(self):
caller.call_later(0.01, is_called.set)
await is_called.wait()

async def test_call_returns_future(self, caller: Caller):
fut = Future()
caller.call_direct(lambda: fut)
assert await caller.call_soon(lambda: fut) is fut

async def test_repr(self, caller):
async def test_func(a, b, c):
pass
Expand All @@ -198,9 +201,13 @@ async def test_func(a, b, c):
b = {f"name {i}": "long_string" * 100 for i in range(100)}
c = Future()
c.metadata.update(a=a, b=b)
assert repr(c) == "Future< MainThread {'a': 'long stringl…nglong string', 'b': {…}}>"
assert repr(c) == "Future< MainThread {'a': 'long stringl…nglong string', 'b': {…}} >"
fut = caller.call_soon(test_func, a, b, c)
assert repr(fut).startswith("Future< MainThread | <function TestCaller.test_repr.<locals>.test_func at")
assert repr(fut).startswith("Future< MainThread | <function")
await fut
assert repr(fut).startswith("Future< MainThread 🏁 | <function")
c.cancel()
assert repr(c) == "Future< MainThread ⛔ {'a': 'long stringl…nglong string', 'b': {…}} >"

def test_no_thread(self):
with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -253,6 +260,16 @@ async def runner():
await anyio.to_thread.run_sync(_in_thread)
assert caller not in Caller.all_callers()

async def test_direct_async(self, caller: Caller):
event = AsyncEvent()

async def set_event():
event.set()

caller.call_direct(set_event)
with anyio.fail_after(1):
await event.wait()

async def test_cancels_on_exit(self):
is_cancelled = False
async with Caller(create=True) as caller:
Expand Down Expand Up @@ -378,22 +395,16 @@ def func():
with pytest.raises(RuntimeError):
await fut

async def test_as_completed_cancelled(self, anyio_backend):
async def test_as_completed_cancelled(self, caller):
items = {Caller.to_thread(anyio.sleep, 100) for _ in range(4)}
async with Caller(create=True):

async def cancelled(task_status: TaskStatus[None]):
with pytest.raises(anyio.get_cancelled_exc_class()): # noqa: PT012
task_status.started()
async for _ in Caller.as_completed(items):
pass

async with anyio.create_task_group() as tg:
await tg.start(cancelled)
tg.cancel_scope.cancel()
for item in items:
with pytest.raises(FutureCancelledError):
await item
with anyio.move_on_after(0.1):
with pytest.raises(anyio.get_cancelled_exc_class()):
async for _ in Caller.as_completed(items):
pass
for item in items:
assert item.cancelled()
with pytest.raises(FutureCancelledError):
await item

async def test__check_in_thread(self, anyio_backend):
Caller.to_thread(anyio.sleep, 0.1)
Expand Down Expand Up @@ -424,7 +435,7 @@ async def func(a, b, results=results):
assert not caller.queue_exists(func)

async def test_gc(self, anyio_backend):
event_finalize_called = AsyncEvent()
event_finalize_called = anyio.Event()
async with Caller(create=True) as caller:
weakref.finalize(caller, event_finalize_called.set)
del caller
Expand Down Expand Up @@ -457,27 +468,6 @@ async def test_call_early(self, anyio_backend) -> None:
async with caller:
await fut

async def test_call_coroutine(self, caller: Caller):
# Test we can await a coroutine, note that it is not permitted with the type hints,
# but should probably be discouraged anyway since there is no way of knowing
# (with type hints) if a coroutine has already been awaited.
my_contextvar = ContextVar[int]("my_contextvar")
my_contextvar.set(1)

async def my_func():
await anyio.sleep(0)
assert my_contextvar.get() == 1
return True

# Discouraged
fut = Caller.to_thread(my_func()) # pyright: ignore[reportCallIssue, reportArgumentType]
val = await fut
assert val is True
# This the preferred way of calling.
fut = Caller.to_thread(my_func)
val = await fut
assert val is True

async def test_current_future(self, anyio_backend):
async with Caller(create=True) as caller:
fut = caller.call_soon(Caller.current_future)
Expand Down Expand Up @@ -545,14 +535,11 @@ async def async_func():
await anyio.sleep(10)
raise RuntimeError

async with anyio.create_task_group() as tg:
fut = caller.call_soon(async_func)
tg.start_soon(fut.wait)
await anyio.sleep(0)
tg.cancel_scope.cancel()
await anyio.sleep(0)
fut = caller.call_soon(async_func)
with anyio.move_on_after(0.1):
await fut
with pytest.raises(FutureCancelledError):
fut.exception() # pyright: ignore[reportPossiblyUnboundVariable]
fut.exception()

@pytest.mark.parametrize("return_when", ["FIRST_COMPLETED", "FIRST_EXCEPTION", "ALL_COMPLETED"])
async def test_wait(self, caller: Caller, return_when):
Expand Down
Loading