Skip to content

Commit

Permalink
Clean up some device opt-in/opt-outs in test suite.
Browse files Browse the repository at this point in the history
Use allowlists rather than denylists in a few places.

PiperOrigin-RevId: 568968749
  • Loading branch information
hawkinsp authored and jax authors committed Sep 27, 2023
1 parent 5384561 commit 6be860b
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 39 deletions.
6 changes: 3 additions & 3 deletions tests/array_interoperability_test.py
Expand Up @@ -74,7 +74,7 @@ def setUp(self):
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
if gpu and jax.default_backend() == "cpu":
if gpu and jax.test_device_matches(["cpu"]):
raise unittest.SkipTest("Skipping GPU test case on CPU")
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
Expand Down Expand Up @@ -180,7 +180,7 @@ def testNumpyToJax(self, shape, dtype):
dtype=numpy_dtypes,
)
@unittest.skipIf(numpy_version < (1, 23, 0), "Requires numpy 1.23 or newer")
@jtu.skip_on_devices("gpu") #NumPy only accepts cpu DLPacks
@jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks
def testJaxToNumpy(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x_jax = jnp.array(rng(shape, dtype))
Expand All @@ -192,7 +192,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
if not jtu.test_device_matches(["gpu"]):
if not jtu.test_device_matches(["cuda"]):
self.skipTest("__cuda_array_interface__ is only supported on GPU")

@jtu.sample_product(
Expand Down
12 changes: 3 additions & 9 deletions tests/debugging_primitives_test.py
Expand Up @@ -98,10 +98,8 @@ def f(x):
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")

@jtu.device_supports_buffer_donation()
def test_can_stage_out_debug_print_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")

def f(x, y):
debug_print('x: {x}', x=x)
return x + y
Expand All @@ -120,10 +118,8 @@ def f(x):
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")

@jtu.device_supports_buffer_donation()
def test_can_stage_out_ordered_print_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")

def f(x, y):
debug_print('x: {x}', x=x, ordered=True)
return x + y
Expand All @@ -133,10 +129,8 @@ def f(x, y):
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")

@jtu.device_supports_buffer_donation()
def test_can_stage_out_prints_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")

def f(x, y):
debug_print('x: {x}', x=x, ordered=True)
debug_print('x: {x}', x=x)
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_test.py
Expand Up @@ -1275,7 +1275,7 @@ def testTrimZeros(self, a_shape, dtype, trim):
def testPoly(self, a_shape, dtype, rank):
if dtype in (np.float16, jnp.bfloat16, np.int16):
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
elif rank == 2 and jtu.test_device_matches(["tpu", "gpu"]):
elif rank == 2 and not jtu.test_device_matches(["cpu"]):
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.")
rng = jtu.rand_default(self.rng())
tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
Expand Down Expand Up @@ -1914,7 +1914,7 @@ def np_fun(x, y):
xshape=one_dim_array_shapes,
yshape=one_dim_array_shapes,
)
@jtu.skip_on_devices("gpu", "tpu", "rocm") # backends don't support all dtypes.
@jtu.skip_on_devices("cuda", "tpu", "rocm") # backends don't support all dtypes.
def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op):
jnp_op = getattr(jnp, op)
np_op = getattr(np, op)
Expand Down
32 changes: 16 additions & 16 deletions tests/linalg_test.py
Expand Up @@ -210,7 +210,7 @@ def testIssue1213(self):
)
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEig(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
rng = jtu.rand_default(self.rng())
Expand Down Expand Up @@ -252,7 +252,7 @@ def check_left_eigenvectors(a, w, vl):
)
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEigvalsGrad(self, shape, dtype):
# This test sometimes fails for large matrices. I (@j-towns) suspect, but
# haven't checked, that might be because of perturbations causing the
Expand All @@ -271,7 +271,7 @@ def testEigvalsGrad(self, shape, dtype):
)
# TODO: enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEigvals(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
Expand All @@ -280,7 +280,7 @@ def testEigvals(self, shape, dtype):
w2 = jnp.linalg.eigvals(a)
self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 1e-14})

@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEigvalsInf(self):
# https://github.com/google/jax/issues/2661
x = jnp.array([[jnp.inf]])
Expand All @@ -290,7 +290,7 @@ def testEigvalsInf(self):
shape=[(1, 1), (4, 4), (5, 5)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testEigBatching(self, shape, dtype):
rng = jtu.rand_default(self.rng())
shape = (10,) + shape
Expand Down Expand Up @@ -688,7 +688,7 @@ def testNumpyQrModes(self, shape, dtype, mode):
)
@jax.default_matmul_precision("float32")
def testQr(self, shape, dtype, full_matrices):
if (jtu.test_device_matches(["gpu"]) and
if (jtu.test_device_matches(["cuda"]) and
_is_required_cuda_version_satisfied(12000)):
self.skipTest("Triggers a bug in cuda-12 b/287345077")
rng = jtu.rand_default(self.rng())
Expand Down Expand Up @@ -1287,7 +1287,7 @@ def testTriangularSolveGradPrecision(self):
dtype=int_types + float_types + complex_types
)
def testExpm(self, n, batch_size, dtype):
if (jtu.test_device_matches(["gpu"]) and
if (jtu.test_device_matches(["cuda"]) and
_is_required_cuda_version_satisfied(12000)):
self.skipTest("Triggers a bug in cuda-12 b/287345077")

Expand Down Expand Up @@ -1357,7 +1357,7 @@ def reference_fn(a, taus):
dtype=float_types + complex_types,
calc_q=[False, True],
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testHessenberg(self, shape, dtype, calc_q):
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q)
Expand Down Expand Up @@ -1514,7 +1514,7 @@ def expm(x):
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSchur(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
Expand All @@ -1526,7 +1526,7 @@ def testSchur(self, shape, dtype):
shape=[(1, 1), (4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testRsf2csf(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
Expand All @@ -1542,7 +1542,7 @@ def testRsf2csf(self, shape, dtype):
)
# funm uses jax.scipy.linalg.schur which is implemented for a CPU
# backend only, so tests on GPU and TPU backends are skipped here
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testFunm(self, shape, dtype, disp):
def func(x):
return x**-2.718
Expand All @@ -1558,7 +1558,7 @@ def func(x):
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSqrtmPSDMatrix(self, shape, dtype):
# Checks against scipy.linalg.sqrtm when the principal square root
# is guaranteed to be unique (i.e no negative real eigenvalue)
Expand All @@ -1581,7 +1581,7 @@ def testSqrtmPSDMatrix(self, shape, dtype):
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSqrtmGenMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)
Expand All @@ -1600,7 +1600,7 @@ def testSqrtmGenMatrix(self, shape, dtype):
],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSqrtmEdgeCase(self, diag, expected, dtype):
"""
Tests the zero numerator condition
Expand Down Expand Up @@ -1773,7 +1773,7 @@ def test_tridiagonal_solve(self, dtype):
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSchur(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
Expand All @@ -1785,7 +1785,7 @@ def testSchur(self, shape, dtype):
shape=[(2, 2), (4, 4), (15, 15), (50, 50), (100, 100)],
dtype=float_types + complex_types,
)
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testSchurBatching(self, shape, dtype):
rng = jtu.rand_default(self.rng())
batch_size = 10
Expand Down
4 changes: 2 additions & 2 deletions tests/polynomial_test.py
Expand Up @@ -72,7 +72,7 @@ def assertSetsAllClose(self, x, y, rtol=None, atol=None, check_dtypes=True):
trailing=[0, 2],
)
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testRoots(self, dtype, length, leading, trailing):
rng = jtu.rand_some_zero(self.rng())

Expand All @@ -98,7 +98,7 @@ def np_fun(arg):
trailing=[0, 2],
)
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.run_on_devices("cpu")
def testRootsNoStrip(self, dtype, length, leading, trailing):
rng = jtu.rand_some_zero(self.rng())

Expand Down
8 changes: 3 additions & 5 deletions tests/sparse_test.py
Expand Up @@ -1958,7 +1958,7 @@ def test_bcsr_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
@jax.default_matmul_precision("float32")
@jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning)
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
if (jtu.test_device_matches(["gpu"]) and
if (jtu.test_device_matches(["cuda"]) and
_is_required_cuda_version_satisfied(12000)):
raise unittest.SkipTest("Triggers a bug in cuda-12 b/287344632")

Expand Down Expand Up @@ -2777,8 +2777,7 @@ class SparseSolverTest(sptu.SparseTestCase):
reorder=[0, 1, 2, 3],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
@jtu.run_on_devices("cpu", "gpu")
@jtu.skip_on_devices("rocm") # test n gpu requires cusolver
@jtu.run_on_devices("cpu", "cuda")
def test_sparse_qr_linear_solver(self, size, reorder, dtype):
if jtu.test_device_matches(["cuda"]) and not GPU_LOWERING_ENABLED:
raise unittest.SkipTest('test requires cusparse/cusolver')
Expand All @@ -2805,8 +2804,7 @@ def sparse_solve(data, indices, indptr, b):
size=[10, 20, 50],
dtype=jtu.dtypes.floating,
)
@jtu.run_on_devices("cpu", "gpu")
@jtu.skip_on_devices("rocm") # test requires cusolver
@jtu.run_on_devices("cpu", "cuda")
def test_sparse_qr_linear_solver_grads(self, size, dtype):
if jtu.test_device_matches(["cuda"]) and not GPU_LOWERING_ENABLED:
raise unittest.SkipTest('test requires cusparse/cusolver')
Expand Down
4 changes: 2 additions & 2 deletions tests/xmap_test.py
Expand Up @@ -459,7 +459,7 @@ def f(x, y):
self.assertAllClose(f_mapped(x, x), expected)

@jtu.with_and_without_mesh
@jtu.run_on_devices("gpu", "tpu") # In/out aliasing not supported on CPU.
@jtu.device_supports_buffer_donation() # In/out aliasing not supported on CPU
def testBufferDonation(self, mesh, axis_resources):
shard = lambda x: x
if axis_resources:
Expand All @@ -476,7 +476,7 @@ def testBufferDonation(self, mesh, axis_resources):
self.assertNotDeleted(y)
self.assertDeleted(x)

@jtu.run_on_devices("gpu", "tpu") # In/out aliasing not supported on CPU.
@jtu.device_supports_buffer_donation() # In/out aliasing not supported on CPU
@jtu.with_mesh([('x', 2)])
@jtu.ignore_warning(category=UserWarning, # SPMD test generates warning.
message="Some donated buffers were not usable*")
Expand Down

0 comments on commit 6be860b

Please sign in to comment.