Skip to content

Commit

Permalink
[sparse] track sorted columns for COO GPU lowerings
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 9, 2022
1 parent 537e35b commit 3679e0c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 40 deletions.
78 changes: 54 additions & 24 deletions jax/experimental/sparse/coo.py
Expand Up @@ -28,6 +28,7 @@
from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.numpy.lax_numpy import _promote_dtypes
from jax._src.lib import xla_client
import jax.numpy as jnp

try:
Expand All @@ -40,13 +41,15 @@
except ImportError:
hipsparse = None

xops = xla_client.ops

Dtype = Any
Shape = Tuple[int, ...]

class COOInfo(NamedTuple):
shape: Shape
rows_sorted: bool = False
cols_sorted: bool = False


@tree_util.register_pytree_node_class
Expand All @@ -58,13 +61,16 @@ class COO(JAXSparse):
shape: Tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
_info = property(lambda self: COOInfo(self.shape, self._rows_sorted))
_info = property(lambda self: COOInfo(
shape=self.shape, rows_sorted=self._rows_sorted, cols_sorted=self._cols_sorted))
_bufs = property(lambda self: (self.data, self.row, self.col))
_rows_sorted: bool
_cols_sorted: bool

def __init__(self, args, *, shape, rows_sorted=False):
def __init__(self, args, *, shape, rows_sorted=False, cols_sorted=False):
self.data, self.row, self.col = _safe_asarray(args)
self._rows_sorted = rows_sorted
self._cols_sorted = cols_sorted
super().__init__(args, shape=shape)

@classmethod
Expand All @@ -90,15 +96,16 @@ def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
raise ValueError(f"COO must have ndim=2; got shape={shape}")
data = jnp.empty(0, dtype)
row = col = jnp.empty(0, index_dtype)
return cls((data, row, col), shape=shape)
return cls((data, row, col), shape=shape, rows_sorted=True, cols_sorted=True)

def todense(self):
return coo_todense(self)

def transpose(self, axes=None):
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return COO((self.data, self.col, self.row), shape=self.shape[::-1], rows_sorted=False)
return COO((self.data, self.col, self.row), shape=self.shape[::-1],
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted)

def tree_flatten(self):
return (self.data, self.row, self.col), self._info._asdict()
Expand Down Expand Up @@ -164,17 +171,27 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
spinfo=spinfo)
if not spinfo.rows_sorted:
warnings.warn("coo_todense GPU lowering requires matrices with sorted rows. To sort the rows "
"in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default "
"implementation.", CuSparseEfficiencyWarning)

if spinfo.rows_sorted:
shape = spinfo.shape
transpose = False
elif spinfo.cols_sorted:
row, col = col, row
transpose = True
shape = spinfo.shape[::-1]
else:
warnings.warn("coo_todense GPU lowering requires matrices with sorted rows or sorted cols. "
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
"back to the default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
spinfo=spinfo)

if cusparse is not None:
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)]
result = cusparse.coo_todense(ctx.builder, data, row, col, shape=shape)
else:
return [hipsparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)]
result = hipsparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)

return [xops.Transpose(result, (1, 0))] if transpose else [result]

def _coo_todense_jvp(data_dot, data, row, col, *, spinfo):
return _coo_todense(data_dot, row, col, spinfo=spinfo)
Expand Down Expand Up @@ -380,19 +397,26 @@ def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
spinfo=spinfo, transpose=transpose)
if not spinfo.rows_sorted:
warnings.warn("coo_matvec GPU lowering requires matrices with sorted rows. To sort the rows "
"in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default "
"implementation.", CuSparseEfficiencyWarning)

if spinfo.rows_sorted:
shape = spinfo.shape
elif spinfo.cols_sorted:
row, col = col, row
transpose = not transpose
shape = spinfo.shape[::-1]
else:
warnings.warn("coo_matvec GPU lowering requires matrices with sorted rows or sorted cols. "
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
"back to the default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
spinfo=spinfo, transpose=transpose)

if cusparse is not None:
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=spinfo.shape,
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]
else:
return [hipsparse.coo_matvec(ctx.builder, data, row, col, v, shape=spinfo.shape,
transpose=transpose)]
return [hipsparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]

def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose):
return _coo_matvec(data_dot, row, col, v, spinfo=spinfo, transpose=transpose)
Expand Down Expand Up @@ -489,19 +513,25 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
spinfo=spinfo, transpose=transpose)
if not spinfo.rows_sorted:
warnings.warn("coo_matmat GPU lowering requires matrices with sorted rows. To sort the rows "
"in your matrix, use e.g. mat = mat._sort_rows(). Falling back to the default "
"implementation.", CuSparseEfficiencyWarning)
if spinfo.rows_sorted:
shape = spinfo.shape
elif spinfo.cols_sorted:
row, col = col, row
transpose = not transpose
shape = spinfo.shape[::-1]
else:
warnings.warn("coo_matmat GPU lowering requires matrices with sorted rows or sorted cols. "
"To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling "
"back to the default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
spinfo=spinfo, transpose=transpose)

if cusparse is not None:
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=spinfo.shape,
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]
else:
return [hipsparse.coo_matmat(ctx.builder, data, row, col, B, shape=spinfo.shape,
transpose=transpose)]
return [hipsparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]

def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose):
return _coo_matmat(data_dot, row, col, B, spinfo=spinfo, transpose=transpose)
Expand Down
69 changes: 53 additions & 16 deletions tests/sparse_test.py
Expand Up @@ -19,6 +19,7 @@
import random
import unittest
from typing import NamedTuple, Tuple
import warnings

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -132,6 +133,12 @@ def gpu_matmul_warning_context(self, dtype):
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
return contextlib.nullcontext()

@contextlib.contextmanager
def assertNoWarnings(self):
with warnings.catch_warnings(record=True) as caught_warnings:
yield
self.assertEmpty(caught_warnings)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
Expand Down Expand Up @@ -435,36 +442,66 @@ def test_coo_sorted_indices(self):
self.assertArraysEqual(mat.todense(), mat_resorted.todense())

@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse")
def test_coo_sorted_indices_gpu_warnings(self):
def test_coo_sorted_indices_gpu_lowerings(self):
dtype = jnp.float32

mat_sorted = sparse.COO.fromdense(jnp.arange(9, dtype=dtype).reshape(3, 3))
self.assertTrue(mat_sorted._rows_sorted)
mat = jnp.arange(12, dtype=dtype).reshape(4, 3)

mat_rows_sorted = sparse.COO.fromdense(mat)
self.assertTrue(mat_rows_sorted._rows_sorted)
self.assertFalse(mat_rows_sorted._cols_sorted)

mat_cols_sorted = sparse.COO.fromdense(mat.T).T
self.assertFalse(mat_cols_sorted._rows_sorted)
self.assertTrue(mat_cols_sorted._cols_sorted)

mat_unsorted = mat_sorted.T
mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape)
self.assertFalse(mat_unsorted._rows_sorted)
self.assertFalse(mat_unsorted._cols_sorted)

self.assertArraysEqual(mat_sorted.todense().T, mat_unsorted._sort_rows().todense())
self.assertArraysEqual(mat, mat_rows_sorted._sort_rows().todense())
self.assertArraysEqual(mat, mat_cols_sorted._sort_rows().todense())
self.assertArraysEqual(mat, mat_unsorted._sort_rows().todense())

todense = jit(sparse.coo_todense)
todense(mat_sorted)
todense(mat_unsorted._sort_rows())
with self.assertNoWarnings():
dense_rows_sorted = todense(mat_rows_sorted)
dense_cols_sorted = todense(mat_cols_sorted)
dense_unsorted = todense(mat_unsorted._sort_rows())
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_todense GPU lowering requires matrices with sorted rows.*"):
todense(mat_unsorted)
dense_unsorted_fallback = todense(mat_unsorted)
self.assertArraysEqual(mat, dense_rows_sorted)
self.assertArraysEqual(mat, dense_cols_sorted)
self.assertArraysEqual(mat, dense_unsorted)
self.assertArraysEqual(mat, dense_unsorted_fallback)

lhs_vec = jnp.arange(3, dtype=dtype)
rhs_vec = jnp.arange(3, dtype=dtype)
matvec = jit(sparse.coo_matvec)
matvec(mat_sorted, lhs_vec)
matvec(mat_unsorted._sort_rows(), lhs_vec)
matvec_expected = mat @ rhs_vec
with self.assertNoWarnings():
matvec_rows_sorted = matvec(mat_rows_sorted, rhs_vec)
matvec_cols_sorted = matvec(mat_cols_sorted, rhs_vec)
matvec_unsorted = matvec(mat_unsorted._sort_rows(), rhs_vec)
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matvec GPU lowering requires matrices with sorted rows.*"):
matvec(mat_unsorted, lhs_vec)
matvec_unsorted_fallback = matvec(mat_unsorted, rhs_vec)
self.assertArraysEqual(matvec_expected, matvec_rows_sorted)
self.assertArraysEqual(matvec_expected, matvec_cols_sorted)
self.assertArraysEqual(matvec_expected, matvec_unsorted)
self.assertArraysEqual(matvec_expected, matvec_unsorted_fallback)

lhs_mat = jnp.arange(6, dtype=dtype).reshape(3, 2)
rhs_mat = jnp.arange(6, dtype=dtype).reshape(3, 2)
matmat = jit(sparse.coo_matmat)
matmat(mat_sorted, lhs_mat)
matmat(mat_unsorted._sort_rows(), lhs_mat)
matmat_expected = mat @ rhs_mat
with self.assertNoWarnings():
matmat_rows_sorted = matmat(mat_rows_sorted, rhs_mat)
matmat_cols_sorted = matmat(mat_cols_sorted, rhs_mat)
matmat_unsorted = matmat(mat_unsorted._sort_rows(), rhs_mat)
with self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, "coo_matmat GPU lowering requires matrices with sorted rows.*"):
matmat(mat_unsorted, lhs_mat)
matmat_unsorted_fallback = matmat(mat_unsorted, rhs_mat)
self.assertArraysEqual(matmat_expected, matmat_rows_sorted)
self.assertArraysEqual(matmat_expected, matmat_cols_sorted)
self.assertArraysEqual(matmat_expected, matmat_unsorted)
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)

@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
def test_gpu_translation_rule(self):
Expand Down

0 comments on commit 3679e0c

Please sign in to comment.