Skip to content

Commit

Permalink
[testing] auto-replay captured streams (#13803)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Sep 30, 2021
1 parent 5f25855 commit e1d1c7c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 15 deletions.
2 changes: 2 additions & 0 deletions docs/source/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,8 @@ If you need to capture both streams at once, use the parent :obj:`CaptureStd` cl
function_that_writes_to_stdout_and_stderr()
print(cs.err, cs.out)
Also, to aid debugging test issues, by default these context managers automatically replay the captured streams on exit
from the context.


Capturing logger stream
Expand Down
56 changes: 41 additions & 15 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,34 +610,54 @@ class CaptureStd:
"""
Context manager to capture:
- stdout, clean it up and make it available via obj.out
- stderr, and make it available via obj.err
- stdout: replay it, clean it up and make it available via ``obj.out``
- stderr: replay it and make it available via ``obj.err``
init arguments:
- out - capture stdout: True/False, default True
- err - capture stdout: True/False, default True
- out - capture stdout:`` True``/``False``, default ``True``
- err - capture stdout: ``True``/``False``, default ``True``
- replay - whether to replay or not: ``True``/``False``, default ``True``. By default each
captured stream gets replayed back on context's exit, so that one can see what the test was doing. If this is a
not wanted behavior and the captured data shouldn't be replayed, pass ``replay=False`` to disable this feature.
Examples::
# to capture stdout only with auto-replay
with CaptureStdout() as cs:
print("Secret message")
print(f"captured: {cs.out}")
assert "message" in cs.out
# to capture stderr only with auto-replay
import sys
with CaptureStderr() as cs:
print("Warning: ", file=sys.stderr)
print(f"captured: {cs.err}")
assert "Warning" in cs.err
# to capture just one of the streams, but not the other
# to capture both streams with auto-replay
with CaptureStd() as cs:
print("Secret message")
print("Warning: ", file=sys.stderr)
assert "message" in cs.out
assert "Warning" in cs.err
# to capture just one of the streams, and not the other, with auto-replay
with CaptureStd(err=False) as cs:
print("Secret message")
print(f"captured: {cs.out}")
assert "message" in cs.out
# but best use the stream-specific subclasses
# to capture without auto-replay
with CaptureStd(replay=False) as cs:
print("Secret message")
assert "message" in cs.out
"""

def __init__(self, out=True, err=True):
def __init__(self, out=True, err=True, replay=True):

self.replay = replay

if out:
self.out_buf = StringIO()
self.out = "error: CaptureStd context is unfinished yet, called too early"
Expand Down Expand Up @@ -666,11 +686,17 @@ def __enter__(self):
def __exit__(self, *exc):
if self.out_buf:
sys.stdout = self.out_old
self.out = apply_print_resets(self.out_buf.getvalue())
captured = self.out_buf.getvalue()
if self.replay:
sys.stdout.write(captured)
self.out = apply_print_resets(captured)

if self.err_buf:
sys.stderr = self.err_old
self.err = self.err_buf.getvalue()
captured = self.err_buf.getvalue()
if self.replay:
sys.stderr.write(captured)
self.err = captured

def __repr__(self):
msg = ""
Expand All @@ -690,15 +716,15 @@ def __repr__(self):
class CaptureStdout(CaptureStd):
"""Same as CaptureStd but captures only stdout"""

def __init__(self):
super().__init__(err=False)
def __init__(self, replay=True):
super().__init__(err=False, replay=replay)


class CaptureStderr(CaptureStd):
"""Same as CaptureStd but captures only stderr"""

def __init__(self):
super().__init__(out=False)
def __init__(self, replay=True):
super().__init__(out=False, replay=replay)


class CaptureLogger:
Expand Down

0 comments on commit e1d1c7c

Please sign in to comment.