Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
import inspect
import logging
import reprlib
import threading
import time
import weakref
Expand Down Expand Up @@ -42,6 +43,11 @@
"ReentrantAsyncLock",
]

truncated_rep = reprlib.Repr()
truncated_rep.maxlevel = 1
truncated_rep.maxother = 100
truncated_rep.fillvalue = "…"


class FutureCancelledError(anyio.ClosedResourceError):
"Used to indicate a `Future` is cancelled."
Expand Down Expand Up @@ -107,36 +113,28 @@ class Future(Awaitable[T]):
execution results.
"""

__slots__ = [
"__weakref__",
"_cancel_scope",
"_cancelled",
"_done",
"_done_callbacks",
"_exception",
"_metadata",
"_result",
"_setting_value",
"_thread",
]
_cancelled = False
_cancel_scope: anyio.CancelScope | None = None
_exception = None
_setting_value = False
_result: T

"The thread in which the result is targeted to run."
REPR_OMIT: ClassVar[set[str]] = {"func", "args", "kwargs", "start_time", "delay"}

def __init__(self, thread: threading.Thread | None = None, /, **metadata) -> None:
self._cancel_scope: anyio.CancelScope | None = None
self._cancelled = False
self._done = AsyncEvent(thread)
self._done_callbacks = []
self._exception = None
self._metadata = metadata
self._setting_value = False
self._thread = thread or threading.current_thread()
self._thread = thread = thread or threading.current_thread()
self._done = AsyncEvent(thread)

@override
def __repr__(self) -> str:
metadata = " ".join(f"{k}:{v!r}" for k, v in self.metadata.items())
return f"Future<thread:{self._thread.name!r} {metadata}>"
md = self.metadata
if "func" in md:
items = [truncated_rep.repr(v) for k, v in md.items() if k not in self.REPR_OMIT]
rep = f"| {md['func']} {' | '.join(items) if items else ''}"
else:
rep = f"{truncated_rep.repr(md)}" if md else ""
return f"Future< {self._thread.name} {rep}>"

@override
def __await__(self) -> Generator[Any, None, T]:
Expand Down Expand Up @@ -363,7 +361,7 @@ class Caller:
_pool_instances: ClassVar[weakref.WeakSet[Self]] = weakref.WeakSet()
_queue_map: weakref.WeakKeyDictionary[Callable[..., Awaitable[Any]], MemoryObjectSendStream[tuple]]
_taskgroup: TaskGroup | None = None
_callers: deque[tuple[contextvars.Context, tuple[Future, float, float, Callable, tuple, dict]] | Callable[[], Any]]
_callers: deque[tuple[contextvars.Context, Future] | Callable[[], Any]]
_callers_added: threading.Event
_stopped_event: threading.Event
_stopped = False
Expand Down Expand Up @@ -466,38 +464,32 @@ async def _server_loop(self, tg: TaskGroup, task_status: TaskStatus[None]) -> No
except Exception as e:
self.log.exception("Simple call failed", exc_info=e)
else:
context, args = job
context.run(tg.start_soon, self._wrap_call, *args)
context, fut = job
context.run(tg.start_soon, self._wrap_call, fut)
finally:
self._running = False
for job in self._callers:
if not callable(job):
job[1][0].set_exception(FutureCancelledError())
if isinstance(job, tuple):
job[1].set_exception(FutureCancelledError())
socket.close()
self.iopub_sockets.pop(self.thread, None)
self._stopped_event.set()
tg.cancel_scope.cancel()

async def _wrap_call(
self,
fut: Future[T],
starttime: float,
delay: float,
func: Callable[..., T | Awaitable[T]],
args: tuple,
kwargs: dict,
) -> None:
async def _wrap_call(self, fut: Future[T]) -> None:
self._future_var.set(fut)
if fut.cancelled():
fut.set_result(cast("T", None)) # This will cancel
return
md = fut.metadata
func = md["func"]
try:
with anyio.CancelScope() as scope:
fut.set_cancel_scope(scope)
try:
if (delay_ := delay - time.monotonic() + starttime) > 0:
if (delay_ := md["delay"] - time.monotonic() + md["start_time"]) > 0:
await anyio.sleep(float(delay_))
result = func(*args, **kwargs) if callable(func) else func # pyright: ignore[reportAssignmentType]
result = func(*md["args"], **md["kwargs"]) if callable(func) else func # pyright: ignore[reportAssignmentType]
if inspect.isawaitable(result) and result is not fut:
result: T = await result
if fut.cancelled() and not scope.cancel_called:
Expand Down Expand Up @@ -575,10 +567,11 @@ def call_later(
if self._stopped:
raise anyio.ClosedResourceError
fut: Future[T] = Future(self.thread)
fut.metadata.update(start_time=time.monotonic(), delay=delay, func=func, args=args, kwargs=kwargs)
if threading.current_thread() is self.thread and (tg := self._taskgroup):
tg.start_soon(self._wrap_call, fut, time.monotonic(), delay, func, args, kwargs)
tg.start_soon(self._wrap_call, fut)
else:
self._callers.append((contextvars.copy_context(), (fut, time.monotonic(), delay, func, args, kwargs)))
self._callers.append((contextvars.copy_context(), fut))
self._callers_added.set()
self._outstanding += 1
return fut
Expand All @@ -596,9 +589,10 @@ def call_soon(self, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwa

def call_direct(self, func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwargs) -> None:
"""
Schedule func to be called in caller's event loop directly.
Schedule `func` to be called in caller's event loop directly.

The call is made without copying the context and does not use a future.
This method is provided to facilitate lightweight *thread-safe* function calls that
need to be done from within the callers event loop.

Args:
func: The function (awaitables permitted, though discouraged).
Expand All @@ -607,7 +601,8 @@ def call_direct(self, func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwar

??? warning

**Use this method for lightweight calls only.**
- Use this method for lightweight calls only.
- Corroutines will **not** be awaited.
"""
self._callers.append(functools.partial(func, *args, **kwargs))
self._callers_added.set()
Expand Down Expand Up @@ -918,6 +913,11 @@ async def wait(

Returns two sets of the futures: (done, pending).

Args:
items: An iterable of futures to wait for.
timeout: The maximum time before returning.
return_when: The same options as available for [asyncio.wait][].

!!! example

```python
Expand Down
18 changes: 15 additions & 3 deletions tests/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ async def test_wait_cancelled_shield(self, anyio_backend):
await fut.wait(timeout=0.001)
assert fut.cancelled()

def test_metadata(self):
fut = Future(name="test")
assert repr(fut) == "Future<thread:'MainThread' name:'test'>"
def test_repr(self):
fut = Future(name="test", mydict={"test": "a long string" * 100})
assert repr(fut) == "Future< MainThread {'mydict': {…}, 'name': 'test'}>"


@pytest.mark.anyio
Expand All @@ -190,6 +190,18 @@ async def test_sync(self):
caller.call_later(0.01, is_called.set)
await is_called.wait()

async def test_repr(self, caller):
async def test_func(a, b, c):
pass

a = "long string" * 100
b = {f"name {i}": "long_string" * 100 for i in range(100)}
c = Future()
c.metadata.update(a=a, b=b)
assert repr(c) == "Future< MainThread {'a': 'long stringl…nglong string', 'b': {…}}>"
fut = caller.call_soon(test_func, a, b, c)
assert repr(fut).startswith("Future< MainThread | <function TestCaller.test_repr.<locals>.test_func at")

def test_no_thread(self):
with pytest.raises(RuntimeError):
Caller()
Expand Down
Loading