Skip to content

Commit

Permalink
[sparse] Add BCSR format template.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 477013899
  • Loading branch information
tlu7 authored and jax authors committed Sep 26, 2022
1 parent 82636b0 commit 71bcabe
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
3 changes: 3 additions & 0 deletions jax/experimental/sparse/__init__.py
Expand Up @@ -217,6 +217,9 @@
BCOO as BCOO,
)

from jax.experimental.sparse.bcsr import (
BCSR as BCSR,
)
from jax.experimental.sparse.api import (
empty as empty,
eye as eye,
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/sparse/api.py
Expand Up @@ -37,6 +37,7 @@
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.bcoo import BCOO
from jax.experimental.sparse.bcsr import BCSR
from jax.experimental.sparse.coo import COO
from jax.experimental.sparse.csr import CSR, CSC
from jax.experimental.sparse.util import _coo_extract
Expand Down Expand Up @@ -116,7 +117,7 @@ def empty(shape, dtype=None, index_dtype='int32', sparse_format='bcoo', **kwds):
Returns:
mat: empty sparse matrix.
"""
formats = {'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC}
formats = {'bcsr': BCSR, 'bcoo': BCOO, 'coo': COO, 'csr': CSR, 'csc': CSC}
if sparse_format not in formats:
raise ValueError(f"sparse_format={sparse_format!r} not recognized; "
f"must be one of {list(formats.keys())}")
Expand Down
93 changes: 93 additions & 0 deletions jax/experimental/sparse/bcsr.py
@@ -0,0 +1,93 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BCSR (Bached compressed row) matrix object and associated primitives."""

from typing import Tuple

from jax import core
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _safe_asarray
import jax.numpy as jnp
from jax.util import split_list

Shape = Tuple[int, ...]


class BCSR(JAXSparse):
"""Experimental batched CSR matrix implemented in JAX."""

data: jnp.ndarray
indices: jnp.ndarray
indptr: jnp.ndarray
shape: Shape
nse = property(lambda self: self.indices.shape[-1])
dtype = property(lambda self: self.data.dtype)
n_batch = property(lambda self: self.indices.ndim - 1)
n_sparse = property(lambda _: 2)
n_dense = property(lambda self: self.data.ndim - self.indices.ndim)

@property
def _sparse_shape(self):
return tuple(self.shape[self.n_batch:self.n_batch + 2])

def __init__(self, args, *, shape):
# JAX transforms will sometimes instantiate pytrees with null values, so we
# must catch that in the initialization of inputs.
self.data, self.indices, self.indptr = _safe_asarray(args)
super().__init__(args, shape=shape)

def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
n_batch = self.n_batch
n_dense = self.n_dense
dtype = self.dtype
shape = list(self.shape)
except Exception: # pylint: disable=broad-except
repr_ = f"{name}(<invalid>)"
else:
extra = f", nse={nse}"
if n_batch: extra += f", n_batch={n_batch}"
if n_dense: extra += f", n_dense={n_dense}"
repr_ = f"{name}({dtype}{shape}{extra})"
if isinstance(self.data, core.Tracer):
repr_ = f"{type(self.data).__name__}[{repr_}]"
return repr_

def transpose(self, *args, **kwargs):
raise NotImplementedError("Tranpose is not implemented.")

def tree_flatten(self):
return (self.data, self.indices, self.indptr), {}

@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32', n_dense=0,
n_batch=0, nse=0):
"""Create an empty BCSR instance. Public method is sparse.empty()."""
shape = tuple(shape)
if n_dense < 0 or n_batch < 0 or nse < 0:
raise ValueError(f"Invalid inputs: shape={shape}, n_dense={n_dense},"
f"n_batch={n_batch}, nse={nse}")
n_sparse = len(shape) - n_dense - n_batch
if n_sparse != 2:
raise ValueError("BCSR sparse.empty: must have 2 sparse dimensions.")
batch_shape, sparse_shape, dense_shape = split_list(shape,
[n_batch, n_sparse])
data = jnp.zeros((*batch_shape, nse, *dense_shape), dtype)
indices = jnp.full((*batch_shape, nse), jnp.array(sparse_shape[1]),
index_dtype)
indptr = jnp.zeros((*batch_shape, sparse_shape[0] + 1), index_dtype)
return cls((data, indices, indptr), shape=shape)
12 changes: 12 additions & 0 deletions tests/sparse_test.py
Expand Up @@ -2275,6 +2275,18 @@ def f(X, y):


class SparseObjectTest(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": f"_{cls.__name__}", "cls": cls}
for cls in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO, sparse.BCSR])
def test_pytree_flattening(self, cls):
sparse_format = cls.__name__.lower()
M = sparse.empty((2, 4), sparse_format=sparse_format)
self.assertIsInstance(M, cls)
buffers, tree = tree_util.tree_flatten(M)
M_out = tree_util.tree_unflatten(tree, buffers)
self.assertEqual(M.dtype, M_out.dtype)
self.assertEqual(M.shape, M_out.shape)
self.assertEqual(M.nse, M_out.nse)

@parameterized.named_parameters(
{"testcase_name": f"_{cls.__name__}", "cls": cls}
Expand Down

0 comments on commit 71bcabe

Please sign in to comment.