Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/async_kernel/interface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def flusher(string: str, name=name) -> None:
echo.write(string) # pragma: no cover
echo.flush() # pragma: no cover

wrapper = OutStream(flusher=flusher)
wrapper = OutStream(flusher=flusher, mode=name)
setattr(sys, name, wrapper)

return restore
Expand Down
15 changes: 11 additions & 4 deletions src/async_kernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aiologic import Lock
from typing_extensions import override

import async_kernel
from async_kernel.common import Fixed

if TYPE_CHECKING:
Expand All @@ -17,16 +18,18 @@ class OutStream(TextIOBase):

_write_lock = Fixed(Lock)

def __init__(self, flusher: Callable[[str], None]) -> None:
def __init__(self, flusher: Callable[[str], None], *, mode: str) -> None:
"""
Args:
flusher: A callback responsible for sending the output.
ctx: The context variable to redirect output.

[reference for IOBase](https://docs.python.org/3/library/io.html#io.IOBase)
"""
super().__init__()
self._flusher = flusher
self._out = ""
self._ctx = {"stdout": async_kernel.utils._stdout_context, "stderr": async_kernel.utils._stdout_context}[mode] # pyright: ignore[reportPrivateUsage]

@override
def isatty(self) -> Literal[True]:
Expand All @@ -52,9 +55,13 @@ def flush(self) -> None:

@override
def write(self, string: str) -> int:
with self._write_lock:
self._out = string
self.flush()
if out := self._ctx.get():
out.write(string)
out.flush()
else:
with self._write_lock:
self._out = string
self.flush()
return len(string)

@override
Expand Down
40 changes: 40 additions & 0 deletions src/async_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import threading
import traceback
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
Expand All @@ -16,6 +17,7 @@

if TYPE_CHECKING:
from collections.abc import Generator
from contextlib import _SupportsRedirect, _SupportsRedirectT # pyright: ignore[reportPrivateUsage]

from async_kernel.kernel import Kernel
from async_kernel.typing import Content, Job, Message
Expand All @@ -33,14 +35,19 @@
"get_tags",
"get_timeout",
"mark_thread_pydev_do_not_trace",
"redirect_stderr",
"redirect_stdout",
"setattr_nested",
"subshell_context",
]

LAUNCHED_BY_DEBUGPY = "debugpy" in sys.modules


_job_var: ContextVar[Job[Any]] = ContextVar("async-kernel job")
_cell_id_var: ContextVar[str | None] = ContextVar("async-kernel cell_id", default=None)
_stdout_context: ContextVar[_SupportsRedirect | None] = ContextVar("async-kernel std_out", default=None)
_stderr_context: ContextVar[_SupportsRedirect | None] = ContextVar("async-kernel std_err", default=None)


def mark_thread_pydev_do_not_trace(thread: threading.Thread | None = None, *, remove=False) -> None:
Expand Down Expand Up @@ -195,3 +202,36 @@ def error_to_content(error: BaseException, /) -> Content:
"evalue": str(error),
"traceback": traceback.format_exception(error),
}


@contextmanager
def redirect_stdout(stream: _SupportsRedirectT, /) -> Generator[_SupportsRedirectT, Any, None]:
"""
Re-direct [sys.stdout][] generated in the current context.

See also:
- [contextlib.redirect_stdout][]
"""
assert get_kernel().event_started
token = _stdout_context.set(stream)
try:
yield stream
finally:
_stdout_context.reset(token)


@contextmanager
def redirect_stderr(stream: _SupportsRedirectT, /) -> Generator[_SupportsRedirectT, Any, None]:
"""
Re-direct [sys.stderr][] generated in the current context.

See also:
- [contextlib.redirect_stderr][]
"""

assert get_kernel().event_started
token = _stdout_context.set(stream)
try:
yield stream
finally:
_stdout_context.reset(token)
2 changes: 1 addition & 1 deletion tests/test_iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def flusher(string: str):
nonlocal output
output += string

stream = OutStream(flusher)
stream = OutStream(flusher, mode="stdout")

assert stream.errors is None
with pytest.raises(io.UnsupportedOperation):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import io
import pathlib
import sys
import threading
from typing import TYPE_CHECKING, Any, Literal

Expand Down Expand Up @@ -667,3 +669,19 @@ class MyShell(AsyncInteractiveShell): # pyright: ignore[reportUnusedClass]

class MySubshell(AsyncInteractiveSubshell): # pyright: ignore[reportUnusedClass]
pass


async def test_redirect_stdout(kernel: Kernel):

with async_kernel.utils.redirect_stdout(io.StringIO()) as f:
print("hello")
print("world")
assert f.getvalue() == "hello\nworld\n"


async def test_redirect_stderr(kernel: Kernel):

with async_kernel.utils.redirect_stderr(io.StringIO()) as f:
sys.stderr.write("hello")
sys.stderr.flush()
assert f.getvalue() == "hello"
Loading