diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 0eca900dc54e..ba07ff344c85 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -74,7 +74,7 @@ def setUp(self): def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) - if gpu and jax.default_backend() == "cpu": + if gpu and jax.test_device_matches(["cpu"]): raise unittest.SkipTest("Skipping GPU test case on CPU") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) @@ -180,7 +180,7 @@ def testNumpyToJax(self, shape, dtype): dtype=numpy_dtypes, ) @unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer") - @jtu.skip_on_devices("gpu") #NumPy only accepts cpu DLPacks + @jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks def testJaxToNumpy(self, shape, dtype): rng = jtu.rand_default(self.rng()) x_jax = jnp.array(rng(shape, dtype)) @@ -192,7 +192,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if not jtu.test_device_matches(["gpu"]): + if not jtu.test_device_matches(["cuda"]): self.skipTest("__cuda_array_interface__ is only supported on GPU") @jtu.sample_product( diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 89fe88899575..42a0872737eb 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -98,10 +98,8 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + @jtu.device_supports_buffer_donation() def test_can_stage_out_debug_print_with_donate_argnums(self): - if jax.default_backend() not in {"gpu", "tpu"}: - raise unittest.SkipTest("Donate argnums not supported.") - def f(x, y): debug_print('x: {x}', x=x) return x + y @@ -120,10 +118,8 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + @jtu.device_supports_buffer_donation() def test_can_stage_out_ordered_print_with_donate_argnums(self): - if jax.default_backend() not in {"gpu", "tpu"}: - raise unittest.SkipTest("Donate argnums not supported.") - def f(x, y): debug_print('x: {x}', x=x, ordered=True) return x + y @@ -133,10 +129,8 @@ def f(x, y): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + @jtu.device_supports_buffer_donation() def test_can_stage_out_prints_with_donate_argnums(self): - if jax.default_backend() not in {"gpu", "tpu"}: - raise unittest.SkipTest("Donate argnums not supported.") - def f(x, y): debug_print('x: {x}', x=x, ordered=True) debug_print('x: {x}', x=x) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 51b1710ea27f..7940cf317969 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1275,7 +1275,7 @@ def testTrimZeros(self, a_shape, dtype, trim): def testPoly(self, a_shape, dtype, rank): if dtype in (np.float16, jnp.bfloat16, np.int16): self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") - elif rank == 2 and jtu.test_device_matches(["tpu", "gpu"]): + elif rank == 2 and not jtu.test_device_matches(["cpu"]): self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") rng = jtu.rand_default(self.rng()) tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } @@ -1914,7 +1914,7 @@ def np_fun(x, y): xshape=one_dim_array_shapes, yshape=one_dim_array_shapes, ) - @jtu.skip_on_devices("gpu", "tpu", "rocm") # backends don't support all dtypes. + @jtu.skip_on_devices("cuda", "tpu", "rocm") # backends don't support all dtypes. def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op): jnp_op = getattr(jnp, op) np_op = getattr(np, op) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 13abf2811049..9404b9fee11b 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -210,7 +210,7 @@ def testIssue1213(self): ) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): rng = jtu.rand_default(self.rng()) @@ -252,7 +252,7 @@ def check_left_eigenvectors(a, w, vl): ) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testEigvalsGrad(self, shape, dtype): # This test sometimes fails for large matrices. I (@j-towns) suspect, but # haven't checked, that might be because of perturbations causing the @@ -271,7 +271,7 @@ def testEigvalsGrad(self, shape, dtype): ) # TODO: enable when there is an eigendecomposition implementation # for GPU/TPU. - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testEigvals(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -280,7 +280,7 @@ def testEigvals(self, shape, dtype): w2 = jnp.linalg.eigvals(a) self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 1e-14}) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testEigvalsInf(self): # https://github.com/google/jax/issues/2661 x = jnp.array([[jnp.inf]]) @@ -290,7 +290,7 @@ def testEigvalsInf(self): shape=[(1, 1), (4, 4), (5, 5)], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testEigBatching(self, shape, dtype): rng = jtu.rand_default(self.rng()) shape = (10,) + shape @@ -688,7 +688,7 @@ def testNumpyQrModes(self, shape, dtype, mode): ) @jax.default_matmul_precision("float32") def testQr(self, shape, dtype, full_matrices): - if (jtu.test_device_matches(["gpu"]) and + if (jtu.test_device_matches(["cuda"]) and _is_required_cuda_version_satisfied(12000)): self.skipTest("Triggers a bug in cuda-12 b/287345077") rng = jtu.rand_default(self.rng()) @@ -1287,7 +1287,7 @@ def testTriangularSolveGradPrecision(self): dtype=int_types + float_types + complex_types ) def testExpm(self, n, batch_size, dtype): - if (jtu.test_device_matches(["gpu"]) and + if (jtu.test_device_matches(["cuda"]) and _is_required_cuda_version_satisfied(12000)): self.skipTest("Triggers a bug in cuda-12 b/287345077") @@ -1357,7 +1357,7 @@ def reference_fn(a, taus): dtype=float_types + complex_types, calc_q=[False, True], ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testHessenberg(self, shape, dtype, calc_q): rng = jtu.rand_default(self.rng()) jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q) @@ -1514,7 +1514,7 @@ def expm(x): shape=[(4, 4), (15, 15), (50, 50), (100, 100)], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testSchur(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -1526,7 +1526,7 @@ def testSchur(self, shape, dtype): shape=[(1, 1), (4, 4), (15, 15), (50, 50), (100, 100)], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testRsf2csf(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] @@ -1542,7 +1542,7 @@ def testRsf2csf(self, shape, dtype): ) # funm uses jax.scipy.linalg.schur which is implemented for a CPU # backend only, so tests on GPU and TPU backends are skipped here - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testFunm(self, shape, dtype, disp): def func(x): return x**-2.718 @@ -1558,7 +1558,7 @@ def func(x): shape=[(4, 4), (15, 15), (50, 50), (100, 100)], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testSqrtmPSDMatrix(self, shape, dtype): # Checks against scipy.linalg.sqrtm when the principal square root # is guaranteed to be unique (i.e no negative real eigenvalue) @@ -1581,7 +1581,7 @@ def testSqrtmPSDMatrix(self, shape, dtype): shape=[(4, 4), (15, 15), (50, 50), (100, 100)], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testSqrtmGenMatrix(self, shape, dtype): rng = jtu.rand_default(self.rng()) arg = rng(shape, dtype) @@ -1600,7 +1600,7 @@ def testSqrtmGenMatrix(self, shape, dtype): ], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testSqrtmEdgeCase(self, diag, expected, dtype): """ Tests the zero numerator condition @@ -1773,7 +1773,7 @@ def test_tridiagonal_solve(self, dtype): shape=[(4, 4), (15, 15), (50, 50), (100, 100)], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testSchur(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -1785,7 +1785,7 @@ def testSchur(self, shape, dtype): shape=[(2, 2), (4, 4), (15, 15), (50, 50), (100, 100)], dtype=float_types + complex_types, ) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testSchurBatching(self, shape, dtype): rng = jtu.rand_default(self.rng()) batch_size = 10 diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index cdbf19743803..ccba4c2ef11f 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -72,7 +72,7 @@ def assertSetsAllClose(self, x, y, rtol=None, atol=None, check_dtypes=True): trailing=[0, 2], ) # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testRoots(self, dtype, length, leading, trailing): rng = jtu.rand_some_zero(self.rng()) @@ -98,7 +98,7 @@ def np_fun(arg): trailing=[0, 2], ) # TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU. - @jtu.skip_on_devices("gpu", "tpu") + @jtu.run_on_devices("cpu") def testRootsNoStrip(self, dtype, length, leading, trailing): rng = jtu.rand_some_zero(self.rng()) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 71f765117c90..17c55a7df2a3 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -1958,7 +1958,7 @@ def test_bcsr_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): @jax.default_matmul_precision("float32") @jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning) def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - if (jtu.test_device_matches(["gpu"]) and + if (jtu.test_device_matches(["cuda"]) and _is_required_cuda_version_satisfied(12000)): raise unittest.SkipTest("Triggers a bug in cuda-12 b/287344632") @@ -2777,8 +2777,7 @@ class SparseSolverTest(sptu.SparseTestCase): reorder=[0, 1, 2, 3], dtype=jtu.dtypes.floating + jtu.dtypes.complex, ) - @jtu.run_on_devices("cpu", "gpu") - @jtu.skip_on_devices("rocm") # test n gpu requires cusolver + @jtu.run_on_devices("cpu", "cuda") def test_sparse_qr_linear_solver(self, size, reorder, dtype): if jtu.test_device_matches(["cuda"]) and not GPU_LOWERING_ENABLED: raise unittest.SkipTest('test requires cusparse/cusolver') @@ -2805,8 +2804,7 @@ def sparse_solve(data, indices, indptr, b): size=[10, 20, 50], dtype=jtu.dtypes.floating, ) - @jtu.run_on_devices("cpu", "gpu") - @jtu.skip_on_devices("rocm") # test requires cusolver + @jtu.run_on_devices("cpu", "cuda") def test_sparse_qr_linear_solver_grads(self, size, dtype): if jtu.test_device_matches(["cuda"]) and not GPU_LOWERING_ENABLED: raise unittest.SkipTest('test requires cusparse/cusolver') diff --git a/tests/xmap_test.py b/tests/xmap_test.py index dc424c7fdbb4..76fa9c22e3d4 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -459,7 +459,7 @@ def f(x, y): self.assertAllClose(f_mapped(x, x), expected) @jtu.with_and_without_mesh - @jtu.run_on_devices("gpu", "tpu") # In/out aliasing not supported on CPU. + @jtu.device_supports_buffer_donation() # In/out aliasing not supported on CPU def testBufferDonation(self, mesh, axis_resources): shard = lambda x: x if axis_resources: @@ -476,7 +476,7 @@ def testBufferDonation(self, mesh, axis_resources): self.assertNotDeleted(y) self.assertDeleted(x) - @jtu.run_on_devices("gpu", "tpu") # In/out aliasing not supported on CPU. + @jtu.device_supports_buffer_donation() # In/out aliasing not supported on CPU @jtu.with_mesh([('x', 2)]) @jtu.ignore_warning(category=UserWarning, # SPMD test generates warning. message="Some donated buffers were not usable*")