Skip to content

Commit

Permalink
Merge pull request #14536 from jakevdp:coo-oob
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 510281491
  • Loading branch information
jax authors committed Feb 17, 2023
2 parents eea1fef + df35824 commit c467d84
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
32 changes: 23 additions & 9 deletions jax/experimental/sparse/coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class COOInfo(NamedTuple):
shape: Shape
rows_sorted: bool = False
cols_sorted: bool = False
padded: bool = False


@tree_util.register_pytree_node_class
Expand All @@ -63,16 +64,19 @@ class COO(JAXSparse):
dtype = property(lambda self: self.data.dtype)
_info = property(lambda self: COOInfo(
shape=self.shape, rows_sorted=self._rows_sorted,
cols_sorted=self._cols_sorted))
cols_sorted=self._cols_sorted, padded=self._padded))
_bufs = property(lambda self: (self.data, self.row, self.col))
_rows_sorted: bool
_cols_sorted: bool
_padded: bool

def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape,
rows_sorted: bool = False, cols_sorted: bool = False):
rows_sorted: bool = False, cols_sorted: bool = False,
padded: bool = True):
self.data, self.row, self.col = map(jnp.asarray, args)
self._rows_sorted = rows_sorted
self._cols_sorted = cols_sorted
self._padded = padded
super().__init__(args, shape=shape)

@classmethod
Expand Down Expand Up @@ -131,7 +135,7 @@ def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> COO:
if axes is not None:
raise NotImplementedError("axes argument to transpose()")
return COO((self.data, self.col, self.row), shape=self.shape[::-1],
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted)
rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted, padded=self._padded)

def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]:
return (self.data, self.row, self.col), self._info._asdict()
Expand All @@ -140,11 +144,12 @@ def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]:
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.row, obj.col = children
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}:
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted', 'padded'}:
raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}")
obj.shape = aux_data['shape']
obj._rows_sorted = aux_data['rows_sorted']
obj._cols_sorted = aux_data['cols_sorted']
obj._padded = aux_data['padded']
return obj

def __matmul__(self, other: ArrayLike) -> Array:
Expand Down Expand Up @@ -207,6 +212,9 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
if spinfo.padded:
# GPU rule returns incorrect results with padded representation.
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)

if spinfo.rows_sorted:
shape = spinfo.shape
Expand Down Expand Up @@ -274,11 +282,12 @@ def coo_fromdense(mat: Array, *, nse: Optional[int] = None, index_dtype: DTypeLi
Returns:
mat_coo : COO representation of the matrix.
"""
padded = nse is not None
if nse is None:
nse = int((mat != 0).sum())
nse_int = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument")
return COO(_coo_fromdense(mat, nse=nse_int, index_dtype=index_dtype),
shape=mat.shape, rows_sorted=True)
shape=mat.shape, rows_sorted=True, padded=padded)

def _coo_fromdense(mat: Array, *, nse: int, index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array, Array]:
"""Create COO-format sparse matrix from a dense matrix.
Expand Down Expand Up @@ -446,8 +455,10 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo,
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
transpose=transpose)
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)
if spinfo.padded:
# GPU rule returns incorrect results with padded representation.
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose)

if spinfo.rows_sorted:
shape = spinfo.shape
Expand Down Expand Up @@ -569,8 +580,11 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo,
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
transpose=transpose)
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)
if spinfo.padded:
# GPU rule returns incorrect results with padded representation.
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose)

if spinfo.rows_sorted:
shape = spinfo.shape
elif spinfo.cols_sorted:
Expand Down
41 changes: 38 additions & 3 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def test_coo_sorted_indices_gpu_lowerings(self):
self.assertFalse(mat_cols_sorted._rows_sorted)
self.assertTrue(mat_cols_sorted._cols_sorted)

mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape)
mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape, padded=False)
self.assertFalse(mat_unsorted._rows_sorted)
self.assertFalse(mat_unsorted._cols_sorted)

Expand Down Expand Up @@ -582,10 +582,45 @@ def test_gpu_translation_rule(self):
)
def test_extra_nse(self, shape, dtype, mat_type):
rng = rand_sparse(self.rng())
rng_dense = jtu.rand_default(self.rng())
M = rng(shape, dtype)
nse = (M != 0).sum() + 5
M_out = mat_type.fromdense(M, nse=nse, index_dtype=jnp.int32).todense()
self.assertArraysEqual(M, M_out)
M_sp = mat_type.fromdense(M, nse=nse)

with self.subTest("todense"):
def todense1(M, _):
assert isinstance(M, np.ndarray)
return M
def todense2(_, M):
assert isinstance(M, mat_type)
return M.todense()
args_maker = lambda: [M, M_sp]
self._CheckAgainstNumpy(todense1, todense2, args_maker)
self._CompileAndCheck(todense2, args_maker)

with self.subTest("matvec"):
v = rng_dense(M.shape[-1:], dtype)
args_maker = lambda: [M, M_sp, v]
def matvec1(M, _, v):
assert isinstance(M, np.ndarray)
return M @ v
def matvec2(_, M, v):
assert isinstance(M, mat_type)
return M @ v
self._CheckAgainstNumpy(matvec1, matvec2, args_maker)
self._CompileAndCheck(matvec2, args_maker)

with self.subTest("matmat"):
B = rng_dense(M.shape[::-1], dtype)
args_maker = lambda: [M, M_sp, B]
def matmat1(M, _, B):
assert isinstance(M, np.ndarray)
return M @ B
def matmat2(_, M, B):
assert isinstance(M, mat_type)
return M @ B
self._CheckAgainstNumpy(matmat1, matmat2, args_maker)
self._CompileAndCheck(matmat2, args_maker)

@jtu.sample_product(
shape=[(5, 8), (8, 5), (5, 5), (8, 8)],
Expand Down

0 comments on commit c467d84

Please sign in to comment.