diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py index 1887c4cc0dbc..042b40ee0253 100644 --- a/jax/experimental/sparse/__init__.py +++ b/jax/experimental/sparse/__init__.py @@ -217,6 +217,9 @@ BCOO as BCOO, ) +from jax.experimental.sparse.bcsr import ( + BCSR as BCSR, +) from jax.experimental.sparse.api import ( empty as empty, eye as eye, diff --git a/jax/experimental/sparse/api.py b/jax/experimental/sparse/api.py index 49551fadcf8d..8d1a3356adec 100644 --- a/jax/experimental/sparse/api.py +++ b/jax/experimental/sparse/api.py @@ -37,6 +37,7 @@ from jax import tree_util from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.bcoo import BCOO +from jax.experimental.sparse.bcsr import BCSR from jax.experimental.sparse.coo import COO from jax.experimental.sparse.csr import CSR, CSC from jax.experimental.sparse.util import _coo_extract @@ -116,7 +117,7 @@ def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds): Returns: mat: empty sparse matrix. """ - formats = {'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC} + formats = {'bcsr': BCSR, 'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC} if sparse_format not in formats: raise ValueError(f"sparse_format={sparse_format!r} not recognized; " f"must be one of {list(formats.keys())}") diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py new file mode 100644 index 000000000000..f22d36b9c764 --- /dev/null +++ b/jax/experimental/sparse/bcsr.py @@ -0,0 +1,93 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BCSR (Bached compressed row) matrix object and associated primitives.""" + +from typing import Tuple + +from jax import core +from jax.experimental.sparse._base import JAXSparse +from jax.experimental.sparse.util import _safe_asarray +import jax.numpy as jnp +from jax.util import split_list + +Shape = Tuple[int, ...] + + +class BCSR(JAXSparse): + """Experimental batched CSR matrix implemented in JAX.""" + + data: jnp.ndarray + indices: jnp.ndarray + indptr: jnp.ndarray + shape: Shape + nse = property(lambda self: self.indices.shape[-1]) + dtype = property(lambda self: self.data.dtype) + n_batch = property(lambda self: self.indices.ndim - 1) + n_sparse = property(lambda _: 2) + n_dense = property(lambda self: self.data.ndim - self.indices.ndim) + + @property + def _sparse_shape(self): + return tuple(self.shape[self.n_batch:self.n_batch + 2]) + + def __init__(self, args, *, shape): + # JAX transforms will sometimes instantiate pytrees with null values, so we + # must catch that in the initialization of inputs. + self.data, self.indices, self.indptr = _safe_asarray(args) + super().__init__(args, shape=shape) + + def __repr__(self): + name = self.__class__.__name__ + try: + nse = self.nse + n_batch = self.n_batch + n_dense = self.n_dense + dtype = self.dtype + shape = list(self.shape) + except Exception: # pylint: disable=broad-except + repr_ = f"{name}()" + else: + extra = f", nse={nse}" + if n_batch: extra += f", n_batch={n_batch}" + if n_dense: extra += f", n_dense={n_dense}" + repr_ = f"{name}({dtype}{shape}{extra})" + if isinstance(self.data, core.Tracer): + repr_ = f"{type(self.data).__name__}[{repr_}]" + return repr_ + + def transpose(self, *args, **kwargs): + raise NotImplementedError("Tranpose is not implemented.") + + def tree_flatten(self): + return (self.data, self.indices, self.indptr), {} + + @classmethod + def _empty(cls, shape, *, dtype=None, index_dtype='int32', n_dense=0, + n_batch=0, nse=0): + """Create an empty BCSR instance. Public method is sparse.empty().""" + shape = tuple(shape) + if n_dense < 0 or n_batch < 0 or nse < 0: + raise ValueError(f"Invalid inputs: shape={shape}, n_dense={n_dense}," + f"n_batch={n_batch}, nse={nse}") + n_sparse = len(shape) - n_dense - n_batch + if n_sparse != 2: + raise ValueError("BCSR sparse.empty: must have 2 sparse dimensions.") + batch_shape, sparse_shape, dense_shape = split_list(shape, + [n_batch, n_sparse]) + data = jnp.zeros((*batch_shape, nse, *dense_shape), dtype) + indices = jnp.full((*batch_shape, nse), jnp.array(sparse_shape[1]), + index_dtype) + indptr = jnp.zeros((*batch_shape, sparse_shape[0] + 1), index_dtype) + return cls((data, indices, indptr), shape=shape) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 5fa672ed85d7..c38568c051e2 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -2275,6 +2275,18 @@ def f(X, y): class SparseObjectTest(jtu.JaxTestCase): + @parameterized.named_parameters( + {"testcase_name": f"_{cls.__name__}", "cls": cls} + for cls in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO, sparse.BCSR]) + def test_pytree_flattening(self, cls): + sparse_format = cls.__name__.lower() + M = sparse.empty((2, 4), sparse_format=sparse_format) + self.assertIsInstance(M, cls) + buffers, tree = tree_util.tree_flatten(M) + M_out = tree_util.tree_unflatten(tree, buffers) + self.assertEqual(M.dtype, M_out.dtype) + self.assertEqual(M.shape, M_out.shape) + self.assertEqual(M.nse, M_out.nse) @parameterized.named_parameters( {"testcase_name": f"_{cls.__name__}", "cls": cls}