@@ -371,9 +371,9 @@ class Caller:
371371 _backend : Backend
372372 _queue_map : weakref .WeakKeyDictionary [Callable [..., Awaitable [Any ]], MemoryObjectSendStream [tuple ]]
373373 _taskgroup : TaskGroup | None = None
374- _callers : deque [tuple [contextvars .Context , Future ] | Callable [[], Any ]]
374+ _jobs : deque [tuple [contextvars .Context , Future ] | Callable [[], Any ]]
375375 _thread : threading .Thread
376- _callers_added : threading .Event
376+ _job_added : threading .Event
377377 _stopped_event : threading .Event
378378 _stopped = False
379379 _protected = False
@@ -428,8 +428,8 @@ def __new__(
428428 inst ._backend = Backend (sniffio .current_async_library ())
429429 inst ._thread = thread
430430 inst .log = log or logging .LoggerAdapter (logging .getLogger ())
431- inst ._callers = deque ()
432- inst ._callers_added = threading .Event ()
431+ inst ._jobs = deque ()
432+ inst ._job_added = threading .Event ()
433433 inst ._protected = protected
434434 inst ._queue_map = weakref .WeakKeyDictionary ()
435435 cls ._instances [thread ] = inst
@@ -461,24 +461,27 @@ async def _server_loop(self, tg: TaskGroup, task_status: TaskStatus[None]) -> No
461461 self .iopub_sockets [self .thread ] = socket
462462 task_status .started ()
463463 while not self ._stopped :
464- if not self ._callers :
465- self ._callers_added .clear ()
466- await wait_thread_event (self ._callers_added )
467- while self ._callers :
464+ if not self ._jobs :
465+ self ._job_added .clear ()
466+ await wait_thread_event (self ._job_added )
467+ while self ._jobs :
468468 if self ._stopped :
469469 return
470- job = self ._callers .popleft ()
470+ job = self ._jobs .popleft ()
471471 if isinstance (job , Callable ):
472472 try :
473- job ()
473+ if inspect .iscoroutinefunction (job ):
474+ await job ()
475+ else :
476+ job ()
474477 except Exception as e :
475478 self .log .exception ("Simple call failed" , exc_info = e )
476479 else :
477480 context , fut = job
478481 context .run (tg .start_soon , self ._wrap_call , fut )
479482 finally :
480483 self ._running = False
481- for job in self ._callers :
484+ for job in self ._jobs :
482485 if isinstance (job , tuple ):
483486 job [1 ].set_exception (FutureCancelledError ())
484487 socket .close ()
@@ -493,8 +496,8 @@ def _schedule_wrapped_call(self, func: Callable, /, args: tuple, kwargs: dict, *
493496 if threading .current_thread () is self .thread and (tg := self ._taskgroup ):
494497 tg .start_soon (self ._wrap_call , fut )
495498 else :
496- self ._callers .append ((contextvars .copy_context (), fut ))
497- self ._callers_added .set ()
499+ self ._jobs .append ((contextvars .copy_context (), fut ))
500+ self ._job_added .set ()
498501 return fut
499502
500503 async def _wrap_call (self , fut : Future ) -> None :
@@ -510,9 +513,12 @@ async def _wrap_call(self, fut: Future) -> None:
510513 try :
511514 if (delay := md .get ("delay" )) and ((delay := delay - time .monotonic () + md ["start_time" ]) > 0 ):
512515 await anyio .sleep (delay )
513- result = func (* md ["args" ], ** md ["kwargs" ]) if callable (func ) else func
514- if inspect .isawaitable (result ) and result is not fut :
515- result = await result
516+ # Evaluate
517+ if inspect .iscoroutinefunction (func ):
518+ result = await func (* md ["args" ], ** md ["kwargs" ])
519+ else :
520+ result = func (* md ["args" ], ** md ["kwargs" ])
521+ # Cancellation
516522 if fut .cancelled () and not scope .cancel_called :
517523 scope .cancel ()
518524 fut .set_result (result )
@@ -566,15 +572,20 @@ def stop(self, *, force=False) -> None:
566572 for sender in self ._queue_map .values ():
567573 sender .close ()
568574 self ._queue_map .clear ()
569- self ._callers_added .set ()
575+ self ._job_added .set ()
570576 self ._instances .pop (self .thread , None )
571577 if self in self ._to_thread_pool :
572578 self ._to_thread_pool .remove (self )
573579 if self .thread is not threading .current_thread ():
574580 self ._stopped_event .wait ()
575581
576582 def call_later (
577- self , delay : float , func : Callable [P , T | Awaitable [T ]], / , * args : P .args , ** kwargs : P .kwargs
583+ self ,
584+ delay : float ,
585+ func : Callable [P , T | CoroutineType [Any , Any , T ]],
586+ / ,
587+ * args : P .args ,
588+ ** kwargs : P .kwargs ,
578589 ) -> Future [T ]:
579590 """
580591 Schedule func to be called in caller's event loop copying the current context.
@@ -587,7 +598,13 @@ def call_later(
587598 """
588599 return self ._schedule_wrapped_call (func , args , kwargs , delay = delay , start_time = time .monotonic ())
589600
590- def call_soon (self , func : Callable [P , T | Awaitable [T ]], / , * args : P .args , ** kwargs : P .kwargs ) -> Future [T ]:
601+ def call_soon (
602+ self ,
603+ func : Callable [P , T | CoroutineType [Any , Any , T ]],
604+ / ,
605+ * args : P .args ,
606+ ** kwargs : P .kwargs ,
607+ ) -> Future [T ]:
591608 """
592609 Schedule func to be called in caller's event loop copying the current context.
593610
@@ -598,12 +615,18 @@ def call_soon(self, func: Callable[P, T | Awaitable[T]], /, *args: P.args, **kwa
598615 """
599616 return self ._schedule_wrapped_call (func , args , kwargs )
600617
601- def call_direct (self , func : Callable [P , Any ], / , * args : P .args , ** kwargs : P .kwargs ) -> None :
618+ def call_direct (
619+ self ,
620+ func : Callable [P , T | CoroutineType [Any , Any , T ]],
621+ / ,
622+ * args : P .args ,
623+ ** kwargs : P .kwargs ,
624+ ) -> None :
602625 """
603626 Schedule `func` to be called in caller's event loop directly.
604627
605628 This method is provided to facilitate lightweight *thread-safe* function calls that
606- need to be done from within the callers event loop.
629+ need to be performed from within the callers event loop/taskgroup .
607630
608631 Args:
609632 func: The function (awaitables permitted, though discouraged).
@@ -612,11 +635,11 @@ def call_direct(self, func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwar
612635
613636 ??? warning
614637
615- - Use this method for lightweight calls only.
616- - Corroutines will **not** be awaited.
638+ ** Use this method for lightweight calls only!**
639+
617640 """
618- self ._callers .append (functools .partial (func , * args , ** kwargs ))
619- self ._callers_added .set ()
641+ self ._jobs .append (functools .partial (func , * args , ** kwargs ))
642+ self ._job_added .set ()
620643
621644 def queue_exists (self , func : Callable ) -> bool :
622645 "Returns True if an execution queue exists for `func`."
@@ -736,13 +759,24 @@ def get_instance(cls, name: str | None = "MainThread", *, create: bool = False)
736759 raise RuntimeError (msg )
737760
738761 @classmethod
739- def to_thread (cls , func : Callable [P , T | Awaitable [T ]], / , * args : P .args , ** kwargs : P .kwargs ) -> Future [T ]:
762+ def to_thread (
763+ cls ,
764+ func : Callable [P , T | CoroutineType [Any , Any , T ]],
765+ / ,
766+ * args : P .args ,
767+ ** kwargs : P .kwargs ,
768+ ) -> Future [T ]:
740769 """A classmethod to call func in a separate thread see also [to_thread_by_name][async_kernel.Caller.to_thread_by_name]."""
741770 return cls .to_thread_by_name (None , func , * args , ** kwargs )
742771
743772 @classmethod
744773 def to_thread_by_name (
745- cls , name : str | None , func : Callable [P , T | Awaitable [T ]], / , * args : P .args , ** kwargs : P .kwargs
774+ cls ,
775+ name : str | None ,
776+ func : Callable [P , T | CoroutineType [Any , Any , T ]],
777+ / ,
778+ * args : P .args ,
779+ ** kwargs : P .kwargs ,
746780 ) -> Future [T ]:
747781 """
748782 A classmethod to call func in the thread specified by name.
0 commit comments