Skip to content

Commit

Permalink
Merge pull request #13616 from jakevdp:fix-sparse-error
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 494758906
  • Loading branch information
jax authors committed Dec 12, 2022
2 parents 4a9e9d5 + 2e95990 commit b868cf7
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions tests/sparse_test.py
Expand Up @@ -151,7 +151,7 @@ def gpu_dense_conversion_warning_context(self, dtype):
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
return contextlib.nullcontext()

def gpu_matmul_warning_context(self, dtype):
def gpu_matmul_dtype_warning_context(self, dtype):
if jtu.device_under_test() == "gpu" and dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
return contextlib.nullcontext()
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_csr_matvec(self, shape, dtype, transpose):
matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose)

self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

@jtu.sample_product(
Expand All @@ -338,7 +338,7 @@ def test_csr_matmat(self, shape, dtype, transpose):
matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose)

self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

@jtu.sample_product(
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_coo_matvec(self, shape, dtype, transpose):
matvec = lambda *args: sparse_coo._coo_matvec(*args, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True), transpose=transpose)

self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

@jtu.sample_product(
Expand All @@ -418,7 +418,7 @@ def test_coo_matmat(self, shape, dtype, transpose):
matmat = lambda *args: sparse_coo._coo_matmat(*args, spinfo=sparse_coo.COOInfo(shape=shape, rows_sorted=True), transpose=transpose)

self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

def test_coo_matmat_layout(self):
Expand Down Expand Up @@ -654,6 +654,11 @@ def test_coo_matmul_ad(self, shape, dtype, bshape):

class BCOOTest(sptu.SparseTestCase):

def gpu_matmul_warning_context(self, msg):
if GPU_LOWERING_ENABLED and config.jax_bcoo_cusparse_lowering:
return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg)
return contextlib.nullcontext()

def test_vmappable(self):
"""Test does not depend on batching rules of BCOO primitives."""
M = jnp.arange(9).reshape((3, 3))
Expand Down Expand Up @@ -1166,10 +1171,8 @@ def f_sparse(lhs_bcoo, lhs, rhs):
else:
lhs_bcoo, lhs, rhs = args_maker()
matmat_expected = f_dense(lhs_bcoo, lhs, rhs)
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
"bcoo_dot_general GPU lowering currently does not support this "
"batch-mode computation.*"):
with self.gpu_matmul_warning_context(
"bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"):
matmat_default_lowering_fallback = jit(f_sparse)(lhs_bcoo, lhs, rhs)
self.assertAllClose(matmat_expected, matmat_default_lowering_fallback,
atol=1E-6, rtol=1E-6)
Expand Down Expand Up @@ -1204,13 +1207,9 @@ def test_bcoo_batched_matmat_default_lowering(
sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general,
dimension_numbers=dimension_numbers))

if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
"bcoo_dot_general GPU lowering currently does not support this "
"batch-mode computation.*"):
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)

with self.gpu_matmul_warning_context(
"bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"):
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)
self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback)

@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
Expand All @@ -1236,14 +1235,11 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):

matmat_expected = lax.dot_general(lhs_mat_dense, rhs,
dimension_numbers=dimension_numbers_2d)
if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
with self.subTest(msg="2D"):
with self.gpu_matmul_warning_context(
"bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs)

with self.subTest(msg="2D"):
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)

lhs_vec_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32)
lhs_vec_bcoo = sparse.BCOO.fromdense(lhs_vec_dense, nse=5)
Expand All @@ -1260,14 +1256,11 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
vecmat_expected = lax.dot_general(lhs_vec_dense, rhs,
dimension_numbers=dimension_numbers_1d)

if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
with self.subTest(msg="1D"):
with self.gpu_matmul_warning_context(
"bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
vecmat_unsorted_fallback = sp_vecmat(lhs_vec_bcoo_unsorted, rhs)

with self.subTest(msg="1D"):
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)

@jtu.sample_product(
props=_generate_bcoo_dot_general_properties(
Expand Down

0 comments on commit b868cf7

Please sign in to comment.