Skip to content

Commit

Permalink
[sparse] Add bcsr dot_general
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 496489368
  • Loading branch information
tlu7 authored and jax authors committed Dec 19, 2022
1 parent dbc3944 commit bc34af9
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 9 deletions.
2 changes: 2 additions & 0 deletions jax/experimental/sparse/__init__.py
Expand Up @@ -223,6 +223,8 @@
)

from jax.experimental.sparse.bcsr import (
bcsr_dot_general as bcsr_dot_general,
bcsr_dot_general_p as bcsr_dot_general_p,
bcsr_extract as bcsr_extract,
bcsr_extract_p as bcsr_extract_p,
bcsr_fromdense as bcsr_fromdense,
Expand Down
158 changes: 156 additions & 2 deletions jax/experimental/sparse/bcsr.py
Expand Up @@ -17,17 +17,24 @@

import operator

from typing import NamedTuple, Optional, Sequence, Tuple
from typing import NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np

from jax import core
from jax import lax
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse import bcoo
from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _csr_to_coo, Shape
from jax.experimental.sparse.util import (
_broadcasting_vmap, _count_stored_elements,
_csr_to_coo, _dot_general_validated_shape,
SparseInfo, Shape)
import jax.numpy as jnp
from jax._src import api_util
from jax._src.lax.lax import DotDimensionNumbers
from jax.util import split_list, safe_zip
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.typing import Array, ArrayLike, DTypeLike
Expand Down Expand Up @@ -297,6 +304,150 @@ def _bcsr_extract_abstract_eval(indices, indptr, mat):
_bcsr_extract_impl, multiple_results=False))


#----------------------------------------------------------------------
# bcsr_dot_general


bcsr_dot_general_p = core.Primitive('bcsr_dot_general')


def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None) -> Array:
"""A general contraction operation.
Args:
lhs: An ndarray or BCSR-format sparse array.
rhs: An ndarray or BCSR-format sparse array..
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
precision: unused
preferred_element_type: unused
Returns:
An ndarray or BCSR-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCSR. If either input is
dense, the result will be dense, of type ndarray.
"""
del precision, preferred_element_type # unused
if isinstance(rhs, (np.ndarray, jnp.ndarray)):
if isinstance(lhs, (np.ndarray, jnp.ndarray)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)

if isinstance(lhs, BCSR):
lhs_data, lhs_indices, lhs_indptr = lhs._bufs
return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs._info)

raise NotImplementedError("bcsr_dot_general currently implemented for BCSR "
"lhs and ndarray rhs.")


def _bcsr_dot_general(lhs_data: jnp.ndarray, lhs_indices: jnp.ndarray,
lhs_indptr: jnp.ndarray, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
lhs_spinfo: SparseInfo) -> Array:
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
return bcsr_dot_general_p.bind(jnp.asarray(lhs_data),
jnp.asarray(lhs_indices),
jnp.asarray(lhs_indptr), jnp.asarray(rhs),
dimension_numbers=(cdims, bdims),
lhs_spinfo=lhs_spinfo)


@bcsr_dot_general_p.def_impl
def _bcsr_dot_general_impl(lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
lhs_data = jnp.asarray(lhs_data)
lhs_bcsr_indices = jnp.asarray(lhs_indices)
lhs_bcsr_indptr = jnp.asarray(lhs_indptr)
rhs = jnp.asarray(rhs)
lhs_bcoo_indices = _bcsr_to_bcoo(lhs_bcsr_indices, lhs_bcsr_indptr,
shape=lhs_spinfo.shape)
return bcoo._bcoo_dot_general_impl(lhs_data, lhs_bcoo_indices, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=lhs_spinfo)


@bcsr_dot_general_p.def_abstract_eval
def _bcsr_dot_general_abstract_eval(lhs_data, lhs_indices, lhs_indptr, rhs, *,
dimension_numbers, lhs_spinfo):
if lhs_data.dtype != rhs.dtype:
raise ValueError("bcsr_dot_general requires arguments to have matching "
f"dtypes; got lhs.dtype={lhs_data.dtype}, "
f"rhs.dtype={rhs.dtype}")

(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
props = _validate_bcsr_indices(lhs_indices, lhs_indptr, lhs_spinfo.shape)
out_shape = _dot_general_validated_shape(lhs_spinfo.shape, rhs.shape,
dimension_numbers)

if lhs_batch and max(lhs_batch) >= props.n_batch:
raise NotImplementedError(
"bcsr_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
f"got {lhs_batch=}, {props.n_batch=}")

# TODO: support contraction of dense dimensions?
if any(d >= props.n_batch + 2 for d in lhs_contracting):
raise NotImplementedError("bcsr_dot_general: contracting over dense dimensions.")

return core.ShapedArray(out_shape, lhs_data.dtype)


# def _bcsr_dot_general_jvp_lhs(lhs_data_dot, lhs_data, lhs_indices, lhs_indptr,
# rhs, *, dimension_numbers, lhs_spinfo):
# del lhs_data
# return _bcsr_dot_general(lhs_data_dot, lhs_indices, lhs_indptr, rhs,
# dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)


# def _bcsr_dot_general_jvp_rhs(rhs_dot, lhs_data, lhs_indices, lhs_indptr, rhs,
# *, dimension_numbers, lhs_spinfo):
# del rhs
# return _bcsr_dot_general(lhs_data, lhs_indices, lhs_indptr, rhs_dot,
# dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)


# def _bcsr_dot_general_transpose(ct, lhs_data, lhs_indices, lhs_inptr, rhs, *,
# dimension_numbers, lhs_spinfo):
# lhs_bcoo_indices = _bcsr_to_bcoo(
# lhs_indices, lhs_inptr, shape=lhs_spinfo.shape)
# return bcoo._bcoo_dot_general_transpose(
# ct, lhs_data, lhs_bcoo_indices, rhs, dimension_numbers=dimension_numbers,
# lhs_spinfo=lhs_spinfo)


# def _bcsr_dot_general_batch_rule(batched_args, batch_dims, *,
# dimension_numbers, lhs_spinfo):
# lhs_data, lhs_indices, lhs_indptr, rhs = batched_args
# lhs_bcoo_indices = _bcsr_to_bcoo(
# lhs_indices, lhs_indptr, shape=lhs_spinfo.shape)
# return bcoo._bcoo_dot_general_batch_rule(
# (lhs_data, lhs_bcoo_indices, rhs), batch_dims,
# dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)


# ad.defjvp(bcsr_dot_general_p, _bcsr_dot_general_jvp_lhs, None,
# _bcsr_dot_general_jvp_rhs)
# ad.primitive_transposes[bcsr_dot_general_p] = _bcsr_dot_general_transpose
# batching.primitive_batchers[bcsr_dot_general_p] = _bcsr_dot_general_batch_rule


_bcsr_dot_general_default_lowering = mlir.lower_fun(
_bcsr_dot_general_impl, multiple_results=False)
mlir.register_lowering(
bcsr_dot_general_p, _bcsr_dot_general_default_lowering)


@tree_util.register_pytree_node_class
class BCSR(JAXSparse):
"""Experimental batched CSR matrix implemented in JAX."""
Expand All @@ -310,6 +461,8 @@ class BCSR(JAXSparse):
n_batch = property(lambda self: self.indices.ndim - 1)
n_sparse = property(lambda _: 2)
n_dense = property(lambda self: self.data.ndim - self.indices.ndim)
_bufs = property(lambda self: (self.data, self.indices, self.indptr))
_info = property(lambda self: SparseInfo(self.shape))

@property
def _sparse_shape(self):
Expand Down Expand Up @@ -345,6 +498,7 @@ def transpose(self, *args, **kwargs):
raise NotImplementedError("Tranpose is not implemented.")

def tree_flatten(self):
# TODO(tianjianlu): Unflatten SparseInfo with self._info._asdict().
return (self.data, self.indices, self.indptr), {'shape': self.shape}

@classmethod
Expand Down
45 changes: 38 additions & 7 deletions tests/sparse_test.py
Expand Up @@ -74,13 +74,15 @@
[(2,), (2,)]
]

class BcooDotGeneralProperties(NamedTuple):

class BatchedDotGeneralProperties(NamedTuple):
lhs_shape: Tuple[int, ...]
rhs_shape: Tuple[int, ...]
n_batch: int
n_dense: int
dimension_numbers: DotDimensionNumbers


def _iter_subsets(s: Sequence) -> Iterable[Tuple]:
"""Return an iterator over all subsets of a sequence s"""
return itertools.chain.from_iterable(itertools.combinations(s, n) for n in range(len(s) + 1))
Expand All @@ -99,12 +101,19 @@ def iter_sparse_layouts(shape: Sequence[int], min_n_batch=0) -> Iterator[SparseL
yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense)


def _generate_bcoo_dot_general_properties(shapes=((5,), (2, 3), (2, 3, 4), (2, 3, 4, 4))) -> BcooDotGeneralProperties:
def _generate_batched_dot_general_properties(
shapes=((5,), (2, 3), (2, 3, 4), (2, 3, 4, 4)),
sparse_format='bcoo') -> BatchedDotGeneralProperties:
"""Generator of properties for bcoo_dot_general tests."""
rng = random.Random(0)

if sparse_format not in ['bcoo', 'bcsr']:
raise ValueError(f"Sparse format {sparse_format} not supported.")

for shape in shapes:
for layout in iter_sparse_layouts(shape):
if sparse_format == "bcsr" and layout.n_sparse != 2:
continue
subsets = split_list(range(len(shape)), [layout.n_batch, layout.n_sparse])
for batch_dims in _iter_subsets(range(layout.n_batch)):
for contracting_dims in _iter_subsets(remaining(range(layout.n_batch + layout.n_sparse), batch_dims)):
Expand All @@ -113,7 +122,7 @@ def _generate_bcoo_dot_general_properties(shapes=((5,), (2, 3), (2, 3, 4), (2, 3
rhs_permute = rng.sample(range(len(shape)), len(shape))
lhs_permute = list(itertools.chain.from_iterable(
rng.sample(subset, len(subset)) for subset in subsets))
yield BcooDotGeneralProperties(
yield BatchedDotGeneralProperties(
lhs_shape=tuple(shape[p] for p in lhs_permute),
rhs_shape=tuple(shape[p] for p in rhs_permute),
n_batch=layout.n_batch,
Expand Down Expand Up @@ -141,6 +150,7 @@ def _rand_sparse(shape, dtype, nse=nse):
return post(M)
return _rand_sparse


def _is_required_cuda_version_satisfied(cuda_version):
version = xla_bridge.get_backend().platform_version
if version == "<unknown>" or version.split()[0] == "rocm":
Expand Down Expand Up @@ -995,11 +1005,11 @@ def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense):
self.assertAllClose(M3, M4)

@jtu.sample_product(
props=_generate_bcoo_dot_general_properties(),
props=_generate_batched_dot_general_properties(),
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
@jax.default_matmul_precision("float32")
def test_bcoo_dot_general(self, dtype: np.dtype, props: BcooDotGeneralProperties):
def test_bcoo_dot_general(self, dtype: np.dtype, props: BatchedDotGeneralProperties):
rng = jtu.rand_default(self.rng())
sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense)
args_maker = lambda: [sprng(props.lhs_shape, dtype),
Expand Down Expand Up @@ -1204,11 +1214,11 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)

@jtu.sample_product(
props=_generate_bcoo_dot_general_properties(),
props=_generate_batched_dot_general_properties(),
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
@jax.default_matmul_precision("float32")
def test_bcoo_rdot_general(self, dtype: np.dtype, props: BcooDotGeneralProperties):
def test_bcoo_rdot_general(self, dtype: np.dtype, props: BatchedDotGeneralProperties):
rng = jtu.rand_default(self.rng())
sprng = sptu.rand_bcoo(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense)
args_maker = lambda: [rng(props.rhs_shape, dtype),
Expand Down Expand Up @@ -2266,6 +2276,27 @@ def test_bcsr_extract(self, shape, dtype, n_batch):
args_maker_bcsr_extract = lambda: [indices, indptr, M]
self._CompileAndCheck(sparse.bcsr_extract, args_maker_bcsr_extract)

@jtu.sample_product(
props=_generate_batched_dot_general_properties(
shapes=((2, 3), (2, 3, 4), (2, 3, 4, 4)), sparse_format='bcsr'),
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
@jax.default_matmul_precision("float32")
def test_bcsr_dot_general(self, dtype: np.dtype, props: BatchedDotGeneralProperties):
rng = jtu.rand_default(self.rng())
sprng = sptu.rand_bcsr(self.rng(), n_batch=props.n_batch, n_dense=props.n_dense)
args_maker = lambda: [sprng(props.lhs_shape, dtype),
rng(props.rhs_shape, dtype)]
dense_fun = partial(lax.dot_general,
dimension_numbers=props.dimension_numbers)
sparse_fun = partial(sparse.bcsr_dot_general,
dimension_numbers=props.dimension_numbers)

tol = {np.float64: 1E-12, np.complex128: 1E-12,
np.float32: 1E-5, np.complex64: 1E-5}

self._CheckAgainstDense(dense_fun, sparse_fun, args_maker, tol=tol)
self._CompileAndCheckSparse(sparse_fun, args_maker, atol=tol, rtol=tol)

class SparseGradTest(sptu.SparseTestCase):
def test_sparse_grad(self):
Expand Down

0 comments on commit bc34af9

Please sign in to comment.