Skip to content

Commit

Permalink
[sparse]: COO: check for sorted rows before cusparse lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 8, 2022
1 parent 809156c commit 43c3bfd
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 9 deletions.
44 changes: 40 additions & 4 deletions jax/experimental/sparse/coo.py
Expand Up @@ -21,6 +21,7 @@
import numpy as np

from jax import core
from jax import lax
from jax.interpreters import ad
from jax.interpreters import xla
from jax.experimental.sparse._base import JAXSparse
Expand All @@ -45,6 +46,7 @@

class COOInfo(NamedTuple):
shape: Shape
rows_sorted: bool = False


@tree_util.register_pytree_node_class
Expand All @@ -56,17 +58,30 @@ class COO(JAXSparse):
shape: Tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
_info = property(lambda self: COOInfo(self.shape))
_info = property(lambda self: COOInfo(self.shape, self._rows_sorted))
_bufs = property(lambda self: (self.data, self.row, self.col))
_rows_sorted: bool

def __init__(self, args, *, shape):
def __init__(self, args, *, shape, rows_sorted=False):
self.data, self.row, self.col = _safe_asarray(args)
self._rows_sorted = rows_sorted
super().__init__(args, shape=shape)

@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
return coo_fromdense(mat, nse=nse, index_dtype=index_dtype)

def _sort_rows(self):
"""Return a copy of the COO matrix with sorted rows.
If self._rows_sorted is True, this returns ``self`` without a copy.
"""
# TODO(jakevdp): would be benefit from lowering this to cusparse sort_rows utility?
if self._rows_sorted:
return self
row, col, data = lax.sort((self.row, self.col, self.data), num_keys=1)
return self.__class__((data, row, col), shape=self.shape, rows_sorted=True)

@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
"""Create an empty COO instance. Public method is sparse.empty()."""
Expand All @@ -83,7 +98,7 @@ def todense(self):
def transpose(self, axes=None):
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return COO((self.data, self.col, self.row), shape=self.shape[::-1])
return COO((self.data, self.col, self.row), shape=self.shape[::-1], rows_sorted=False)

def tree_flatten(self):
return (self.data, self.row, self.col), self._info._asdict()
Expand Down Expand Up @@ -149,6 +164,13 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
spinfo=spinfo)
if not spinfo.rows_sorted:
warnings.warn("coo_todense GPU lowering requires matrices with sorted rows. To sort the rows "
"in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default "
"implementation.", CuSparseEfficiencyWarning)
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
spinfo=spinfo)

if cusparse is not None:
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)]
else:
Expand Down Expand Up @@ -196,7 +218,7 @@ def coo_fromdense(mat, *, nse=None, index_dtype=jnp.int32):
if nse is None:
nse = (mat != 0).sum()
nse = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)
return COO(_coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape, rows_sorted=True)

def _coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
"""Create COO-format sparse matrix from a dense matrix.
Expand Down Expand Up @@ -358,6 +380,13 @@ def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
spinfo=spinfo, transpose=transpose)
if not spinfo.rows_sorted:
warnings.warn("coo_matvec GPU lowering requires matrices with sorted rows. To sort the rows "
"in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default "
"implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
spinfo=spinfo, transpose=transpose)

if cusparse is not None:
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=spinfo.shape,
transpose=transpose)]
Expand Down Expand Up @@ -460,6 +489,13 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
spinfo=spinfo, transpose=transpose)
if not spinfo.rows_sorted:
warnings.warn("coo_matmat GPU lowering requires matrices with sorted rows. To sort the rows "
"in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default "
"implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
spinfo=spinfo, transpose=transpose)

if cusparse is not None:
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=spinfo.shape,
transpose=transpose)]
Expand Down
54 changes: 49 additions & 5 deletions tests/sparse_test.py
Expand Up @@ -55,6 +55,8 @@
np.complex128: 1E-10,
}

GPU_LOWERING_ENABLED = (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported)

class BcooDotGeneralProperties(NamedTuple):
lhs_shape: Tuple[int]
rhs_shape: Tuple[int]
Expand Down Expand Up @@ -331,7 +333,7 @@ def test_coo_todense(self, shape, dtype):
M = rng(shape, dtype)

args = (M.data, M.row, M.col)
todense = lambda *args: sparse_coo._coo_todense(*args, spinfo=sparse_coo.COOInfo(shape=M.shape))
todense = lambda *args: sparse_coo._coo_todense(*args, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True))

self.assertArraysEqual(M.toarray(), todense(*args))
with self.gpu_dense_conversion_warning_context(dtype):
Expand Down Expand Up @@ -377,7 +379,7 @@ def test_coo_matvec(self, shape, dtype, transpose):
v = v_rng(op(M).shape[1], dtype)

args = (M.data, M.row, M.col, v)
matvec = lambda *args: sparse_coo._coo_matvec(*args, spinfo=sparse_coo.COOInfo(shape=M.shape), transpose=transpose)
matvec = lambda *args: sparse_coo._coo_matvec(*args, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True), transpose=transpose)

self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
Expand All @@ -399,7 +401,7 @@ def test_coo_matmat(self, shape, dtype, transpose):
B = B_rng((op(M).shape[1], 4), dtype)

args = (M.data, M.row, M.col, B)
matmat = lambda *args: sparse_coo._coo_matmat(*args, spinfo=sparse_coo.COOInfo(shape=shape), transpose=transpose)
matmat = lambda *args: sparse_coo._coo_matmat(*args, spinfo=sparse_coo.COOInfo(shape=shape, rows_sorted=True), transpose=transpose)

self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
Expand All @@ -415,13 +417,55 @@ def test_coo_matmat_layout(self):
x = jnp.arange(9).reshape(3, 3).astype(d.dtype)

def f(x):
return sparse_coo._coo_matmat(d, i, j, x.T, spinfo=sparse_coo.COOInfo(shape=shape))
return sparse_coo._coo_matmat(d, i, j, x.T, spinfo=sparse_coo.COOInfo(shape=shape, rows_sorted=True))

result = f(x)
result_jit = jit(f)(x)

self.assertAllClose(result, result_jit)

def test_coo_sorted_indices(self):
rng = self.rng()
sprng = rand_sparse(rng)

mat = sparse.COO.fromdense(sprng((5, 6), np.float32))
perm = rng.permutation(mat.nse)
mat_unsorted = sparse.COO((mat.data[perm], mat.row[perm], mat.col[perm]), shape=mat.shape)
mat_resorted = mat_unsorted._sort_rows()
self.assertArraysEqual(mat.todense(), mat_resorted.todense())

@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse")
def test_coo_sorted_indices_gpu_warnings(self):
dtype = jnp.float32

mat_sorted = sparse.COO.fromdense(jnp.arange(9, dtype=dtype).reshape(3, 3))
self.assertTrue(mat_sorted._rows_sorted)

mat_unsorted = mat_sorted.T
self.assertFalse(mat_unsorted._rows_sorted)

self.assertArraysEqual(mat_sorted.todense().T, mat_unsorted._sort_rows().todense())

todense = jit(sparse.coo_todense)
todense(mat_sorted)
todense(mat_unsorted._sort_rows())
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_todense GPU lowering requires matrices with sorted rows.*"):
todense(mat_unsorted)

lhs_vec = jnp.arange(3, dtype=dtype)
matvec = jit(sparse.coo_matvec)
matvec(mat_sorted, lhs_vec)
matvec(mat_unsorted._sort_rows(), lhs_vec)
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matvec GPU lowering requires matrices with sorted rows.*"):
matvec(mat_unsorted, lhs_vec)

lhs_mat = jnp.arange(6, dtype=dtype).reshape(3, 2)
matmat = jit(sparse.coo_matmat)
matmat(mat_sorted, lhs_mat)
matmat(mat_unsorted._sort_rows(), lhs_mat)
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matmat GPU lowering requires matrices with sorted rows.*"):
matmat(mat_unsorted, lhs_mat)

@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
def test_gpu_translation_rule(self):
version = xla_bridge.get_backend().platform_version
Expand Down Expand Up @@ -470,7 +514,7 @@ def test_coo_todense_ad(self, shape, dtype):
rng = rand_sparse(self.rng(), post=jnp.array)
M = rng(shape, dtype)
data, row, col = sparse_coo._coo_fromdense(M, nse=(M != 0).sum())
f = lambda data: sparse_coo._coo_todense(data, row, col, spinfo=sparse_coo.COOInfo(shape=M.shape))
f = lambda data: sparse_coo._coo_todense(data, row, col, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True))

# Forward-mode
primals, tangents = jax.jvp(f, [data], [jnp.ones_like(data)])
Expand Down

0 comments on commit 43c3bfd

Please sign in to comment.