Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
d42fffd by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
  • Loading branch information
jax authors committed Feb 11, 2022
1 parent 8b4a7ce commit 5691010
Show file tree
Hide file tree
Showing 29 changed files with 73 additions and 20 deletions.
8 changes: 0 additions & 8 deletions CHANGELOG.md
Expand Up @@ -29,14 +29,6 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
`dialect=` is passed.
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
`ir.Module` object instead of its string representation.
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
default. To recover the previous behavior, use the `jax.test_util.with_config`
decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTest(jtu.JaxTestCase):
...
```

## jaxlib 0.1.76 (Jan 27, 2022)

Expand Down
5 changes: 1 addition & 4 deletions jax/_src/test_util.py
Expand Up @@ -915,10 +915,7 @@ def decorator(cls):

class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""
_default_config = {
'jax_enable_checks': True,
'jax_numpy_rank_promotion': 'raise',
}
_default_config = {'jax_enable_checks': True}

# TODO(mattjj): this obscures the error messages from failures, figure out how
# to re-enable it
Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jax2tf/tests/tf_test_util.py
Expand Up @@ -150,7 +150,6 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
return f1(*args1)


@jtu.with_config(jax_numpy_rank_promotion="allow")
class JaxToTfTestCase(jtu.JaxTestCase):

def setUp(self):
Expand Down
17 changes: 15 additions & 2 deletions tests/api_test.py
Expand Up @@ -70,6 +70,7 @@
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))


@jtu.with_config(jax_numpy_rank_promotion="raise")
class CPPJitTest(jtu.BufferDonationTestCase):
"""Shared tests between the Python and the C++ jax,jit implementations.
Expand Down Expand Up @@ -859,14 +860,15 @@ def f(k):
python_should_be_executing = False
self.assertEqual(x, f(x))


@jtu.with_config(jax_numpy_rank_promotion="raise")
class PythonJitTest(CPPJitTest):

@property
def jit(self):
return api._python_jit


@jtu.with_config(jax_numpy_rank_promotion="raise")
class APITest(jtu.JaxTestCase):

def test_grad_item(self):
Expand Down Expand Up @@ -3414,6 +3416,7 @@ def f(x):
FLAGS.jax_numpy_rank_promotion = allow_promotion


@jtu.with_config(jax_numpy_rank_promotion="raise")
class RematTest(jtu.JaxTestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -4270,6 +4273,7 @@ def f(u, x):

_ = api.linearize(partial(f, core.unit), 3.)

@jtu.with_config(jax_numpy_rank_promotion="raise")
class JaxprTest(jtu.JaxTestCase):

def test_scalar_literals(self):
Expand Down Expand Up @@ -4413,6 +4417,7 @@ def test_convert_element_type_literal_constant_folding(self):
self.assertLen(jaxpr.eqns, 0)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomJVPTest(jtu.JaxTestCase):

def test_basic(self):
Expand Down Expand Up @@ -5387,6 +5392,7 @@ def f_jvp(primals, tangents):
self.assertEqual(shape, ())


@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomVJPTest(jtu.JaxTestCase):

def test_basic(self):
Expand Down Expand Up @@ -6355,6 +6361,7 @@ def transposed(y):
return transposed


@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomTransposeTest(jtu.JaxTestCase):

def test_linear_call(self):
Expand Down Expand Up @@ -6683,6 +6690,7 @@ def tp(r, t): return t / r
self.assertAllClose(f_t(x), jax.jit(f_t)(x))


@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomVmapTest(jtu.JaxTestCase):

def test_basic(self):
Expand Down Expand Up @@ -7109,6 +7117,7 @@ def vmap_ref(xs, y):
self.assertEqual(str(jaxpr), str(jaxpr_ref))


@jtu.with_config(jax_numpy_rank_promotion="raise")
class CustomApiTest(jtu.JaxTestCase):
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""

Expand Down Expand Up @@ -7146,6 +7155,7 @@ def test_def_method_forwarding_all_permutations(self):
self.assertIsInstance(getattr(f, method), Callable)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class InvertibleADTest(jtu.JaxTestCase):

@jtu.ignore_warning(message="Values that an @invertible function closes")
Expand Down Expand Up @@ -7254,6 +7264,7 @@ def f(x, y):
check_dtypes=True)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class BufferDonationTest(jtu.BufferDonationTestCase):

@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
Expand All @@ -7276,6 +7287,7 @@ def test_pmap_nested_donate_ignored(self):
pmap_fun(a) # doesn't crash


@jtu.with_config(jax_numpy_rank_promotion="raise")
class NamedCallTest(jtu.JaxTestCase):

def test_default_name(self):
Expand Down Expand Up @@ -7356,6 +7368,7 @@ def test_integer_overflow(self, jit_type, func):
self.assertRaises(OverflowError, f, int_min - 1)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class BackendsTest(jtu.JaxTestCase):

@unittest.skipIf(not sys.executable, "test requires sys.executable")
Expand All @@ -7378,6 +7391,7 @@ def test_cpu_warning_suppression(self):
assert "No GPU/TPU found" not in result.stderr.decode()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class CleanupTest(jtu.JaxTestCase):
def test_call_wrapped_second_phase_cleanup(self):
try:
Expand Down Expand Up @@ -7538,7 +7552,6 @@ def f(x, n):
self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.out_avals[0].shape[0])
self.assertEqual(4, jaxpr.out_avals[0].shape[1])

@jax.numpy_rank_promotion("allow") # explicitly exercises implicit rank promotion.
def test_basic_batchpoly_neuralnet(self):
def predict(params, inputs):
for W, b in params:
Expand Down
1 change: 1 addition & 0 deletions tests/batching_test.py
Expand Up @@ -40,6 +40,7 @@
# These are 'manual' tests for batching (vmap). The more exhaustive, more
# systematic tests are in lax_test.py's LaxVmapTest class.

@jtu.with_config(jax_numpy_rank_promotion="raise")
class BatchingTest(jtu.JaxTestCase):

def testConstantFunction(self):
Expand Down
1 change: 1 addition & 0 deletions tests/fft_test.py
Expand Up @@ -93,6 +93,7 @@ def _zero_for_irfft(z, axes):
return jnp.concatenate(parts, axis=axis)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class FftTest(jtu.JaxTestCase):

def testNotImplemented(self):
Expand Down
1 change: 1 addition & 0 deletions tests/infeed_test.py
Expand Up @@ -26,6 +26,7 @@

config.parse_flags_with_absl()

@jtu.with_config(jax_numpy_rank_promotion="raise")
class InfeedTest(jtu.JaxTestCase):

@jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion.
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_autodiff_test.py
Expand Up @@ -189,7 +189,7 @@ def check_grads_bilinear(f, args, order,
check_grads(lambda rhs: f(lhs, rhs), (rhs,), order,
modes=modes, atol=atol, rtol=rtol, eps=1.)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxAutodiffTest(jtu.JaxTestCase):

@parameterized.named_parameters(itertools.chain.from_iterable(
Expand Down
1 change: 1 addition & 0 deletions tests/lax_numpy_einsum_test.py
Expand Up @@ -30,6 +30,7 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class EinsumTest(jtu.JaxTestCase):

def _check(self, s, *ops):
Expand Down
2 changes: 2 additions & 0 deletions tests/lax_numpy_indexing_test.py
Expand Up @@ -414,6 +414,7 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None):
MODES = ["clip", "drop", "promise_in_bounds"]


@jtu.with_config(jax_numpy_rank_promotion="raise")
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""

Expand Down Expand Up @@ -996,6 +997,7 @@ def _update_tol(op):
return tol


@jtu.with_config(jax_numpy_rank_promotion="raise")
class IndexedUpdateTest(jtu.JaxTestCase):

@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
Expand Down
5 changes: 5 additions & 0 deletions tests/lax_numpy_test.py
Expand Up @@ -519,6 +519,7 @@ def wrapper(*args, **kw):
return wrapper


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""

Expand Down Expand Up @@ -5914,6 +5915,7 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
GradSpecialValuesTestSpec(jnp.sinc, [0.], 1),
]

@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyGradTests(jtu.JaxTestCase):

@parameterized.named_parameters(itertools.chain.from_iterable(
Expand Down Expand Up @@ -6018,6 +6020,7 @@ def testGradLogaddexp2Complex(self, shapes, dtype):
tol = 3e-2
check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol)

@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpySignaturesTest(jtu.JaxTestCase):

def testWrappedSignaturesMatch(self):
Expand Down Expand Up @@ -6133,6 +6136,7 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]:
yield arg_dtypes


@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyUfuncTests(jtu.JaxTestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -6164,6 +6168,7 @@ def testUfuncInputTypes(self, name, arg_dtypes):
# that jnp returns float32. e.g. np.cos(np.uint8(0))
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2)

@jtu.with_config(jax_numpy_rank_promotion="raise")
class NumpyDocTests(jtu.JaxTestCase):

def test_lax_numpy_docstrings(self):
Expand Down
1 change: 1 addition & 0 deletions tests/lax_numpy_vectorize_test.py
Expand Up @@ -25,6 +25,7 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class VectorizeTest(jtu.JaxTestCase):

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down
1 change: 1 addition & 0 deletions tests/lax_scipy_sparse_test.py
Expand Up @@ -64,6 +64,7 @@ def rand_sym_pos_def(rng, shape, dtype):
return matrix @ matrix.T.conj()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipyTests(jtu.JaxTestCase):

def _fetch_preconditioner(self, preconditioner, A, rng=None):
Expand Down
1 change: 1 addition & 0 deletions tests/lax_scipy_test.py
Expand Up @@ -144,6 +144,7 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
]


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxBackedScipyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Scipy implementation."""

Expand Down
3 changes: 3 additions & 0 deletions tests/lax_test.py
Expand Up @@ -181,6 +181,7 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None):
]


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxTest(jtu.JaxTestCase):
"""Numerical tests for LAX operations."""

Expand Down Expand Up @@ -2668,6 +2669,7 @@ def testDynamicSliceU8Index(self):
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
# check casting to ndarray works
Expand Down Expand Up @@ -2870,6 +2872,7 @@ def testLog1pNearOne(self):
np.log1p(np.float32(1e-5)), lax.log1p(np.complex64(1e-5)))


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxNamedShapeTest(jtu.JaxTestCase):

def test_abstract_eval(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/linalg_test.py
Expand Up @@ -42,6 +42,7 @@
complex_types = jtu.dtypes.complex


@jtu.with_config(jax_numpy_rank_promotion='raise')
class NumpyLinalgTest(jtu.JaxTestCase):

def testNotImplemented(self):
Expand Down Expand Up @@ -956,6 +957,7 @@ def f(inp):
self.assertFalse(np.any(np.isnan(cube_func(a))))


@jtu.with_config(jax_numpy_rank_promotion='raise')
class ScipyLinalgTest(jtu.JaxTestCase):

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -1372,7 +1374,7 @@ def expm(x):
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
rtol=tol)


@jtu.with_config(jax_numpy_rank_promotion='raise')
class LaxLinalgTest(jtu.JaxTestCase):

def run_test(self, alpha, beta):
Expand Down
3 changes: 2 additions & 1 deletion tests/nn_test.py
Expand Up @@ -35,6 +35,7 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class NNFunctionsTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
Expand Down Expand Up @@ -229,7 +230,7 @@ def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4):
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
]


@jtu.with_config(jax_numpy_rank_promotion="raise")
class NNInitializersTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
Expand Down
4 changes: 4 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -74,6 +74,7 @@ def check_1d_2d_mesh(f, set_mesh):


# TODO(skye): make the buffer donation utils part of JaxTestCase
@jtu.with_config(jax_numpy_rank_promotion="raise")
class PJitTest(jtu.BufferDonationTestCase):

@jtu.with_mesh([('x', 1)])
Expand Down Expand Up @@ -634,6 +635,7 @@ def f(x, y):
self.assertEqual(f(1, 'bye'), 5)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class GDAPjitTest(jtu.JaxTestCase):

@jtu.with_mesh([('x', 4), ('y', 2)])
Expand Down Expand Up @@ -951,6 +953,7 @@ def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")


@jtu.with_config(jax_numpy_rank_promotion="raise")
class PJitErrorTest(jtu.JaxTestCase):
@check_1d_2d_mesh(set_mesh=True)
def testNonDivisibleArgs(self, mesh, resources):
Expand Down Expand Up @@ -1178,6 +1181,7 @@ def h(x):
f(x)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class UtilTest(jtu.JaxTestCase):

def testOpShardingRoundTrip(self):
Expand Down

0 comments on commit 5691010

Please sign in to comment.