diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index edc422ffcb91..0a205c9eca97 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -28,6 +28,7 @@ from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning from jax import tree_util from jax._src.numpy.lax_numpy import _promote_dtypes +from jax._src.lib import xla_client import jax.numpy as jnp try: @@ -40,6 +41,7 @@ except ImportError: hipsparse = None +xops = xla_client.ops Dtype = Any Shape = Tuple[int, ...] @@ -47,6 +49,7 @@ class COOInfo(NamedTuple): shape: Shape rows_sorted: bool = False + cols_sorted: bool = False @tree_util.register_pytree_node_class @@ -58,13 +61,16 @@ class COO(JAXSparse): shape: Tuple[int, int] nse = property(lambda self: self.data.size) dtype = property(lambda self: self.data.dtype) - _info = property(lambda self: COOInfo(self.shape, self._rows_sorted)) + _info = property(lambda self: COOInfo( + shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted)) _bufs = property(lambda self: (self.data, self.row, self.col)) _rows_sorted: bool + _cols_sorted: bool - def __init__(self, args, *, shape, rows_sorted=False): + def __init__(self, args, *, shape, rows_sorted=False, cols_sorted=False): self.data, self.row, self.col = _safe_asarray(args) self._rows_sorted = rows_sorted + self._cols_sorted = cols_sorted super().__init__(args, shape=shape) @classmethod @@ -90,7 +96,7 @@ def _empty(cls, shape, *, dtype=None, index_dtype='int32'): raise ValueError(f"COO must have ndim=2; got shape={shape}") data = jnp.empty(0, dtype) row = col = jnp.empty(0, index_dtype) - return cls((data, row, col), shape=shape) + return cls((data, row, col), shape=shape, rows_sorted=True, cols_sorted=True) def todense(self): return coo_todense(self) @@ -98,7 +104,8 @@ def todense(self): def transpose(self, axes=None): if axes is not None: raise NotImplementedError("axes argument to transpose()") - return COO((self.data, self.col, self.row), shape=self.shape[::-1], rows_sorted=False) + return COO((self.data, self.col, self.row), shape=self.shape[::-1], + rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted) def tree_flatten(self): return (self.data, self.row, self.col), self._info._asdict() @@ -164,17 +171,27 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col, spinfo=spinfo) - if not spinfo.rows_sorted: - warnings.warn("coo_todense GPU lowering requires matrices with sorted rows. To sort the rows " - "in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default " - "implementation.", CuSparseEfficiencyWarning) + + if spinfo.rows_sorted: + shape = spinfo.shape + transpose = False + elif spinfo.cols_sorted: + row, col = col, row + transpose = True + shape = spinfo.shape[::-1] + else: + warnings.warn("coo_todense GPU lowering requires matrices with sorted rows or sorted cols. " + "To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling " + "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col, spinfo=spinfo) if cusparse is not None: - return [cusparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)] + result = cusparse.coo_todense(ctx.builder, data, row, col, shape=shape) else: - return [hipsparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)] + result = hipsparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape) + + return [xops.Transpose(result, (1, 0))] if transpose else [result] def _coo_todense_jvp(data_dot, data, row, col, *, spinfo): return _coo_todense(data_dot, row, col, spinfo=spinfo) @@ -380,19 +397,26 @@ def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v, spinfo=spinfo, transpose=transpose) - if not spinfo.rows_sorted: - warnings.warn("coo_matvec GPU lowering requires matrices with sorted rows. To sort the rows " - "in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default " - "implementation.", CuSparseEfficiencyWarning) + + if spinfo.rows_sorted: + shape = spinfo.shape + elif spinfo.cols_sorted: + row, col = col, row + transpose = not transpose + shape = spinfo.shape[::-1] + else: + warnings.warn("coo_matvec GPU lowering requires matrices with sorted rows or sorted cols. " + "To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling " + "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v, spinfo=spinfo, transpose=transpose) if cusparse is not None: - return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=spinfo.shape, + return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape, transpose=transpose)] else: - return [hipsparse.coo_matvec(ctx.builder, data, row, col, v, shape=spinfo.shape, - transpose=transpose)] + return [hipsparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape, + transpose=transpose)] def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose): return _coo_matvec(data_dot, row, col, v, spinfo=spinfo, transpose=transpose) @@ -489,19 +513,25 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B, spinfo=spinfo, transpose=transpose) - if not spinfo.rows_sorted: - warnings.warn("coo_matmat GPU lowering requires matrices with sorted rows. To sort the rows " - "in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default " - "implementation.", CuSparseEfficiencyWarning) + if spinfo.rows_sorted: + shape = spinfo.shape + elif spinfo.cols_sorted: + row, col = col, row + transpose = not transpose + shape = spinfo.shape[::-1] + else: + warnings.warn("coo_matmat GPU lowering requires matrices with sorted rows or sorted cols. " + "To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling " + "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B, spinfo=spinfo, transpose=transpose) if cusparse is not None: - return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=spinfo.shape, + return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape, transpose=transpose)] else: - return [hipsparse.coo_matmat(ctx.builder, data, row, col, B, shape=spinfo.shape, - transpose=transpose)] + return [hipsparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape, + transpose=transpose)] def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose): return _coo_matmat(data_dot, row, col, B, spinfo=spinfo, transpose=transpose) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index e001a630156b..15fd97d10af1 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -19,6 +19,7 @@ import random import unittest from typing import NamedTuple, Tuple +import warnings from absl.testing import absltest from absl.testing import parameterized @@ -132,6 +133,12 @@ def gpu_matmul_warning_context(self, dtype): return self.assertWarns(sparse.CuSparseEfficiencyWarning) return contextlib.nullcontext() + @contextlib.contextmanager + def assertNoWarnings(self): + with warnings.catch_warnings(record=True) as caught_warnings: + yield + self.assertEmpty(caught_warnings) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} @@ -435,36 +442,66 @@ def test_coo_sorted_indices(self): self.assertArraysEqual(mat.todense(), mat_resorted.todense()) @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse") - def test_coo_sorted_indices_gpu_warnings(self): + def test_coo_sorted_indices_gpu_lowerings(self): dtype = jnp.float32 - mat_sorted = sparse.COO.fromdense(jnp.arange(9, dtype=dtype).reshape(3, 3)) - self.assertTrue(mat_sorted._rows_sorted) + mat = jnp.arange(12, dtype=dtype).reshape(4, 3) + + mat_rows_sorted = sparse.COO.fromdense(mat) + self.assertTrue(mat_rows_sorted._rows_sorted) + self.assertFalse(mat_rows_sorted._cols_sorted) + + mat_cols_sorted = sparse.COO.fromdense(mat.T).T + self.assertFalse(mat_cols_sorted._rows_sorted) + self.assertTrue(mat_cols_sorted._cols_sorted) - mat_unsorted = mat_sorted.T + mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape) self.assertFalse(mat_unsorted._rows_sorted) + self.assertFalse(mat_unsorted._cols_sorted) - self.assertArraysEqual(mat_sorted.todense().T, mat_unsorted._sort_rows().todense()) + self.assertArraysEqual(mat, mat_rows_sorted._sort_rows().todense()) + self.assertArraysEqual(mat, mat_cols_sorted._sort_rows().todense()) + self.assertArraysEqual(mat, mat_unsorted._sort_rows().todense()) todense = jit(sparse.coo_todense) - todense(mat_sorted) - todense(mat_unsorted._sort_rows()) + with self.assertNoWarnings(): + dense_rows_sorted = todense(mat_rows_sorted) + dense_cols_sorted = todense(mat_cols_sorted) + dense_unsorted = todense(mat_unsorted._sort_rows()) with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_todense GPU lowering requires matrices with sorted rows.*"): - todense(mat_unsorted) + dense_unsorted_fallback = todense(mat_unsorted) + self.assertArraysEqual(mat, dense_rows_sorted) + self.assertArraysEqual(mat, dense_cols_sorted) + self.assertArraysEqual(mat, dense_unsorted) + self.assertArraysEqual(mat, dense_unsorted_fallback) - lhs_vec = jnp.arange(3, dtype=dtype) + rhs_vec = jnp.arange(3, dtype=dtype) matvec = jit(sparse.coo_matvec) - matvec(mat_sorted, lhs_vec) - matvec(mat_unsorted._sort_rows(), lhs_vec) + matvec_expected = mat @ rhs_vec + with self.assertNoWarnings(): + matvec_rows_sorted = matvec(mat_rows_sorted, rhs_vec) + matvec_cols_sorted = matvec(mat_cols_sorted, rhs_vec) + matvec_unsorted = matvec(mat_unsorted._sort_rows(), rhs_vec) with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matvec GPU lowering requires matrices with sorted rows.*"): - matvec(mat_unsorted, lhs_vec) + matvec_unsorted_fallback = matvec(mat_unsorted, rhs_vec) + self.assertArraysEqual(matvec_expected, matvec_rows_sorted) + self.assertArraysEqual(matvec_expected, matvec_cols_sorted) + self.assertArraysEqual(matvec_expected, matvec_unsorted) + self.assertArraysEqual(matvec_expected, matvec_unsorted_fallback) - lhs_mat = jnp.arange(6, dtype=dtype).reshape(3, 2) + rhs_mat = jnp.arange(6, dtype=dtype).reshape(3, 2) matmat = jit(sparse.coo_matmat) - matmat(mat_sorted, lhs_mat) - matmat(mat_unsorted._sort_rows(), lhs_mat) + matmat_expected = mat @ rhs_mat + with self.assertNoWarnings(): + matmat_rows_sorted = matmat(mat_rows_sorted, rhs_mat) + matmat_cols_sorted = matmat(mat_cols_sorted, rhs_mat) + matmat_unsorted = matmat(mat_unsorted._sort_rows(), rhs_mat) with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matmat GPU lowering requires matrices with sorted rows.*"): - matmat(mat_unsorted, lhs_mat) + matmat_unsorted_fallback = matmat(mat_unsorted, rhs_mat) + self.assertArraysEqual(matmat_expected, matmat_rows_sorted) + self.assertArraysEqual(matmat_expected, matmat_cols_sorted) + self.assertArraysEqual(matmat_expected, matmat_unsorted) + self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback) @unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU") def test_gpu_translation_rule(self):