Skip to content

Commit c812974

Browse files
authored
Add Caller.get_runner. (#126)
1 parent 85f6f95 commit c812974

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

src/async_kernel/caller.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import contextlib
45
import contextvars
56
import functools
@@ -11,6 +12,7 @@
1112
import weakref
1213
from collections import deque
1314
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
15+
from types import CoroutineType
1416
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast, overload
1517

1618
import anyio
@@ -556,6 +558,22 @@ def stopped(self) -> bool:
556558
"Returns `True` if the caller is stopped."
557559
return self._stopped
558560

561+
def get_runner(self, *, started: Callable[[], None] | None = None) -> CoroutineType[Any, Any, None]:
562+
"""A convenience method to run the caller.
563+
564+
!!! tip
565+
566+
See [async_kernel.caller.Caller.get_instance][] for a usage example.
567+
"""
568+
if self.running or self.stopped:
569+
raise RuntimeError
570+
571+
async def runner() -> None:
572+
async with self:
573+
await anyio.sleep_forever()
574+
575+
return runner()
576+
559577
def stop(self, *, force=False) -> None:
560578
"""
561579
Stop the caller, cancelling all pending tasks and close the thread.
@@ -751,8 +769,17 @@ def get_instance(cls, name: str | None = "MainThread", *, create: bool = False)
751769
for thread in cls._instances:
752770
if thread.name == name:
753771
return cls._instances[thread]
754-
if create:
755-
return cls.start_new(name=name)
772+
if name == "MainThread":
773+
if threading.current_thread() is threading.main_thread():
774+
if (backend := sniffio.current_async_library()) == Backend.asyncio:
775+
inst = cls(create=True)
776+
inst._task = asyncio.create_task(inst.get_runner()) # pyright: ignore[reportAttributeAccessIssue]
777+
return inst
778+
msg = f"Starting a caller for the MainThread is not supported for {backend=}"
779+
raise RuntimeError(msg)
780+
else:
781+
if create is True:
782+
return cls.start_new(name=name)
756783
msg = f"A Caller was not found for {name=}."
757784
raise RuntimeError(msg)
758785

tests/test_caller.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,23 @@ async def test_get_instance_no_instance(self, anyio_backend):
349349
with pytest.raises(RuntimeError):
350350
Caller.get_instance(None, create=False)
351351

352+
async def test_get_instance_get_runner(self, anyio_backend):
353+
if anyio_backend == Backend.trio:
354+
with pytest.raises(RuntimeError):
355+
Caller.get_instance()
356+
return
357+
caller = Caller.get_instance()
358+
try:
359+
await caller.call_soon(anyio.sleep, 0.01)
360+
finally:
361+
caller.stop()
362+
363+
async def test_get_runner_error(self):
364+
caller = Caller(create=True)
365+
caller.stop()
366+
with pytest.raises(RuntimeError):
367+
caller.get_runner() # pyright: ignore[reportUnusedCoroutine]
368+
352369
@pytest.mark.parametrize("mode", ["restricted", "surge"])
353370
async def test_as_completed(self, anyio_backend, mode: Literal["restricted", "surge"], mocker):
354371
mocker.patch.object(Caller, "MAX_IDLE_POOL_INSTANCES", new=2)

0 commit comments

Comments
 (0)