Skip to content

Commit

Permalink
[sparse] Add conversions between BCSR and BCOO.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 478816413
  • Loading branch information
tlu7 authored and jax authors committed Oct 4, 2022
1 parent 37f9db7 commit ae49d2e
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 3 deletions.
21 changes: 21 additions & 0 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -138,6 +138,27 @@ def _validate_bcoo_indices(indices: jnp.ndarray, shape: Sequence[int]) -> BCOOPr
return BCOOProperties(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense, nse=nse)


def _bcoo_to_bcsr(indices: jnp.ndarray, *, shape: Sequence[int],
index_dtype=jnp.int32):
"""Given BCOO (indices), return BCSR (indices, indptr)."""
n_batch, n_sparse, _, _ = _validate_bcoo_indices(indices, shape)

if n_sparse != 2:
raise ValueError("Must have 2 sparse dimensions to be converted to BCSR.")

n_rows = shape[n_batch]

def get_ptr(i):
indptr = jnp.zeros(n_rows + 1, index_dtype)
return indptr.at[1:].set(jnp.cumsum(
jnp.bincount(i, length=n_rows).astype(index_dtype)))

for _ in range(n_batch):
get_ptr = vmap(get_ptr)

return indices[..., 1], get_ptr(indices[..., 0])


#----------------------------------------------------------------------
# bcoo_todense

Expand Down
67 changes: 64 additions & 3 deletions jax/experimental/sparse/bcsr.py
Expand Up @@ -14,17 +14,78 @@

"""BCSR (Bached compressed row) matrix object and associated primitives."""

from typing import Tuple
from typing import NamedTuple, Sequence, Tuple

from jax import core
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _safe_asarray
from jax.experimental.sparse.util import _broadcasting_vmap, _csr_to_coo, _safe_asarray
import jax.numpy as jnp
from jax.util import split_list
from jax.util import split_list, safe_zip

Shape = Tuple[int, ...]


class BCSRProperties(NamedTuple):
n_batch: int
n_dense: int
nse: int


def _compatible(shape1, shape2):
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))


def _validate_bcsr_indices(indices: jnp.ndarray, indptr: jnp.ndarray,
shape: Sequence[int]) -> BCSRProperties:
assert jnp.issubdtype(indices.dtype, jnp.integer)
assert jnp.issubdtype(indptr.dtype, jnp.integer)
shape = tuple(shape)

nse = indices.shape[-1]
n_batch = indices.ndim - 1
n_dense = len(shape) - n_batch - 2
assert n_dense >= 0

if not _compatible(indices.shape[:n_batch], shape[:n_batch]):
raise ValueError("indices batch dimensions not compatible for "
f"indices.shape={indices.shape}, shape={shape}")
if not _compatible(indptr.shape[:n_batch], shape[:n_batch]):
raise ValueError("indptr batch dimensions not compatible for "
f"indptr.shape={indptr.shape}, shape={shape}")
if indptr.shape[n_batch:] != (shape[n_batch] + 1,):
raise ValueError("indptr shape must match the matrix shape plus 1.")

return BCSRProperties(n_batch=n_batch, n_dense=n_dense, nse=nse)


def _validate_bcsr(data: jnp.ndarray, indices: jnp.ndarray,
indptr: jnp.ndarray, shape: Sequence[int]) -> BCSRProperties:
props = _validate_bcsr_indices(indices, indptr, shape)
shape = tuple(shape)
n_batch, n_dense, nse = props.n_batch, props.n_dense, props.nse
n_sparse = data.ndim - n_batch - n_dense
if n_sparse != 2:
raise ValueError("BCSR array must have 2 sparse dimensions; "
f"{n_sparse} is given.")
if not _compatible(data.shape[:n_batch], shape[:n_batch]):
raise ValueError("data batch dimensions not compatible for "
f"data.shape={data.shape}, shape={shape}")
if data.shape[-(n_dense + 1):] != (nse,) + shape[n_batch + 2:]:
raise ValueError(f"Invalid data.shape={data.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
return props


def _bcsr_to_bcoo(indices: jnp.ndarray, indptr: jnp.ndarray, *,
shape: Sequence[int]) -> jnp.ndarray:
"""Given BCSR (indices, indptr), return BCOO (indices)."""
n_batch, _, _ = _validate_bcsr_indices(indices, indptr, shape)
csr_to_coo = _csr_to_coo
for _ in range(n_batch):
csr_to_coo = _broadcasting_vmap(csr_to_coo)
return jnp.stack(csr_to_coo(indices, indptr), axis=indices.ndim)


class BCSR(JAXSparse):
"""Experimental batched CSR matrix implemented in JAX."""

Expand Down
35 changes: 35 additions & 0 deletions tests/sparse_test.py
Expand Up @@ -31,6 +31,7 @@
from jax.experimental import sparse
from jax.experimental.sparse import coo as sparse_coo
from jax.experimental.sparse import bcoo as sparse_bcoo
from jax.experimental.sparse import bcsr as sparse_bcsr
from jax.experimental.sparse.bcoo import BCOOInfo
from jax import lax
from jax._src.lib import xla_extension_version
Expand Down Expand Up @@ -2498,6 +2499,40 @@ def test_bcoo_methods(self):
self.assertArraysEqual(M.sum(1), Msp.sum(1).todense())
self.assertArraysEqual(M.sum(), Msp.sum())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch),
"shape": shape, "dtype": dtype, "n_batch": n_batch}
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for n_batch in range(len(shape) - 1)))
def test_bcoo_to_bcsr_round_trip(self, shape, dtype, n_batch):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
n_dense = len(shape) - 2 - n_batch
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
n_dense=n_dense)
_, bcoo_indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch,
n_dense=n_dense)

bcoo_to_bcsr = partial(sparse_bcoo._bcoo_to_bcsr, shape=shape)

args_maker_bcoo_to_bcsr = lambda: [bcoo_indices]
self._CompileAndCheck(bcoo_to_bcsr, args_maker_bcoo_to_bcsr)

bcsr_indices, indptr = bcoo_to_bcsr(bcoo_indices)
bcsr_indices_jit, indptr_jit = jit(bcoo_to_bcsr)(bcoo_indices)

self.assertEqual(bcsr_indices.dtype, jnp.int32)
self.assertEqual(bcsr_indices.shape, shape[:n_batch] + (nse,))
self.assertEqual(indptr.dtype, jnp.int32)
self.assertEqual(indptr.shape, shape[:n_batch] + (shape[n_batch] + 1,))

bcsr_to_bcoo = partial(sparse_bcsr._bcsr_to_bcoo, shape=shape)
self.assertArraysEqual(bcoo_indices, bcsr_to_bcoo(bcsr_indices, indptr))
args_maker_bcsr_to_bcoo = lambda: [bcsr_indices, indptr]
self._CompileAndCheck(bcsr_to_bcoo, args_maker_bcsr_to_bcoo)


class SparseRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
Expand Down

0 comments on commit ae49d2e

Please sign in to comment.