Skip to content

Commit

Permalink
jax.test_util: add capture_stdout context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 12, 2022
1 parent e5725f1 commit 0063661
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 76 deletions.
11 changes: 10 additions & 1 deletion jax/_src/test_util.py
Expand Up @@ -14,12 +14,13 @@

from contextlib import contextmanager
import inspect
import io
import functools
from functools import partial
import re
import os
import textwrap
from typing import Dict, List, Generator, Sequence, Tuple, Union
from typing import Callable, Dict, List, Generator, Sequence, Tuple, Union
import unittest
import warnings
import zlib
Expand Down Expand Up @@ -153,6 +154,14 @@ def check_eq(xs, ys, err_msg=''):
tree_all(tree_map(assert_close, xs, ys))


@contextmanager
def capture_stdout() -> Generator[Callable[[], str], None, None]:
with unittest.mock.patch('sys.stdout', new_callable=io.StringIO) as fp:
def _read() -> str:
return fp.getvalue()
yield _read


@contextmanager
def count_device_put():
device_put = dispatch.device_put
Expand Down

0 comments on commit 0063661

Please sign in to comment.