Skip to content

Commit

Permalink
ContextManagers: Add retrieve_stdout()
Browse files Browse the repository at this point in the history
Fixes #544
  • Loading branch information
sils committed May 25, 2015
1 parent db047bf commit f64d732
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
44 changes: 36 additions & 8 deletions coalib/misc/ContextManagers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,42 @@
from contextlib import contextmanager
from contextlib import contextmanager, closing
import sys
import os
from io import StringIO

This comment has been minimized.

Copy link
@AbdealiLoKo

AbdealiLoKo May 26, 2015

Contributor

Two new lines needed here

@contextmanager
def replace_stdout(replacement):
"""
Replaces stdout with the replacement, yields back to the caller and then
reverts everything back.
"""
_stdout = sys.stdout
sys.stdout = replacement
try:
yield
finally:
sys.stdout = _stdout


@contextmanager
def suppress_stdout():
with open(os.devnull, "w") as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
"""
Suppresses everything going to stdout.
"""
with open(os.devnull, "w") as devnull, replace_stdout(devnull):
yield


@contextmanager
def retrieve_stdout():
"""
Yields a StringIO object from which one can read everything that was
printed to stdout. (It won't be printed to the real stdout!)
Example usage:
with retrieve_stdout() as stdout:
print("something") # Won't print to the console
what_was_printed = stdout.getvalue() # Save the value
"""
with closing(StringIO()) as sio, replace_stdout(sio):
yield sio
9 changes: 8 additions & 1 deletion coalib/tests/misc/ContextManagersTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys

sys.path.insert(0, ".")
from coalib.misc.ContextManagers import suppress_stdout
from coalib.misc.ContextManagers import suppress_stdout, retrieve_stdout


class SuppressStdoutTest(unittest.TestCase):
Expand All @@ -25,5 +25,12 @@ def no_print_func():
sys.stdout = old_stdout


class RetrieveStdoutTest(unittest.TestCase):
def test_retrieve_stdout(self):
with retrieve_stdout() as sio:
print("test")
self.assertEqual(sio.getvalue(), "test\n")


if __name__ == '__main__':
unittest.main(verbosity=2)

0 comments on commit f64d732

Please sign in to comment.