diff --git a/CHANGELOG.md b/CHANGELOG.md index 3190fe95d3b2..cca7775bad04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. `dialect=` is passed. * The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR `ir.Module` object instead of its string representation. + * `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by + default. To recover the previous behavior, use the `jax.test_util.with_config` + decorator: + ```python + @jtu.with_config(jax_numpy_rank_promotion='allow') + class MyTest(jtu.JaxTestCase): + ... + ``` ## jaxlib 0.1.76 (Jan 27, 2022) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 484134a48f6b..baff15718876 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -915,7 +915,10 @@ def decorator(cls): class JaxTestCase(parameterized.TestCase): """Base class for JAX tests including numerical checks and boilerplate.""" - _default_config = {'jax_enable_checks': True} + _default_config = { + 'jax_enable_checks': True, + 'jax_numpy_rank_promotion': 'raise', + } # TODO(mattjj): this obscures the error messages from failures, figure out how # to re-enable it diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index b12b448a0965..2ad4f9492737 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -150,6 +150,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, return f1(*args1) +@jtu.with_config(jax_numpy_rank_promotion="allow") class JaxToTfTestCase(jtu.JaxTestCase): def setUp(self): diff --git a/tests/api_test.py b/tests/api_test.py index 671cadd2d57a..2a9704613be7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -70,7 +70,6 @@ numpy_version = tuple(map(int, np.__version__.split('.')[:3])) -@jtu.with_config(jax_numpy_rank_promotion="raise") class CPPJitTest(jtu.BufferDonationTestCase): """Shared tests between the Python and the C++ jax,jit implementations. @@ -860,7 +859,7 @@ def f(k): python_should_be_executing = False self.assertEqual(x, f(x)) -@jtu.with_config(jax_numpy_rank_promotion="raise") + class PythonJitTest(CPPJitTest): @property @@ -868,7 +867,6 @@ def jit(self): return api._python_jit -@jtu.with_config(jax_numpy_rank_promotion="raise") class APITest(jtu.JaxTestCase): def test_grad_item(self): @@ -3416,7 +3414,6 @@ def f(x): FLAGS.jax_numpy_rank_promotion = allow_promotion -@jtu.with_config(jax_numpy_rank_promotion="raise") class RematTest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -4273,7 +4270,6 @@ def f(u, x): _ = api.linearize(partial(f, core.unit), 3.) -@jtu.with_config(jax_numpy_rank_promotion="raise") class JaxprTest(jtu.JaxTestCase): def test_scalar_literals(self): @@ -4417,7 +4413,6 @@ def test_convert_element_type_literal_constant_folding(self): self.assertLen(jaxpr.eqns, 0) -@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomJVPTest(jtu.JaxTestCase): def test_basic(self): @@ -5392,7 +5387,6 @@ def f_jvp(primals, tangents): self.assertEqual(shape, ()) -@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomVJPTest(jtu.JaxTestCase): def test_basic(self): @@ -6361,7 +6355,6 @@ def transposed(y): return transposed -@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomTransposeTest(jtu.JaxTestCase): def test_linear_call(self): @@ -6690,7 +6683,6 @@ def tp(r, t): return t / r self.assertAllClose(f_t(x), jax.jit(f_t)(x)) -@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomVmapTest(jtu.JaxTestCase): def test_basic(self): @@ -7117,7 +7109,6 @@ def vmap_ref(xs, y): self.assertEqual(str(jaxpr), str(jaxpr_ref)) -@jtu.with_config(jax_numpy_rank_promotion="raise") class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" @@ -7155,7 +7146,6 @@ def test_def_method_forwarding_all_permutations(self): self.assertIsInstance(getattr(f, method), Callable) -@jtu.with_config(jax_numpy_rank_promotion="raise") class InvertibleADTest(jtu.JaxTestCase): @jtu.ignore_warning(message="Values that an @invertible function closes") @@ -7264,7 +7254,6 @@ def f(x, y): check_dtypes=True) -@jtu.with_config(jax_numpy_rank_promotion="raise") class BufferDonationTest(jtu.BufferDonationTestCase): @jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU. @@ -7287,7 +7276,6 @@ def test_pmap_nested_donate_ignored(self): pmap_fun(a) # doesn't crash -@jtu.with_config(jax_numpy_rank_promotion="raise") class NamedCallTest(jtu.JaxTestCase): def test_default_name(self): @@ -7368,7 +7356,6 @@ def test_integer_overflow(self, jit_type, func): self.assertRaises(OverflowError, f, int_min - 1) -@jtu.with_config(jax_numpy_rank_promotion="raise") class BackendsTest(jtu.JaxTestCase): @unittest.skipIf(not sys.executable, "test requires sys.executable") @@ -7391,7 +7378,6 @@ def test_cpu_warning_suppression(self): assert "No GPU/TPU found" not in result.stderr.decode() -@jtu.with_config(jax_numpy_rank_promotion="raise") class CleanupTest(jtu.JaxTestCase): def test_call_wrapped_second_phase_cleanup(self): try: @@ -7552,6 +7538,7 @@ def f(x, n): self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.out_avals[0].shape[0]) self.assertEqual(4, jaxpr.out_avals[0].shape[1]) + @jax.numpy_rank_promotion("allow") # explicitly exercises implicit rank promotion. def test_basic_batchpoly_neuralnet(self): def predict(params, inputs): for W, b in params: diff --git a/tests/batching_test.py b/tests/batching_test.py index 00cea40a6bea..b7d0cf640856 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -40,7 +40,6 @@ # These are 'manual' tests for batching (vmap). The more exhaustive, more # systematic tests are in lax_test.py's LaxVmapTest class. -@jtu.with_config(jax_numpy_rank_promotion="raise") class BatchingTest(jtu.JaxTestCase): def testConstantFunction(self): diff --git a/tests/fft_test.py b/tests/fft_test.py index 53125e15e45d..b597cc6b7bf0 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -93,7 +93,6 @@ def _zero_for_irfft(z, axes): return jnp.concatenate(parts, axis=axis) -@jtu.with_config(jax_numpy_rank_promotion="raise") class FftTest(jtu.JaxTestCase): def testNotImplemented(self): diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 2d9e75ed59f1..570038343309 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -26,7 +26,6 @@ config.parse_flags_with_absl() -@jtu.with_config(jax_numpy_rank_promotion="raise") class InfeedTest(jtu.JaxTestCase): @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index e91bd02f4fdf..4c959a7f1c58 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -189,7 +189,7 @@ def check_grads_bilinear(f, args, order, check_grads(lambda rhs: f(lhs, rhs), (rhs,), order, modes=modes, atol=atol, rtol=rtol, eps=1.) -@jtu.with_config(jax_numpy_rank_promotion="raise") + class LaxAutodiffTest(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index 12c687a0a9bd..a9ed44a06a48 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -30,7 +30,6 @@ config.parse_flags_with_absl() -@jtu.with_config(jax_numpy_rank_promotion="raise") class EinsumTest(jtu.JaxTestCase): def _check(self, s, *ops): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index cdf19edda531..740eb2a371db 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -414,7 +414,6 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): MODES = ["clip", "drop", "promise_in_bounds"] -@jtu.with_config(jax_numpy_rank_promotion="raise") class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @@ -997,7 +996,6 @@ def _update_tol(op): return tol -@jtu.with_config(jax_numpy_rank_promotion="raise") class IndexedUpdateTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index db3a458b70aa..bc596d8d2703 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -519,7 +519,6 @@ def wrapper(*args, **kw): return wrapper -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -5915,7 +5914,6 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): GradSpecialValuesTestSpec(jnp.sinc, [0.], 1), ] -@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyGradTests(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( @@ -6020,7 +6018,6 @@ def testGradLogaddexp2Complex(self, shapes, dtype): tol = 3e-2 check_grads(jnp.logaddexp2, args, 1, ["fwd", "rev"], tol, tol) -@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpySignaturesTest(jtu.JaxTestCase): def testWrappedSignaturesMatch(self): @@ -6136,7 +6133,6 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]: yield arg_dtypes -@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyUfuncTests(jtu.JaxTestCase): @parameterized.named_parameters( @@ -6168,7 +6164,6 @@ def testUfuncInputTypes(self, name, arg_dtypes): # that jnp returns float32. e.g. np.cos(np.uint8(0)) self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1E-2) -@jtu.with_config(jax_numpy_rank_promotion="raise") class NumpyDocTests(jtu.JaxTestCase): def test_lax_numpy_docstrings(self): diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 0e8fd29fe91d..2cbf3ef28794 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -25,7 +25,6 @@ config.parse_flags_with_absl() -@jtu.with_config(jax_numpy_rank_promotion="raise") class VectorizeTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 93c9140b6bbf..671050fe9bd8 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -64,7 +64,6 @@ def rand_sym_pos_def(rng, shape, dtype): return matrix @ matrix.T.conj() -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipyTests(jtu.JaxTestCase): def _fetch_preconditioner(self, preconditioner, A, rng=None): diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 489c6023202c..346981b3de57 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -144,7 +144,6 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t ] -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" diff --git a/tests/lax_test.py b/tests/lax_test.py index d1efe27f8d7b..2b0b44c72ee3 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -181,7 +181,6 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None): ] -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxTest(jtu.JaxTestCase): """Numerical tests for LAX operations.""" @@ -2669,7 +2668,6 @@ def testDynamicSliceU8Index(self): np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128]) -@jtu.with_config(jax_numpy_rank_promotion="raise") class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): # check casting to ndarray works @@ -2872,7 +2870,6 @@ def testLog1pNearOne(self): np.log1p(np.float32(1e-5)), lax.log1p(np.complex64(1e-5))) -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxNamedShapeTest(jtu.JaxTestCase): def test_abstract_eval(self): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 2f98b07b81d1..5d3863eeabb0 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -42,7 +42,6 @@ complex_types = jtu.dtypes.complex -@jtu.with_config(jax_numpy_rank_promotion='raise') class NumpyLinalgTest(jtu.JaxTestCase): def testNotImplemented(self): @@ -957,7 +956,6 @@ def f(inp): self.assertFalse(np.any(np.isnan(cube_func(a)))) -@jtu.with_config(jax_numpy_rank_promotion='raise') class ScipyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( @@ -1374,7 +1372,7 @@ def expm(x): jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol, rtol=tol) -@jtu.with_config(jax_numpy_rank_promotion='raise') + class LaxLinalgTest(jtu.JaxTestCase): def run_test(self, alpha, beta): diff --git a/tests/nn_test.py b/tests/nn_test.py index fe41be4e5f0c..1659f7d40781 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -35,7 +35,6 @@ config.parse_flags_with_absl() -@jtu.with_config(jax_numpy_rank_promotion="raise") class NNFunctionsTest(jtu.JaxTestCase): @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self): @@ -230,7 +229,7 @@ def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4): initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4) ] -@jtu.with_config(jax_numpy_rank_promotion="raise") + class NNInitializersTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b02eb8d4854c..d847fed51c1f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -74,7 +74,6 @@ def check_1d_2d_mesh(f, set_mesh): # TODO(skye): make the buffer donation utils part of JaxTestCase -@jtu.with_config(jax_numpy_rank_promotion="raise") class PJitTest(jtu.BufferDonationTestCase): @jtu.with_mesh([('x', 1)]) @@ -635,7 +634,6 @@ def f(x, y): self.assertEqual(f(1, 'bye'), 5) -@jtu.with_config(jax_numpy_rank_promotion="raise") class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) @@ -953,7 +951,6 @@ def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") -@jtu.with_config(jax_numpy_rank_promotion="raise") class PJitErrorTest(jtu.JaxTestCase): @check_1d_2d_mesh(set_mesh=True) def testNonDivisibleArgs(self, mesh, resources): @@ -1181,7 +1178,6 @@ def h(x): f(x) -@jtu.with_config(jax_numpy_rank_promotion="raise") class UtilTest(jtu.JaxTestCase): def testOpShardingRoundTrip(self): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 0f60577b16f6..a00dde7a131f 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -109,7 +109,6 @@ def tearDownModule(): jtu.ignore_warning, message=".*is an experimental.*") -@jtu.with_config(jax_numpy_rank_promotion="raise") class PythonPmapTest(jtu.JaxTestCase): @property @@ -1905,7 +1904,6 @@ def pmap(self): return src_api._cpp_pmap -@jtu.with_config(jax_numpy_rank_promotion="raise") class VmapOfPmapTest(jtu.JaxTestCase): # TODO(apaszke) @@ -1948,7 +1946,6 @@ def args_slice(vi, pi): self.assertAllClose(ans, expected) -@jtu.with_config(jax_numpy_rank_promotion="raise") class VmapPmapCollectivesTest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -2134,7 +2131,6 @@ def f(x): self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x)) -@jtu.with_config(jax_numpy_rank_promotion="raise") class PmapWithDevicesTest(jtu.JaxTestCase): def testAllDevices(self): @@ -2387,7 +2383,6 @@ def h(y): jax.grad(mk_case(vmap))(x, y)) -@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardedDeviceArrayTest(jtu.JaxTestCase): def testThreadsafeIndexing(self): @@ -2493,7 +2488,6 @@ def test_delete_is_idempotent(self): _ = x[0] -@jtu.with_config(jax_numpy_rank_promotion="raise") class SpecToIndicesTest(jtu.JaxTestCase): def testShardsPerAxis(self): @@ -2623,7 +2617,6 @@ def _spec_str(spec): f"{spec.mesh_mapping},)") -@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardArgsTest(jtu.JaxTestCase): def numpy_array(x): diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index 6377ef2c1c49..5e2e35d627a2 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -58,7 +58,6 @@ def _compute_relative_diff(actual, expected): _dot = functools.partial(jnp.dot, precision="highest") -@jtu.with_config(jax_numpy_rank_promotion="raise") class QdwhTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/random_test.py b/tests/random_test.py index 68dbf53f2e08..c949a706c314 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -56,7 +56,6 @@ def _prng_key_as_array(key): ('unsafe_rbg', prng.unsafe_rbg_prng_impl)] -@jtu.with_config(jax_numpy_rank_promotion="raise") class PrngTest(jtu.JaxTestCase): def testThreefry2x32(self): @@ -315,7 +314,6 @@ def test_key_array_indexing_nd(self): lambda: keys[0, 1, None, 2]) -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxRandomTest(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): @@ -1226,7 +1224,6 @@ def _double_threefry_fold_in(key, data): @skipIf(not config.jax_enable_custom_prng, 'custom PRNG tests require config.jax_enable_custom_prng') -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxRandomWithCustomPRNGTest(LaxRandomTest): def seed_prng(self, seed): return prng.seed_with_impl(double_threefry_prng_impl, seed) @@ -1255,7 +1252,6 @@ def test_grad_of_prng_key(self): @skipIf(not config.jax_enable_custom_prng, 'custom PRNG tests require config.jax_enable_custom_prng') -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxRandomWithRBGPRNGTest(LaxRandomTest): def seed_prng(self, seed): return random.rbg_key(seed) diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index a4ab47066226..31bc01e7173a 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -42,7 +42,7 @@ def _get_dctn_test_s(shape, axes): s_list.extend(itertools.product(*[[shape[ax]+i for i in range(-shape[ax]+1, shape[ax]+1)] for ax in axes])) return s_list -@jtu.with_config(jax_numpy_rank_promotion="raise") + class LaxBackedScipyFftTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.fft implementations""" diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 347e025c52b0..045146d6a052 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -57,7 +57,6 @@ def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0): return result -@jtu.with_config(jax_numpy_rank_promotion="raise") class NdimageTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index a2207e8c2e05..d98329735310 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -64,7 +64,6 @@ def zakharovFromIndices(x, ii): return answer -@jtu.with_config(jax_numpy_rank_promotion="raise") class TestBFGS(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( @@ -141,7 +140,6 @@ def f(x): jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS') -@jtu.with_config(jax_numpy_rank_promotion="raise") class TestLBFGS(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 4c9a80747114..0162a2915a96 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -35,7 +35,6 @@ default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex -@jtu.with_config(jax_numpy_rank_promotion="raise") class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" diff --git a/tests/sharded_jit_test.py b/tests/sharded_jit_test.py index 37248e36f281..a83784e1d042 100644 --- a/tests/sharded_jit_test.py +++ b/tests/sharded_jit_test.py @@ -39,7 +39,6 @@ config.parse_flags_with_absl() -@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardedJitTest(jtu.JaxTestCase): def setUp(self): @@ -277,7 +276,6 @@ def testCompilationCache(self): # TODO(skye): add more error tests -@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardedJitErrorsTest(jtu.JaxTestCase): def setUp(self): @@ -300,7 +298,6 @@ def f(x): # Tests that don't need a TPU to run. -@jtu.with_config(jax_numpy_rank_promotion="raise") class ShardedJitTestNoTpu(jtu.JaxTestCase): def testTranslationRule(self): @@ -329,7 +326,7 @@ def f(x): # Annotation from sharded_jit self.assertIn("sharding={replicated}", hlo.as_hlo_text()) -@jtu.with_config(jax_numpy_rank_promotion="raise") + class PmapOfShardedJitTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 9c672d4531cc..72456ad6c830 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -116,7 +116,6 @@ def _rand_sparse(shape, dtype, nse=nse): return _rand_sparse -@jtu.with_config(jax_numpy_rank_promotion="raise") class cuSparseTest(jtu.JaxTestCase): def gpu_dense_conversion_warning_context(self, dtype): if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer): @@ -555,7 +554,6 @@ def test_coo_matmul_ad(self, shape, dtype, bshape): self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol) -@jtu.with_config(jax_numpy_rank_promotion="raise") class BCOOTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_nbatch={}_ndense={}".format( @@ -1679,7 +1677,6 @@ def test_bcoo_bad_fillvals(self): self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de) -@jtu.with_config(jax_numpy_rank_promotion="raise") class SparseGradTest(jtu.JaxTestCase): def test_sparse_grad(self): rng_sparse = rand_sparse(self.rng()) @@ -1702,7 +1699,6 @@ def f(X, y): self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense) -@jtu.with_config(jax_numpy_rank_promotion="raise") class SparseObjectTest(jtu.JaxTestCase): def test_repr(self): M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32')) @@ -1898,7 +1894,6 @@ def test_bcoo_methods(self): self.assertArraysEqual(M.sum(), Msp.sum()) -@jtu.with_config(jax_numpy_rank_promotion="raise") class SparseRandomTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_indices_dtype={}_nbatch={}_ndense={}".format( diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index b622d59dfb69..aa03ef129135 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -30,7 +30,6 @@ config.parse_flags_with_absl() -@jtu.with_config(jax_numpy_rank_promotion="raise") class SparsifyTest(jtu.JaxTestCase): @classmethod def sparsify(cls, f): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index b196e9812374..f3ec3f40e8b3 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -210,7 +210,6 @@ def divisors2(n: int) -> Iterator[Tuple[int, int]]: yield axis_resources, mesh_data -@jtu.with_config(jax_numpy_rank_promotion="raise") class XMapTestCase(jtu.BufferDonationTestCase): pass @@ -1178,7 +1177,6 @@ def test_xeinsum_no_named_axes_reduce_sum(self): self.assertAllClose(out, expected, check_dtypes=True) -@jtu.with_config(jax_numpy_rank_promotion="raise") class XMapErrorTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 2)]) @@ -1410,7 +1408,6 @@ def testAxesMismatch(self): xmap(lambda x: x, (p,), (p, ['x']))([x, x, x]) # Error, we raise a generic tree mismatch message -@jtu.with_config(jax_numpy_rank_promotion="raise") class NamedAutodiffTests(jtu.JaxTestCase): def testVjpReduceAxes(self):