Skip to content

Commit d85392d

Browse files
author
Alan Fleming
committed
Stricter handling in Caller class.
1 parent 311dcd8 commit d85392d

File tree

2 files changed

+71
-49
lines changed

2 files changed

+71
-49
lines changed

src/async_kernel/caller.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,9 @@ class Caller:
371371
_backend: Backend
372372
_queue_map: weakref.WeakKeyDictionary[Callable[..., Awaitable[Any]], MemoryObjectSendStream[tuple]]
373373
_taskgroup: TaskGroup | None = None
374-
_callers: deque[tuple[contextvars.Context, Future] | Callable[[], Any]]
374+
_jobs: deque[tuple[contextvars.Context, Future] | Callable[[], Any]]
375375
_thread: threading.Thread
376-
_callers_added: threading.Event
376+
_job_added: threading.Event
377377
_stopped_event: threading.Event
378378
_stopped = False
379379
_protected = False
@@ -428,8 +428,8 @@ def __new__(
428428
inst._backend = Backend(sniffio.current_async_library())
429429
inst._thread = thread
430430
inst.log = log or logging.LoggerAdapter(logging.getLogger())
431-
inst._callers = deque()
432-
inst._callers_added = threading.Event()
431+
inst._jobs = deque()
432+
inst._job_added = threading.Event()
433433
inst._protected = protected
434434
inst._queue_map = weakref.WeakKeyDictionary()
435435
cls._instances[thread] = inst
@@ -461,24 +461,27 @@ async def _server_loop(self, tg: TaskGroup, task_status: TaskStatus[None]) -> No
461461
self.iopub_sockets[self.thread] = socket
462462
task_status.started()
463463
while not self._stopped:
464-
if not self._callers:
465-
self._callers_added.clear()
466-
await wait_thread_event(self._callers_added)
467-
while self._callers:
464+
if not self._jobs:
465+
self._job_added.clear()
466+
await wait_thread_event(self._job_added)
467+
while self._jobs:
468468
if self._stopped:
469469
return
470-
job = self._callers.popleft()
470+
job = self._jobs.popleft()
471471
if isinstance(job, Callable):
472472
try:
473-
job()
473+
if inspect.iscoroutinefunction(job):
474+
await job()
475+
else:
476+
job()
474477
except Exception as e:
475478
self.log.exception("Simple call failed", exc_info=e)
476479
else:
477480
context, fut = job
478481
context.run(tg.start_soon, self._wrap_call, fut)
479482
finally:
480483
self._running = False
481-
for job in self._callers:
484+
for job in self._jobs:
482485
if isinstance(job, tuple):
483486
job[1].set_exception(FutureCancelledError())
484487
socket.close()
@@ -493,8 +496,8 @@ def _schedule_wrapped_call(self, func: Callable, /, args: tuple, kwargs: dict, *
493496
if threading.current_thread() is self.thread and (tg := self._taskgroup):
494497
tg.start_soon(self._wrap_call, fut)
495498
else:
496-
self._callers.append((contextvars.copy_context(), fut))
497-
self._callers_added.set()
499+
self._jobs.append((contextvars.copy_context(), fut))
500+
self._job_added.set()
498501
return fut
499502

500503
async def _wrap_call(self, fut: Future) -> None:
@@ -510,9 +513,12 @@ async def _wrap_call(self, fut: Future) -> None:
510513
try:
511514
if (delay := md.get("delay")) and ((delay := delay - time.monotonic() + md["start_time"]) > 0):
512515
await anyio.sleep(delay)
513-
result = func(*md["args"], **md["kwargs"]) if callable(func) else func
514-
if inspect.isawaitable(result) and result is not fut:
515-
result = await result
516+
# Evaluate
517+
if inspect.iscoroutinefunction(func):
518+
result = await func(*md["args"], **md["kwargs"])
519+
else:
520+
result = func(*md["args"], **md["kwargs"])
521+
# Cancellation
516522
if fut.cancelled() and not scope.cancel_called:
517523
scope.cancel()
518524
fut.set_result(result)
@@ -566,15 +572,20 @@ def stop(self, *, force=False) -> None:
566572
for sender in self._queue_map.values():
567573
sender.close()
568574
self._queue_map.clear()
569-
self._callers_added.set()
575+
self._job_added.set()
570576
self._instances.pop(self.thread, None)
571577
if self in self._to_thread_pool:
572578
self._to_thread_pool.remove(self)
573579
if self.thread is not threading.current_thread():
574580
self._stopped_event.wait()
575581

576582
def call_later(
577-
self, delay: float, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs
583+
self,
584+
delay: float,
585+
func: Callable[P, T | CoroutineType[Any, Any, T]],
586+
/,
587+
*args: P.args,
588+
**kwargs: P.kwargs,
578589
) -> Future[T]:
579590
"""
580591
Schedule func to be called in caller's event loop copying the current context.
@@ -587,7 +598,13 @@ def call_later(
587598
"""
588599
return self._schedule_wrapped_call(func, args, kwargs, delay=delay, start_time=time.monotonic())
589600

590-
def call_soon(self, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs) -> Future[T]:
601+
def call_soon(
602+
self,
603+
func: Callable[P, T | CoroutineType[Any, Any, T]],
604+
/,
605+
*args: P.args,
606+
**kwargs: P.kwargs,
607+
) -> Future[T]:
591608
"""
592609
Schedule func to be called in caller's event loop copying the current context.
593610
@@ -598,12 +615,18 @@ def call_soon(self, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwa
598615
"""
599616
return self._schedule_wrapped_call(func, args, kwargs)
600617

601-
def call_direct(self, func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwargs) -> None:
618+
def call_direct(
619+
self,
620+
func: Callable[P, T | CoroutineType[Any, Any, T]],
621+
/,
622+
*args: P.args,
623+
**kwargs: P.kwargs,
624+
) -> None:
602625
"""
603626
Schedule `func` to be called in caller's event loop directly.
604627
605628
This method is provided to facilitate lightweight *thread-safe* function calls that
606-
need to be done from within the callers event loop.
629+
need to be performed from within the callers event loop/taskgroup.
607630
608631
Args:
609632
func: The function (awaitables permitted, though discouraged).
@@ -612,11 +635,11 @@ def call_direct(self, func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwar
612635
613636
??? warning
614637
615-
- Use this method for lightweight calls only.
616-
- Corroutines will **not** be awaited.
638+
**Use this method for lightweight calls only!**
639+
617640
"""
618-
self._callers.append(functools.partial(func, *args, **kwargs))
619-
self._callers_added.set()
641+
self._jobs.append(functools.partial(func, *args, **kwargs))
642+
self._job_added.set()
620643

621644
def queue_exists(self, func: Callable) -> bool:
622645
"Returns True if an execution queue exists for `func`."
@@ -736,13 +759,24 @@ def get_instance(cls, name: str | None = "MainThread", *, create: bool = False)
736759
raise RuntimeError(msg)
737760

738761
@classmethod
739-
def to_thread(cls, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs) -> Future[T]:
762+
def to_thread(
763+
cls,
764+
func: Callable[P, T | CoroutineType[Any, Any, T]],
765+
/,
766+
*args: P.args,
767+
**kwargs: P.kwargs,
768+
) -> Future[T]:
740769
"""A classmethod to call func in a separate thread see also [to_thread_by_name][async_kernel.Caller.to_thread_by_name]."""
741770
return cls.to_thread_by_name(None, func, *args, **kwargs)
742771

743772
@classmethod
744773
def to_thread_by_name(
745-
cls, name: str | None, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs
774+
cls,
775+
name: str | None,
776+
func: Callable[P, T | CoroutineType[Any, Any, T]],
777+
/,
778+
*args: P.args,
779+
**kwargs: P.kwargs,
746780
) -> Future[T]:
747781
"""
748782
A classmethod to call func in the thread specified by name.

tests/test_caller.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import threading
55
import time
66
import weakref
7-
from contextvars import ContextVar
87
from random import random
98
from typing import Literal, cast
109

@@ -253,6 +252,16 @@ async def runner():
253252
await anyio.to_thread.run_sync(_in_thread)
254253
assert caller not in Caller.all_callers()
255254

255+
async def test_direct_async(self, caller: Caller):
256+
event = AsyncEvent()
257+
258+
async def set_event():
259+
event.set()
260+
261+
caller.call_direct(set_event)
262+
with anyio.fail_after(1):
263+
await event.wait()
264+
256265
async def test_cancels_on_exit(self):
257266
is_cancelled = False
258267
async with Caller(create=True) as caller:
@@ -457,27 +466,6 @@ async def test_call_early(self, anyio_backend) -> None:
457466
async with caller:
458467
await fut
459468

460-
async def test_call_coroutine(self, caller: Caller):
461-
# Test we can await a coroutine, note that it is not permitted with the type hints,
462-
# but should probably be discouraged anyway since there is no way of knowing
463-
# (with type hints) if a coroutine has already been awaited.
464-
my_contextvar = ContextVar[int]("my_contextvar")
465-
my_contextvar.set(1)
466-
467-
async def my_func():
468-
await anyio.sleep(0)
469-
assert my_contextvar.get() == 1
470-
return True
471-
472-
# Discouraged
473-
fut = Caller.to_thread(my_func()) # pyright: ignore[reportCallIssue, reportArgumentType]
474-
val = await fut
475-
assert val is True
476-
# This the preferred way of calling.
477-
fut = Caller.to_thread(my_func)
478-
val = await fut
479-
assert val is True
480-
481469
async def test_current_future(self, anyio_backend):
482470
async with Caller(create=True) as caller:
483471
fut = caller.call_soon(Caller.current_future)

0 commit comments

Comments
 (0)