diff --git a/src/async_kernel/interface/base.py b/src/async_kernel/interface/base.py index 0df09c59..e58672df 100644 --- a/src/async_kernel/interface/base.py +++ b/src/async_kernel/interface/base.py @@ -205,7 +205,7 @@ def flusher(string: str, name=name) -> None: echo.write(string) # pragma: no cover echo.flush() # pragma: no cover - wrapper = OutStream(flusher=flusher) + wrapper = OutStream(flusher=flusher, mode=name) setattr(sys, name, wrapper) return restore diff --git a/src/async_kernel/iostream.py b/src/async_kernel/iostream.py index f3d523fd..e42d27d1 100644 --- a/src/async_kernel/iostream.py +++ b/src/async_kernel/iostream.py @@ -6,6 +6,7 @@ from aiologic import Lock from typing_extensions import override +import async_kernel from async_kernel.common import Fixed if TYPE_CHECKING: @@ -17,16 +18,18 @@ class OutStream(TextIOBase): _write_lock = Fixed(Lock) - def __init__(self, flusher: Callable[[str], None]) -> None: + def __init__(self, flusher: Callable[[str], None], *, mode: str) -> None: """ Args: flusher: A callback responsible for sending the output. + ctx: The context variable to redirect output. [reference for IOBase](https://docs.python.org/3/library/io.html#io.IOBase) """ super().__init__() self._flusher = flusher self._out = "" + self._ctx = {"stdout": async_kernel.utils._stdout_context, "stderr": async_kernel.utils._stdout_context}[mode] # pyright: ignore[reportPrivateUsage] @override def isatty(self) -> Literal[True]: @@ -52,9 +55,13 @@ def flush(self) -> None: @override def write(self, string: str) -> int: - with self._write_lock: - self._out = string - self.flush() + if out := self._ctx.get(): + out.write(string) + out.flush() + else: + with self._write_lock: + self._out = string + self.flush() return len(string) @override diff --git a/src/async_kernel/utils.py b/src/async_kernel/utils.py index 06c7b3b7..6341c950 100644 --- a/src/async_kernel/utils.py +++ b/src/async_kernel/utils.py @@ -4,6 +4,7 @@ import sys import threading import traceback +from collections.abc import Generator from contextlib import contextmanager from contextvars import ContextVar from typing import TYPE_CHECKING, Any @@ -16,6 +17,7 @@ if TYPE_CHECKING: from collections.abc import Generator + from contextlib import _SupportsRedirect, _SupportsRedirectT # pyright: ignore[reportPrivateUsage] from async_kernel.kernel import Kernel from async_kernel.typing import Content, Job, Message @@ -33,14 +35,19 @@ "get_tags", "get_timeout", "mark_thread_pydev_do_not_trace", + "redirect_stderr", + "redirect_stdout", "setattr_nested", "subshell_context", ] LAUNCHED_BY_DEBUGPY = "debugpy" in sys.modules + _job_var: ContextVar[Job[Any]] = ContextVar("async-kernel job") _cell_id_var: ContextVar[str | None] = ContextVar("async-kernel cell_id", default=None) +_stdout_context: ContextVar[_SupportsRedirect | None] = ContextVar("async-kernel std_out", default=None) +_stderr_context: ContextVar[_SupportsRedirect | None] = ContextVar("async-kernel std_err", default=None) def mark_thread_pydev_do_not_trace(thread: threading.Thread | None = None, *, remove=False) -> None: @@ -195,3 +202,36 @@ def error_to_content(error: BaseException, /) -> Content: "evalue": str(error), "traceback": traceback.format_exception(error), } + + +@contextmanager +def redirect_stdout(stream: _SupportsRedirectT, /) -> Generator[_SupportsRedirectT, Any, None]: + """ + Re-direct [sys.stdout][] generated in the current context. + + See also: + - [contextlib.redirect_stdout][] + """ + assert get_kernel().event_started + token = _stdout_context.set(stream) + try: + yield stream + finally: + _stdout_context.reset(token) + + +@contextmanager +def redirect_stderr(stream: _SupportsRedirectT, /) -> Generator[_SupportsRedirectT, Any, None]: + """ + Re-direct [sys.stderr][] generated in the current context. + + See also: + - [contextlib.redirect_stderr][] + """ + + assert get_kernel().event_started + token = _stdout_context.set(stream) + try: + yield stream + finally: + _stdout_context.reset(token) diff --git a/tests/test_iostream.py b/tests/test_iostream.py index d4a6ca40..c30015e8 100644 --- a/tests/test_iostream.py +++ b/tests/test_iostream.py @@ -15,7 +15,7 @@ def flusher(string: str): nonlocal output output += string - stream = OutStream(flusher) + stream = OutStream(flusher, mode="stdout") assert stream.errors is None with pytest.raises(io.UnsupportedOperation): diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 10293253..df0cdca5 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -1,6 +1,8 @@ from __future__ import annotations +import io import pathlib +import sys import threading from typing import TYPE_CHECKING, Any, Literal @@ -667,3 +669,19 @@ class MyShell(AsyncInteractiveShell): # pyright: ignore[reportUnusedClass] class MySubshell(AsyncInteractiveSubshell): # pyright: ignore[reportUnusedClass] pass + + +async def test_redirect_stdout(kernel: Kernel): + + with async_kernel.utils.redirect_stdout(io.StringIO()) as f: + print("hello") + print("world") + assert f.getvalue() == "hello\nworld\n" + + +async def test_redirect_stderr(kernel: Kernel): + + with async_kernel.utils.redirect_stderr(io.StringIO()) as f: + sys.stderr.write("hello") + sys.stderr.flush() + assert f.getvalue() == "hello"