Skip to content

Commit

Permalink
[sparse] Propagate SparseInfo to BCSR todense() and tree_(un)flatten().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 496945167
  • Loading branch information
tlu7 authored and jax authors committed Dec 21, 2022
1 parent 4f75ad6 commit e89b60e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
31 changes: 17 additions & 14 deletions jax/experimental/sparse/bcsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):


@bcsr_fromdense_p.def_abstract_eval
def _bcoo_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
def _bcsr_fromdense_abstract_eval(mat, *, nse, n_batch, n_dense, index_dtype):
n_sparse = mat.ndim - n_batch - n_dense
if n_sparse != 2:
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
Expand Down Expand Up @@ -218,39 +218,43 @@ def bcsr_todense(mat: BCSR) -> Array:
Returns:
The dense version of ``mat``.
"""
return _bcsr_todense(mat.data, mat.indices, mat.indptr,
shape=tuple(mat.shape))
return _bcsr_todense(mat.data, mat.indices, mat.indptr, spinfo=mat._info)


def _bcsr_todense(data: ArrayLike, indices: ArrayLike, indptr: ArrayLike, *, shape: Shape) -> Array:
def _bcsr_todense(data: ArrayLike, indices: ArrayLike, indptr: ArrayLike, *,
spinfo: SparseInfo) -> Array:
"""Convert batched sparse matrix to a dense matrix.
Args:
data : array of shape ``batch_dims + (nse,) + dense_dims``.
indices : array of shape ``batch_dims + (nse,)``.
indptr : array of shape ``batch_dims + (shape[len(batch_dims)] + 1,).
shape : tuple; the shape of the (batched) matrix. Equal to
``batch_dims + 2(sparse_dims) + dense_dims``
spinfo : SparseInfo. In particular, this includes the shape
of the matrix, which is equal to
``batch_dims + 2(sparse_dims) + block_dims`` where
``len(sparse_dims) == 2``.
Returns:
mat : array with specified shape and dtype matching ``data``
"""
return bcsr_todense_p.bind(jnp.asarray(data), jnp.asarray(indices),
jnp.asarray(indptr), shape=shape)
jnp.asarray(indptr), spinfo=spinfo)


@bcsr_todense_p.def_impl
def _bcsr_todense_impl(data, indices, indptr, *, shape):
def _bcsr_todense_impl(data, indices, indptr, *, spinfo):
shape = spinfo.shape
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=shape)
return (bcoo.BCOO((data, bcoo_indices), shape=shape)).todense()


@bcsr_todense_p.def_abstract_eval
def _bcsr_todense_abstract_eval(data, indices, indptr, *, shape):
def _bcsr_todense_abstract_eval(data, indices, indptr, *, spinfo):
shape = spinfo.shape
_validate_bcsr(data, indices, indptr, shape)
return core.ShapedArray(shape, data.dtype)


def _bcsr_todense_batching_rule(batched_args, batch_dims, *, shape):
def _bcsr_todense_batching_rule(batched_args, batch_dims, *, spinfo):
data, indices, indptr = batched_args
if any(b not in [0, None] for b in batch_dims):
raise NotImplementedError(f"{batch_dims=}. Only 0 and None are supported.")
Expand All @@ -260,7 +264,7 @@ def _bcsr_todense_batching_rule(batched_args, batch_dims, *, shape):
indices = indices[None, ...]
if batch_dims[2] is None:
indptr = indptr[None, ...]
return _bcsr_todense(data, indices, indptr, shape=shape), 0
return _bcsr_todense(data, indices, indptr, spinfo=spinfo), 0

batching.primitive_batchers[bcsr_todense_p] = _bcsr_todense_batching_rule
mlir.register_lowering(bcsr_todense_p, mlir.lower_fun(
Expand Down Expand Up @@ -498,14 +502,13 @@ def transpose(self, *args, **kwargs):
raise NotImplementedError("Tranpose is not implemented.")

def tree_flatten(self):
# TODO(tianjianlu): Unflatten SparseInfo with self._info._asdict().
return (self.data, self.indices, self.indptr), {'shape': self.shape}
return (self.data, self.indices, self.indptr), self._info._asdict()

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape'}:
if aux_data.keys() != {'shape', 'indices_sorted', 'unique_indices'}:
raise ValueError(f"BCSR.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj
Expand Down
6 changes: 4 additions & 2 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,7 +2172,8 @@ def test_bcsr_dense_round_trip(self, shape, dtype, n_batch):
self.assertEqual(indptr.dtype, jnp.int32)
self.assertEqual(indptr.shape, shape[:n_batch] + (shape[n_batch] + 1,))

todense = partial(sparse_bcsr._bcsr_todense, shape=shape)
todense = partial(sparse_bcsr._bcsr_todense,
spinfo=sparse_util.SparseInfo(shape=shape))
self.assertArraysEqual(M, todense(data, indices, indptr))
args_maker_todense = lambda: [data, indices, indptr]
self._CompileAndCheck(todense, args_maker_todense)
Expand All @@ -2195,7 +2196,8 @@ def test_bcsr_dense_round_trip_batched(self, shape, dtype, n_batch):

fromdense = partial(sparse_bcsr._bcsr_fromdense, nse=nse, n_batch=0,
n_dense=n_dense)
todense = partial(sparse_bcsr._bcsr_todense, shape=shape)
todense = partial(sparse_bcsr._bcsr_todense,
spinfo=sparse_util.SparseInfo(shape))

for _ in range(n_batch):
fromdense = jax.vmap(fromdense)
Expand Down

0 comments on commit e89b60e

Please sign in to comment.