Skip to content

Commit

Permalink
test_util: add decorator to set config values in test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 5, 2021
1 parent df69062 commit 6114e6a
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 145 deletions.
22 changes: 21 additions & 1 deletion jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from contextlib import contextmanager
import inspect
import functools
import re
import os
Expand Down Expand Up @@ -870,8 +871,18 @@ def getTestCaseNames(self, testCaseClass):
return names


def with_config(**kwds):
"""Test case decorator for subclasses of JaxTestCase"""
def decorator(cls):
assert inspect.isclass(cls) and issubclass(cls, JaxTestCase), "@with_config can only wrap JaxTestCase class definitions."
cls._default_config = {**JaxTestCase._default_config, **kwds}
return cls
return decorator


class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""
_default_config = {'jax_enable_checks': True}

# TODO(mattjj): this obscures the error messages from failures, figure out how
# to re-enable it
Expand All @@ -880,12 +891,21 @@ class JaxTestCase(parameterized.TestCase):

def setUp(self):
super().setUp()
config.update('jax_enable_checks', True)
self._original_config = {}
for key, value in self._default_config.items():
self._original_config[key] = getattr(config, key)
config.update(key, value)

# We use the adler32 hash for two reasons.
# a) it is deterministic run to run, unlike hash() which is randomized.
# b) it returns values in int32 range, which RandomState requires.
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))

def tearDown(self):
for key, value in self._original_config.items():
config.update(key, value)
super().tearDown()

def rng(self):
return self._rng

Expand Down
10 changes: 1 addition & 9 deletions tests/lax_numpy_einsum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,9 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class EinsumTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

def _check(self, s, *ops):
a = np.einsum(s, *ops)
b = jnp.einsum(s, *ops, precision=lax.Precision.HIGHEST)
Expand Down
20 changes: 2 additions & 18 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,18 +425,10 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None):
np.array([[1, 0], [1, 0]]))),
]),]

@jtu.with_config(jax_numpy_rank_promotion="raise")
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": "{}_inshape={}_indexer={}".format(
name, jtu.format_shape_dtype_string( shape, dtype), indexer),
Expand Down Expand Up @@ -947,17 +939,9 @@ def dtypes(op):
else:
return default_dtypes

@jtu.with_config(jax_numpy_rank_promotion="raise")
class IndexedUpdateTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
name, jtu.format_shape_dtype_string(shape, dtype), indexer,
Expand Down
50 changes: 5 additions & 45 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,18 +501,10 @@ def wrapper(*args, **kw):
return wrapper


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True):
def f():
out = [rng(shape, dtype or jnp.float_)
Expand Down Expand Up @@ -5492,17 +5484,9 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
GradSpecialValuesTestSpec(jnp.sinc, [0.], 1),
]

@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyGradTests(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
Expand Down Expand Up @@ -5605,17 +5589,9 @@ def testGradLogaddexp2Complex(self, shapes, dtype):
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)

@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpySignaturesTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

def testWrappedSignaturesMatch(self):
"""Test that jax.numpy function signatures match numpy."""
jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)}
Expand Down Expand Up @@ -5732,17 +5708,9 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]:
yield arg_dtypes


@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyUfuncTests(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(
{"testcase_name": f"_{name}_{','.join(arg_dtypes)}",
"name": name, "arg_dtypes": arg_dtypes}
Expand Down Expand Up @@ -5774,17 +5742,9 @@ def testUfuncInputTypes(self, name, arg_dtypes):
# that jnp returns float32. e.g. np.cos(np.uint8(0))
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2)

@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyDocTests(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

def test_lax_numpy_docstrings(self):
# Test that docstring wrapping & transformation didn't fail.

Expand Down
10 changes: 1 addition & 9 deletions tests/lax_numpy_vectorize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,9 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class VectorizeTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
Expand Down
10 changes: 1 addition & 9 deletions tests/lax_scipy_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,9 @@ def rand_sym_pos_def(rng, shape, dtype):
return matrix @ matrix.T.conj()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipyTests(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

def _fetch_preconditioner(self, preconditioner, A, rng=None):
"""
Returns one of various preconditioning matrices depending on the identifier
Expand Down
10 changes: 1 addition & 9 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
]


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Scipy implementation."""

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

def _GetArgsMaker(self, rng, shapes, dtypes):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

Expand Down
10 changes: 1 addition & 9 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,9 @@
int_dtypes = jtu.dtypes.all_integer
uint_dtypes = jtu.dtypes.all_unsigned

@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomTest(jtu.JaxTestCase):

def setUp(self):
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")
super().setUp()

def tearDown(self):
super().tearDown()
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)

def _CheckCollisions(self, samples, nbits):
fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev
nitems = len(samples)
Expand Down
10 changes: 1 addition & 9 deletions tests/scipy_ndimage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,9 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0):
return result


@jtu.with_config(jax_numpy_rank_promotion="raise")
class NdimageTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}".format(
jtu.format_shape_dtype_string(shape, dtype),
Expand Down
20 changes: 2 additions & 18 deletions tests/scipy_optimize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,9 @@ def zakharovFromIndices(x, ii):
return answer


@jtu.with_config(jax_numpy_rank_promotion="raise")
class TestBFGS(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter),
"maxiter": maxiter, "func_and_init": func_and_init}
Expand Down Expand Up @@ -149,17 +141,9 @@ def f(x):
jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')


@jtu.with_config(jax_numpy_rank_promotion="raise")
class TestLBFGS(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter),
"maxiter": maxiter, "func_and_init": func_and_init}
Expand Down
10 changes: 1 addition & 9 deletions tests/scipy_signal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,10 @@
default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipySignalTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.stats implementations"""

def setUp(self):
super().setUp()
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")

def tearDown(self):
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)
super().tearDown()

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format(
op,
Expand Down

0 comments on commit 6114e6a

Please sign in to comment.