Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from contextlib import asynccontextmanager
from types import CoroutineType
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Never, Self, cast, overload
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Never, Self, Unpack, cast, overload

import anyio
import sniffio
Expand Down Expand Up @@ -801,25 +801,27 @@ def stop_all(cls, *, _stop_protected: bool = False) -> None:
caller.stop(force=_stop_protected)

@classmethod
def get_instance(cls, name: str | None = "MainThread", *, create: bool = False) -> Self:
def get_instance(cls, *, create: bool | NoValue = NoValue, **kwargs: Unpack[CallerStartNewOptions]) -> Self: # pyright: ignore[reportInvalidTypeForm]
"""
A [classmethod][] that gets the caller associated to the thread using the threads name.


The default will provide the caller from the MainThread. If an instance doesn't exist
for the main thread an instance will be created and started when the backend provided
there is a running event loop.
When called without a name `MainThread` will be used as the `name`.

Args:
name: The name of the thread where the caller is base. When name is `None`, a new worker thread is created.
create: Create a new instance if one with the corresponding name does not already exist.
When not provided it defaults to `True` when `name` is `MainThread` otherwise `False`.
kwargs:
Options to use to identify or create a new instance if an instance does not already exist.
"""
if "name" not in kwargs:
kwargs["name"] = "MainThread"
for caller in cls._instances.values():
if caller.name == name:
if caller.name == kwargs["name"]:
return caller
if create is True or name == "MainThread":
return cls.start_new(name=name)
msg = f"A Caller was not found for {name=}."
if create is True or (create is NoValue and kwargs["name"] == "MainThread"):
return cls.start_new(**kwargs)
msg = f"A Caller was not found for {kwargs['name']=}."
raise RuntimeError(msg)

@classmethod
Expand All @@ -831,12 +833,12 @@ def to_thread(
**kwargs: P.kwargs,
) -> Future[T]:
"""A [classmethod][] to call func in a separate thread see also [to_thread_advanced][async_kernel.Caller.to_thread_advanced]."""
return cls.to_thread_advanced(None, func, *args, **kwargs)
return cls.to_thread_advanced({"name": None}, func, *args, **kwargs)

@classmethod
def to_thread_advanced(
cls,
options: CallerStartNewOptions | None,
options: CallerStartNewOptions,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
Expand All @@ -861,20 +863,14 @@ def to_thread_advanced(
Returns:
A future that can be awaited for the result of func.
"""
name = options["name"] if options else None
try:
caller = (
cls._to_thread_pool.popleft()
if not options and cls._to_thread_pool
else cls.get_instance(name=name, create=name is None)
)
except RuntimeError:
if not options:
raise
caller = cls.start_new(**options)

caller = None
if not options.get("name"):
with contextlib.suppress(IndexError):
caller = cls._to_thread_pool.popleft()
if caller is None:
caller = cls.get_instance(create=True, **options)
fut = caller.call_soon(func, *args, **kwargs)
if not name:
if not options.get("name"):
cls._pool_instances.add(caller)
cls._busy_worker_threads += 1

Expand Down
2 changes: 1 addition & 1 deletion src/async_kernel/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class ExecuteContent(TypedDict):
class CallerStartNewOptions(TypedDict):
"Options for [async_kernel.caller.Caller.start_new][]."

name: str
name: NotRequired[str | None]
log: NotRequired[logging.LoggerAdapter]
backend: NotRequired[Backend]
protected: NotRequired[bool]
Expand Down
5 changes: 4 additions & 1 deletion src/async_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ async def wait_thread_event(thread_event: threading.Event, /):

def _wait_thread_event(thread_event: threading.Event, event: anyio.Event, token):
thread_event.wait()
from_thread.run_sync(event.set, token=token)
try:
from_thread.run_sync(event.set, token=token)
except anyio.RunFinishedError:
pass

try:
event = anyio.Event()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ async def _run():
the_thread.start()
ready.wait()
assert isinstance(finished_event, AsyncEvent)
caller = Caller.get_instance(the_thread.name)
caller = Caller.get_instance(name=the_thread.name)
if check_result == "result":
expr = "10"
context = contextlib.nullcontext()
Expand Down Expand Up @@ -384,7 +384,7 @@ async def waiter():

async def test_get_instance_no_instance(self, anyio_backend):
with pytest.raises(RuntimeError):
Caller.get_instance(None, create=False)
Caller.get_instance(name=None, create=False)

async def test_get_instance_get_runner(self, anyio_backend):
if anyio_backend == Backend.trio:
Expand Down Expand Up @@ -562,7 +562,7 @@ async def close_tsc():
await anyio.sleep_forever()

fut = Caller.to_thread(close_tsc)
caller = Caller.get_instance(fut.thread.name)
caller = Caller.get_instance(name=fut.thread.name)
ready.wait()
never_called_future = caller.call_later(10, str)
proceed.set()
Expand Down