From 5691010d2ff6a4f24f0768df48a3b26624e32d34 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 10 Feb 2022 19:08:01 -0800 Subject: [PATCH] Copybara import of the project: -- d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas : JaxTestCase: set numpy_rank_promotion='raise' by default PiperOrigin-RevId: 427896974 --- CHANGELOG.md | 8 -------- jax/_src/test_util.py | 5 +---- jax/experimental/jax2tf/tests/tf_test_util.py | 1 - tests/api_test.py | 17 +++++++++++++++-- tests/batching_test.py | 1 + tests/fft_test.py | 1 + tests/infeed_test.py | 1 + tests/lax_autodiff_test.py | 2 +- tests/lax_numpy_einsum_test.py | 1 + tests/lax_numpy_indexing_test.py | 2 ++ tests/lax_numpy_test.py | 5 +++++ tests/lax_numpy_vectorize_test.py | 1 + tests/lax_scipy_sparse_test.py | 1 + tests/lax_scipy_test.py | 1 + tests/lax_test.py | 3 +++ tests/linalg_test.py | 4 +++- tests/nn_test.py | 3 ++- tests/pjit_test.py | 4 ++++ tests/pmap_test.py | 7 +++++++ tests/qdwh_test.py | 1 + tests/random_test.py | 4 ++++ tests/scipy_fft_test.py | 2 +- tests/scipy_ndimage_test.py | 1 + tests/scipy_optimize_test.py | 2 ++ tests/scipy_signal_test.py | 1 + tests/sharded_jit_test.py | 5 ++++- tests/sparse_test.py | 5 +++++ tests/sparsify_test.py | 1 + tests/xmap_test.py | 3 +++ 29 files changed, 73 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cca7775bad04..3190fe95d3b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,14 +29,6 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. `dialect=` is passed. * The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR `ir.Module` object instead of its string representation. - * `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by - default. To recover the previous behavior, use the `jax.test_util.with_config` - decorator: - ```python - @jtu.with_config(jax_numpy_rank_promotion='allow') - class MyTest(jtu.JaxTestCase): - ... - ``` ## jaxlib 0.1.76 (Jan 27, 2022) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index baff15718876..484134a48f6b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -915,10 +915,7 @@ def decorator(cls): class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" - _default_config = { - 'jax_enable_checks': True, - 'jax_numpy_rank_promotion': 'raise', - } + _default_config = {'jax_enable_checks': True} # TODO(mattjj): this obscures the error messages from failures, figure out how # to re-enable it diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 2ad4f9492737..b12b448a0965 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -150,7 +150,6 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, return f1(*args1) -@jtu.with_config(jax_numpy_rank_promotion="allow") class JaxToTfTestCase(jtu.JaxTestCase): def setUp(self): diff --git a/tests/api_test.py b/tests/api_test.py index 2a9704613be7..671cadd2d57a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -70,6 +70,7 @@ numpy_version = tuple(map(int, np.__version__.split('.')[:3])) +@jtu.with_config(jax_numpy_rank_promotion="raise") class CPPJitTest(jtu.BufferDonationTestCase): """Shared tests between the Python and the C++ jax,jit implementations. @@ -859,7 +860,7 @@ def f(k): python_should_be_executing = False self.assertEqual(x, f(x)) - +@jtu.with_config(jax_numpy_rank_promotion="raise") class PythonJitTest(CPPJitTest): @property @@ -867,6 +868,7 @@ def jit(self): return api._python_jit +@jtu.with_config(jax_numpy_rank_promotion="raise") class APITest(jtu.JaxTestCase): def test_grad_item(self): @@ -3414,6 +3416,7 @@ def f(x): FLAGS.jax_numpy_rank_promotion = allow_promotion +@jtu.with_config(jax_numpy_rank_promotion="raise") class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -4270,6 +4273,7 @@ def f(u, x): _ = api.linearize(partial(f, core.unit), 3.) +@jtu.with_config(jax_numpy_rank_promotion="raise") class JaxprTest(jtu.JaxTestCase): def test_scalar_literals(self): @@ -4413,6 +4417,7 @@ def test_convert_element_type_literal_constant_folding(self): self.assertLen(jaxpr.eqns, 0) +@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomJVPTest(jtu.JaxTestCase): def test_basic(self): @@ -5387,6 +5392,7 @@ def f_jvp(primals, tangents): self.assertEqual(shape, ()) +@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomVJPTest(jtu.JaxTestCase): def test_basic(self): @@ -6355,6 +6361,7 @@ def transposed(y): return transposed +@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomTransposeTest(jtu.JaxTestCase): def test_linear_call(self): @@ -6683,6 +6690,7 @@ def tp(r, t): return t / r self.assertAllClose(f_t(x), jax.jit(f_t)(x)) +@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomVmapTest(jtu.JaxTestCase): def test_basic(self): @@ -7109,6 +7117,7 @@ def vmap_ref(xs, y): self.assertEqual(str(jaxpr), str(jaxpr_ref)) +@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" @@ -7146,6 +7155,7 @@ def test_def_method_forwarding_all_permutations(self): self.assertIsInstance(getattr(f, method), Callable) +@jtu.with_config(jax_numpy_rank_promotion="raise") class InvertibleADTest(jtu.JaxTestCase): @jtu.ignore_warning(message="Values that an @invertible function closes") @@ -7254,6 +7264,7 @@ def f(x, y): check_dtypes=True) +@jtu.with_config(jax_numpy_rank_promotion="raise") class BufferDonationTest(jtu.BufferDonationTestCase): @jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU. @@ -7276,6 +7287,7 @@ def test_pmap_nested_donate_ignored(self): pmap_fun(a) # doesn't crash +@jtu.with_config(jax_numpy_rank_promotion="raise") class NamedCallTest(jtu.JaxTestCase): def test_default_name(self): @@ -7356,6 +7368,7 @@ def test_integer_overflow(self, jit_type, func): self.assertRaises(OverflowError, f, int_min - 1) +@jtu.with_config(jax_numpy_rank_promotion="raise") class BackendsTest(jtu.JaxTestCase): @unittest.skipIf(not sys.executable, "test requires sys.executable") @@ -7378,6 +7391,7 @@ def test_cpu_warning_suppression(self): assert "No GPU/TPU found" not in result.stderr.decode() +@jtu.with_config(jax_numpy_rank_promotion="raise") class CleanupTest(jtu.JaxTestCase): def test_call_wrapped_second_phase_cleanup(self): try: @@ -7538,7 +7552,6 @@ def f(x, n): self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.out_avals[0].shape[0]) self.assertEqual(4, jaxpr.out_avals[0].shape[1]) - @jax.numpy_rank_promotion("allow") # explicitly exercises implicit rank promotion. def test_basic_batchpoly_neuralnet(self): def predict(params, inputs): for W, b in params: diff --git a/tests/batching_test.py b/tests/batching_test.py index b7d0cf640856..00cea40a6bea 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -40,6 +40,7 @@ # These are 'manual' tests for batching (vmap). The more exhaustive, more # systematic tests are in lax_test.py's LaxVmapTest class. +@jtu.with_config(jax_numpy_rank_promotion="raise") class BatchingTest(jtu.JaxTestCase): def testConstantFunction(self): diff --git a/tests/fft_test.py b/tests/fft_test.py index b597cc6b7bf0..53125e15e45d 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -93,6 +93,7 @@ def _zero_for_irfft(z, axes): return jnp.concatenate(parts, axis=axis) +@jtu.with_config(jax_numpy_rank_promotion="raise") class FftTest(jtu.JaxTestCase): def testNotImplemented(self): diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 570038343309..2d9e75ed59f1 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -26,6 +26,7 @@ config.parse_flags_with_absl() +@jtu.with_config(jax_numpy_rank_promotion="raise") class InfeedTest(jtu.JaxTestCase): @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 4c959a7f1c58..e91bd02f4fdf 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -189,7 +189,7 @@ def check_grads_bilinear(f, args, order, check_grads(lambda rhs: f(lhs, rhs), (rhs,), order, modes=modes, atol=atol, rtol=rtol, eps=1.) - +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxAutodiffTest(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index a9ed44a06a48..12c687a0a9bd 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -30,6 +30,7 @@ config.parse_flags_with_absl() +@jtu.with_config(jax_numpy_rank_promotion="raise") class EinsumTest(jtu.JaxTestCase): def _check(self, s, *ops): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 740eb2a371db..cdf19edda531 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -414,6 +414,7 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): MODES = ["clip", "drop", "promise_in_bounds"] +@jtu.with_config(jax_numpy_rank_promotion="raise") class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @@ -996,6 +997,7 @@ def _update_tol(op): return tol +@jtu.with_config(jax_numpy_rank_promotion="raise") class IndexedUpdateTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index bc596d8d2703..db3a458b70aa 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -519,6 +519,7 @@ def wrapper(*args, **kw): return wrapper +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -5914,6 +5915,7 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): GradSpecialValuesTestSpec(jnp.sinc, [0.], 1), ] +@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyGradTests(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( @@ -6018,6 +6020,7 @@ def testGradLogaddexp2Complex(self, shapes, dtype): tol = 3e-2 check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol) +@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpySignaturesTest(jtu.JaxTestCase): def testWrappedSignaturesMatch(self): @@ -6133,6 +6136,7 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]: yield arg_dtypes +@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyUfuncTests(jtu.JaxTestCase): @parameterized.named_parameters( @@ -6164,6 +6168,7 @@ def testUfuncInputTypes(self, name, arg_dtypes): # that jnp returns float32. e.g. np.cos(np.uint8(0)) self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2) +@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyDocTests(jtu.JaxTestCase): def test_lax_numpy_docstrings(self): diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 2cbf3ef28794..0e8fd29fe91d 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -25,6 +25,7 @@ config.parse_flags_with_absl() +@jtu.with_config(jax_numpy_rank_promotion="raise") class VectorizeTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 671050fe9bd8..93c9140b6bbf 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -64,6 +64,7 @@ def rand_sym_pos_def(rng, shape, dtype): return matrix @ matrix.T.conj() +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipyTests(jtu.JaxTestCase): def _fetch_preconditioner(self, preconditioner, A, rng=None): diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 346981b3de57..489c6023202c 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -144,6 +144,7 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t ] +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" diff --git a/tests/lax_test.py b/tests/lax_test.py index 2b0b44c72ee3..d1efe27f8d7b 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -181,6 +181,7 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None): ] +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxTest(jtu.JaxTestCase): """Numerical tests for LAX operations.""" @@ -2668,6 +2669,7 @@ def testDynamicSliceU8Index(self): np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128]) +@jtu.with_config(jax_numpy_rank_promotion="raise") class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): # check casting to ndarray works @@ -2870,6 +2872,7 @@ def testLog1pNearOne(self): np.log1p(np.float32(1e-5)), lax.log1p(np.complex64(1e-5))) +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxNamedShapeTest(jtu.JaxTestCase): def test_abstract_eval(self): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5d3863eeabb0..2f98b07b81d1 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -42,6 +42,7 @@ complex_types = jtu.dtypes.complex +@jtu.with_config(jax_numpy_rank_promotion='raise') class NumpyLinalgTest(jtu.JaxTestCase): def testNotImplemented(self): @@ -956,6 +957,7 @@ def f(inp): self.assertFalse(np.any(np.isnan(cube_func(a)))) +@jtu.with_config(jax_numpy_rank_promotion='raise') class ScipyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( @@ -1372,7 +1374,7 @@ def expm(x): jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol, rtol=tol) - +@jtu.with_config(jax_numpy_rank_promotion='raise') class LaxLinalgTest(jtu.JaxTestCase): def run_test(self, alpha, beta): diff --git a/tests/nn_test.py b/tests/nn_test.py index 1659f7d40781..fe41be4e5f0c 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -35,6 +35,7 @@ config.parse_flags_with_absl() +@jtu.with_config(jax_numpy_rank_promotion="raise") class NNFunctionsTest(jtu.JaxTestCase): @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self): @@ -229,7 +230,7 @@ def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4): initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4) ] - +@jtu.with_config(jax_numpy_rank_promotion="raise") class NNInitializersTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d847fed51c1f..b02eb8d4854c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -74,6 +74,7 @@ def check_1d_2d_mesh(f, set_mesh): # TODO(skye): make the buffer donation utils part of JaxTestCase +@jtu.with_config(jax_numpy_rank_promotion="raise") class PJitTest(jtu.BufferDonationTestCase): @jtu.with_mesh([('x', 1)]) @@ -634,6 +635,7 @@ def f(x, y): self.assertEqual(f(1, 'bye'), 5) +@jtu.with_config(jax_numpy_rank_promotion="raise") class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) @@ -951,6 +953,7 @@ def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") +@jtu.with_config(jax_numpy_rank_promotion="raise") class PJitErrorTest(jtu.JaxTestCase): @check_1d_2d_mesh(set_mesh=True) def testNonDivisibleArgs(self, mesh, resources): @@ -1178,6 +1181,7 @@ def h(x): f(x) +@jtu.with_config(jax_numpy_rank_promotion="raise") class UtilTest(jtu.JaxTestCase): def testOpShardingRoundTrip(self): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index a00dde7a131f..0f60577b16f6 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -109,6 +109,7 @@ def tearDownModule(): jtu.ignore_warning, message=".*is an experimental.*") +@jtu.with_config(jax_numpy_rank_promotion="raise") class PythonPmapTest(jtu.JaxTestCase): @property @@ -1904,6 +1905,7 @@ def pmap(self): return src_api._cpp_pmap +@jtu.with_config(jax_numpy_rank_promotion="raise") class VmapOfPmapTest(jtu.JaxTestCase): # TODO(apaszke) @@ -1946,6 +1948,7 @@ def args_slice(vi, pi): self.assertAllClose(ans, expected) +@jtu.with_config(jax_numpy_rank_promotion="raise") class VmapPmapCollectivesTest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -2131,6 +2134,7 @@ def f(x): self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x)) +@jtu.with_config(jax_numpy_rank_promotion="raise") class PmapWithDevicesTest(jtu.JaxTestCase): def testAllDevices(self): @@ -2383,6 +2387,7 @@ def h(y): jax.grad(mk_case(vmap))(x, y)) +@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardedDeviceArrayTest(jtu.JaxTestCase): def testThreadsafeIndexing(self): @@ -2488,6 +2493,7 @@ def test_delete_is_idempotent(self): _ = x[0] +@jtu.with_config(jax_numpy_rank_promotion="raise") class SpecToIndicesTest(jtu.JaxTestCase): def testShardsPerAxis(self): @@ -2617,6 +2623,7 @@ def _spec_str(spec): f"{spec.mesh_mapping},)") +@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardArgsTest(jtu.JaxTestCase): def numpy_array(x): diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index 5e2e35d627a2..6377ef2c1c49 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -58,6 +58,7 @@ def _compute_relative_diff(actual, expected): _dot = functools.partial(jnp.dot, precision="highest") +@jtu.with_config(jax_numpy_rank_promotion="raise") class QdwhTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/random_test.py b/tests/random_test.py index c949a706c314..68dbf53f2e08 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -56,6 +56,7 @@ def _prng_key_as_array(key): ('unsafe_rbg', prng.unsafe_rbg_prng_impl)] +@jtu.with_config(jax_numpy_rank_promotion="raise") class PrngTest(jtu.JaxTestCase): def testThreefry2x32(self): @@ -314,6 +315,7 @@ def test_key_array_indexing_nd(self): lambda: keys[0, 1, None, 2]) +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxRandomTest(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): @@ -1224,6 +1226,7 @@ def _double_threefry_fold_in(key, data): @skipIf(not config.jax_enable_custom_prng, 'custom PRNG tests require config.jax_enable_custom_prng') +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxRandomWithCustomPRNGTest(LaxRandomTest): def seed_prng(self, seed): return prng.seed_with_impl(double_threefry_prng_impl, seed) @@ -1252,6 +1255,7 @@ def test_grad_of_prng_key(self): @skipIf(not config.jax_enable_custom_prng, 'custom PRNG tests require config.jax_enable_custom_prng') +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxRandomWithRBGPRNGTest(LaxRandomTest): def seed_prng(self, seed): return random.rbg_key(seed) diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index 31bc01e7173a..a4ab47066226 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -42,7 +42,7 @@ def _get_dctn_test_s(shape, axes): s_list.extend(itertools.product(*[[shape[ax]+i for i in range(-shape[ax]+1, shape[ax]+1)] for ax in axes])) return s_list - +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipyFftTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.fft implementations""" diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 045146d6a052..347e025c52b0 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -57,6 +57,7 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0): return result +@jtu.with_config(jax_numpy_rank_promotion="raise") class NdimageTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index d98329735310..a2207e8c2e05 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -64,6 +64,7 @@ def zakharovFromIndices(x, ii): return answer +@jtu.with_config(jax_numpy_rank_promotion="raise") class TestBFGS(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( @@ -140,6 +141,7 @@ def f(x): jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS') +@jtu.with_config(jax_numpy_rank_promotion="raise") class TestLBFGS(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 0162a2915a96..4c9a80747114 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -35,6 +35,7 @@ default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex +@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" diff --git a/tests/sharded_jit_test.py b/tests/sharded_jit_test.py index a83784e1d042..37248e36f281 100644 --- a/tests/sharded_jit_test.py +++ b/tests/sharded_jit_test.py @@ -39,6 +39,7 @@ config.parse_flags_with_absl() +@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardedJitTest(jtu.JaxTestCase): def setUp(self): @@ -276,6 +277,7 @@ def testCompilationCache(self): # TODO(skye): add more error tests +@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardedJitErrorsTest(jtu.JaxTestCase): def setUp(self): @@ -298,6 +300,7 @@ def f(x): # Tests that don't need a TPU to run. +@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardedJitTestNoTpu(jtu.JaxTestCase): def testTranslationRule(self): @@ -326,7 +329,7 @@ def f(x): # Annotation from sharded_jit self.assertIn("sharding={replicated}", hlo.as_hlo_text()) - +@jtu.with_config(jax_numpy_rank_promotion="raise") class PmapOfShardedJitTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 72456ad6c830..9c672d4531cc 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -116,6 +116,7 @@ def _rand_sparse(shape, dtype, nse=nse): return _rand_sparse +@jtu.with_config(jax_numpy_rank_promotion="raise") class cuSparseTest(jtu.JaxTestCase): def gpu_dense_conversion_warning_context(self, dtype): if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer): @@ -554,6 +555,7 @@ def test_coo_matmul_ad(self, shape, dtype, bshape): self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol) +@jtu.with_config(jax_numpy_rank_promotion="raise") class BCOOTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_nbatch={}_ndense={}".format( @@ -1677,6 +1679,7 @@ def test_bcoo_bad_fillvals(self): self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de) +@jtu.with_config(jax_numpy_rank_promotion="raise") class SparseGradTest(jtu.JaxTestCase): def test_sparse_grad(self): rng_sparse = rand_sparse(self.rng()) @@ -1699,6 +1702,7 @@ def f(X, y): self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense) +@jtu.with_config(jax_numpy_rank_promotion="raise") class SparseObjectTest(jtu.JaxTestCase): def test_repr(self): M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32')) @@ -1894,6 +1898,7 @@ def test_bcoo_methods(self): self.assertArraysEqual(M.sum(), Msp.sum()) +@jtu.with_config(jax_numpy_rank_promotion="raise") class SparseRandomTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_indices_dtype={}_nbatch={}_ndense={}".format( diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index aa03ef129135..b622d59dfb69 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -30,6 +30,7 @@ config.parse_flags_with_absl() +@jtu.with_config(jax_numpy_rank_promotion="raise") class SparsifyTest(jtu.JaxTestCase): @classmethod def sparsify(cls, f): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index f3ec3f40e8b3..b196e9812374 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -210,6 +210,7 @@ def divisors2(n: int) -> Iterator[Tuple[int, int]]: yield axis_resources, mesh_data +@jtu.with_config(jax_numpy_rank_promotion="raise") class XMapTestCase(jtu.BufferDonationTestCase): pass @@ -1177,6 +1178,7 @@ def test_xeinsum_no_named_axes_reduce_sum(self): self.assertAllClose(out, expected, check_dtypes=True) +@jtu.with_config(jax_numpy_rank_promotion="raise") class XMapErrorTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 2)]) @@ -1408,6 +1410,7 @@ def testAxesMismatch(self): xmap(lambda x: x, (p,), (p, ['x']))([x, x, x]) # Error, we raise a generic tree mismatch message +@jtu.with_config(jax_numpy_rank_promotion="raise") class NamedAutodiffTests(jtu.JaxTestCase): def testVjpReduceAxes(self):