|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | import contextlib |
4 | 5 | import contextvars |
5 | 6 | import functools |
|
11 | 12 | import weakref |
12 | 13 | from collections import deque |
13 | 14 | from collections.abc import AsyncGenerator, Awaitable, Callable, Generator |
| 15 | +from types import CoroutineType |
14 | 16 | from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast, overload |
15 | 17 |
|
16 | 18 | import anyio |
@@ -556,6 +558,22 @@ def stopped(self) -> bool: |
556 | 558 | "Returns `True` if the caller is stopped." |
557 | 559 | return self._stopped |
558 | 560 |
|
| 561 | + def get_runner(self, *, started: Callable[[], None] | None = None) -> CoroutineType[Any, Any, None]: |
| 562 | + """A convenience method to run the caller. |
| 563 | +
|
| 564 | + !!! tip |
| 565 | +
|
| 566 | + See [async_kernel.caller.Caller.get_instance][] for a usage example. |
| 567 | + """ |
| 568 | + if self.running or self.stopped: |
| 569 | + raise RuntimeError |
| 570 | + |
| 571 | + async def runner() -> None: |
| 572 | + async with self: |
| 573 | + await anyio.sleep_forever() |
| 574 | + |
| 575 | + return runner() |
| 576 | + |
559 | 577 | def stop(self, *, force=False) -> None: |
560 | 578 | """ |
561 | 579 | Stop the caller, cancelling all pending tasks and close the thread. |
@@ -751,8 +769,17 @@ def get_instance(cls, name: str | None = "MainThread", *, create: bool = False) |
751 | 769 | for thread in cls._instances: |
752 | 770 | if thread.name == name: |
753 | 771 | return cls._instances[thread] |
754 | | - if create: |
755 | | - return cls.start_new(name=name) |
| 772 | + if name == "MainThread": |
| 773 | + if threading.current_thread() is threading.main_thread(): |
| 774 | + if (backend := sniffio.current_async_library()) == Backend.asyncio: |
| 775 | + inst = cls(create=True) |
| 776 | + inst._task = asyncio.create_task(inst.get_runner()) # pyright: ignore[reportAttributeAccessIssue] |
| 777 | + return inst |
| 778 | + msg = f"Starting a caller for the MainThread is not supported for {backend=}" |
| 779 | + raise RuntimeError(msg) |
| 780 | + else: |
| 781 | + if create is True: |
| 782 | + return cls.start_new(name=name) |
756 | 783 | msg = f"A Caller was not found for {name=}." |
757 | 784 | raise RuntimeError(msg) |
758 | 785 |
|
|
0 commit comments