Skip to content

[sparse] Compatiable jax>=0.9.2, decouple SparseMatrix from JAX and drop block-sparse impls#65

Merged
chaoming0625 merged 3 commits intomainfrom
update
Mar 20, 2026
Merged

[sparse] Compatiable jax>=0.9.2, decouple SparseMatrix from JAX and drop block-sparse impls#65
chaoming0625 merged 3 commits intomainfrom
update

Conversation

@chaoming0625
Copy link
Copy Markdown
Member

@chaoming0625 chaoming0625 commented Mar 20, 2026

Summary by Sourcery

Decouple the internal sparse matrix base class from JAX's JAXSparse, enhance its core interface, and tighten validation and interoperability across CSR, CSC, and COO sparse formats while removing unused block-sparse implementations.

Enhancements:

  • Redefine SparseMatrix as a lightweight, framework-independent base class with core shape, repr, transpose, and JAX PyTree-related hooks.
  • Strengthen CSR, CSC, and COO with_data methods to raise informative errors on shape, dtype, and unit mismatches instead of using bare assertions.
  • Extend binary and matmul operation guards to handle interactions with both JAXSparse and the new SparseMatrix base class.
  • Adjust CSR and CSC PyTree tree_unflatten logic to reconstruct key attributes explicitly rather than via dict updates.

Chores:

  • Remove legacy block sparse matrix implementations and benchmarks from both brainunit and saiunit packages.
  • Bump saiunit package version from 0.2.0 to 0.2.1.

@sourcery-ai
Copy link
Copy Markdown
Contributor

sourcery-ai bot commented Mar 20, 2026

Reviewer's Guide

Refactors the SparseMatrix base away from JAX’s JAXSparse, adds core container/utility behavior, updates CSR/CSC/COO implementations to interoperate with the new base and improve validation and tree_unflatten, and removes unused block-sparse modules while bumping the saiunit package version.

Class diagram for updated SparseMatrix base and sparse formats

classDiagram
    class SparseMatrix {
        +jax.Array data
        +tuple~int~ shape
        +property nse
        +property dtype
        +__hash__
        +SparseMatrix(args, shape)
        +__len__() int
        +size int
        +ndim int
        +__repr__() str
        +T SparseMatrix
        +block_until_ready() SparseMatrix
        +tree_flatten()
        +tree_unflatten(aux_data, children) SparseMatrix
        +transpose(axes)
        +todense()
        +with_data(data)
    }

    class CSR {
        +jax.Array data
        +jax.Array indices
        +jax.Array indptr
        +tuple~int~ shape
        +with_data(data) CSR
        +todense()
        +_binary_op(other, op)
        +_binary_rop(other, op)
        +__matmul__(other)
        +__rmatmul__(other)
        +tree_unflatten(aux_data, children) CSR
    }

    class CSC {
        +jax.Array data
        +jax.Array indices
        +jax.Array indptr
        +tuple~int~ shape
        +with_data(data) CSC
        +todense()
        +_binary_op(other, op)
        +_binary_rop(other, op)
        +__matmul__(other)
        +__rmatmul__(other)
        +tree_unflatten(aux_data, children) CSC
    }

    class COO {
        +jax.Array data
        +jax.Array row
        +jax.Array col
        +tuple~int~ shape
        +with_data(data) COO
        +todense()
        +_binary_op(other, op)
        +_binary_rop(other, op)
        +__matmul__(other)
        +__rmatmul__(other)
    }

    class JAXSparse

    SparseMatrix <|-- CSR
    SparseMatrix <|-- CSC
    SparseMatrix <|-- COO

    CSR ..> JAXSparse : runtime_checks
    CSC ..> JAXSparse : runtime_checks
    COO ..> JAXSparse : runtime_checks

    CSR ..> SparseMatrix : runtime_checks
    CSC ..> SparseMatrix : runtime_checks
    COO ..> SparseMatrix : runtime_checks
Loading

Flow diagram for with_data validation in sparse matrices

flowchart TD
    A[call with_data on sparse matrix] --> B{shape matches self.data.shape}
    B -- no --> E[raise ValueError Shape mismatch]
    B -- yes --> C{dtype matches self.data.dtype}
    C -- no --> F[raise ValueError Dtype mismatch]
    C -- yes --> D{unit matches self.data unit}
    D -- no --> G[raise ValueError Unit mismatch]
    D -- yes --> H[construct new instance with new data and existing indices and structure]
    H --> I[return new sparse matrix instance]
Loading

File-Level Changes

Change Details Files
Refactor SparseMatrix to be a lightweight base class independent of jax.experimental.sparse.JAXSparse, adding core attributes, properties, and utility methods.
  • Remove inheritance from JAXSparse and ABC, making SparseMatrix a plain Python base class
  • Document data, shape, nse, and dtype as class attributes and set hash = None to make instances unhashable
  • Implement init to accept args and shape, storing shape as a tuple[int, ...]
  • Add len, size, ndim, and T convenience methods plus a human-readable repr that uses nse, dtype, and shape when available
  • Introduce block_until_ready delegating to children from tree_flatten, and define abstract tree_flatten, tree_unflatten, transpose, and todense methods that raise NotImplementedError
saiunit/_sparse_base.py
Strengthen validation in with_data for CSR/CSC/COO and adjust sparse binary/matmul operations to treat SparseMatrix like JAXSparse for unsupported-sparse-sparse operations.
  • Replace assertion-based checks in with_data with explicit shape, dtype, and unit validations that raise ValueError with descriptive messages
  • Extend isinstance checks in _binary_op, _binary_rop, matmul, and rmatmul to treat both JAXSparse and SparseMatrix as sparse objects that are not yet supported for binary or matmul operations
  • Update CSR and CSC tree_unflatten to validate aux_data keys, fix error message name for CSC, and assign shape/indices/indptr explicitly instead of updating dict wholesale
saiunit/sparse/_csr.py
saiunit/sparse/_csc.py
saiunit/sparse/_coo.py
Clean up block-sparse experimental modules and bump saiunit version.
  • Remove block CSR/ELL implementations, benchmarks, and tests from both brainunit and saiunit sparse subpackages
  • Increment version from 0.2.0 to 0.2.1 to reflect the API/behavior changes
brainunit/brainunit/sparse/_block_csr.py
brainunit/brainunit/sparse/_block_ell.py
saiunit/sparse/_block_csr.py
saiunit/sparse/_block_csr_benchmark.py
saiunit/sparse/_block_ell.py
saiunit/sparse/_block_ell_benchmark.py
saiunit/sparse/_block_sparse_test.py
saiunit/_version.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625
Copy link
Copy Markdown
Member Author

@sourcery-ai title

@sourcery-ai sourcery-ai bot changed the title Update [sparse] Decouple SparseMatrix from JAX and drop block-sparse impls Mar 20, 2026
Copy link
Copy Markdown
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 4 issues, and left some high level feedback:

  • In SparseMatrix.__init__, the args tuple is accepted but never used and data/nse/dtype are not initialized; consider either making this class explicitly abstract for those attributes or wiring args into the base initialization so subclasses have a consistent contract.
  • SparseMatrix.block_until_ready assumes every child from tree_flatten() has a block_until_ready method; if non-JAX arrays or scalars can appear there, you may want to guard this with hasattr or a narrower expectation.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In `SparseMatrix.__init__`, the `args` tuple is accepted but never used and `data`/`nse`/`dtype` are not initialized; consider either making this class explicitly abstract for those attributes or wiring `args` into the base initialization so subclasses have a consistent contract.
- `SparseMatrix.block_until_ready` assumes every child from `tree_flatten()` has a `block_until_ready` method; if non-JAX arrays or scalars can appear there, you may want to guard this with `hasattr` or a narrower expectation.

## Individual Comments

### Comment 1
<location path="saiunit/_sparse_base.py" line_range="65-70" />
<code_context>

+    data: jax.Array
+    shape: tuple[int, ...]
+    nse: property
+    dtype: property
+
+    __hash__ = None
</code_context>
<issue_to_address>
**suggestion:** Replace these property annotations with explicit abstract @property accessors for better typing and subclass contracts.

Using `nse: property` / `dtype: property` suggests the attributes themselves are `property` objects, which is confusing for both readers and type checkers. Since these are part of the base interface, please define them as abstract-style properties instead, for example:

```python
@property
def nse(self) -> int:
    raise NotImplementedError

@property
def dtype(self):  # or jax.typing.DTypeLike
    raise NotImplementedError
```

This keeps the intended interface while giving subclasses a clear contract and better typing/tooling support.

```suggestion
    data: jax.Array
    shape: tuple[int, ...]

    __hash__ = None

    @property
    def nse(self) -> int:
        """Number of stored (non-zero) elements in the sparse array.

        Subclasses must override this to return the actual count.
        """
        raise NotImplementedError

    @property
    def dtype(self):
        """Data type of the sparse array values.

        Subclasses must override this to return the concrete dtype.
        """
        raise NotImplementedError
```
</issue_to_address>

### Comment 2
<location path="saiunit/_sparse_base.py" line_range="72-78" />
<code_context>
+
+    def __init__(
+        self,
+        args: tuple[jax.Array, ...],
+        *,
+        shape: Sequence[int]
+    ):
</code_context>
<issue_to_address>
**suggestion:** Either use or drop the unused `args` parameter in `SparseMatrix.__init__`.

If subclasses are meant to use `args`, consider validating its structure (e.g., non-empty, expected length) or documenting the expected contents. If not, it’s clearer to remove `args` from the base `__init__` and let subclasses define their own parameters to avoid confusion and misuse.

```suggestion
    def __init__(
        self,
        args: tuple[jax.Array, ...],
        *,
        shape: Sequence[int]
    ):
        """Base initializer for sparse matrices.

        Parameters
        ----------
        args:
            Positional array arguments used by concrete sparse formats
            (e.g. data, indices, indptr). Subclasses may rely on this tuple.
        shape:
            Overall shape of the sparse matrix.
        """
        if not args:
            raise ValueError(
                "SparseMatrix.__init__ expected at least one array in `args` "
                "for use by subclasses."
            )
        self._args = args
        self.shape = tuple(int(s) for s in shape)
```
</issue_to_address>

### Comment 3
<location path="saiunit/_sparse_base.py" line_range="91-101" />
<code_context>
+
+    def __repr__(self):
+        name = self.__class__.__name__
+        try:
+            nse = self.nse
+            dtype = self.dtype
+            shape = list(self.shape)
+        except Exception:
+            repr_ = f"{name}(<invalid>)"
+        else:
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Narrow the exception handling in `__repr__` to avoid masking unrelated errors.

Catching a bare `Exception` here will turn any error (including real bugs) into `ClassName(<invalid>)`, hiding issues and complicating debugging. Limit the `except` to the specific failures you expect (e.g. `AttributeError` or a small set of known exceptions) so unexpected errors still surface.

```python
try:
    nse = self.nse
    dtype = self.dtype
    shape = list(self.shape)
except AttributeError:
    repr_ = f"{name}(<invalid>)"
else:
    repr_ = f"{name}({dtype}{shape}, {nse=})"
```

```suggestion
    def __repr__(self):
        name = self.__class__.__name__
        try:
            nse = self.nse
            dtype = self.dtype
            shape = list(self.shape)
        except AttributeError:
            repr_ = f"{name}(<invalid>)"
        else:
            repr_ = f"{name}({dtype}{shape}, {nse=})"
        return repr_
```
</issue_to_address>

### Comment 4
<location path="saiunit/sparse/_csr.py" line_range="186-191" />
<code_context>
-        assert data.shape == self.data.shape
-        assert data.dtype == self.data.dtype
-        assert get_unit(data) == get_unit(self.data)
+        if data.shape != self.data.shape:
+            raise ValueError(f"Shape mismatch: expected {self.data.shape}, got {data.shape}")
+        if data.dtype != self.data.dtype:
+            raise ValueError(f"Dtype mismatch: expected {self.data.dtype}, got {data.dtype}")
+        if get_unit(data) != get_unit(self.data):
+            raise ValueError(f"Unit mismatch: expected {get_unit(self.data)}, got {get_unit(data)}")
         return COO(
             (data, self.row, self.col),
</code_context>
<issue_to_address>
**suggestion:** Consider factoring out the repeated shape/dtype/unit validation logic used by CSR/CSC/COO.with_data.

The explicit `ValueError` checks are a nice improvement over asserts, but the same shape/dtype/unit validation and error text are now repeated across CSR/CSC/COO. Extracting this into a shared helper (e.g. `_validate_new_data(self.data, data)` on a common base or utility) would reduce duplication and keep future validation changes consistent.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread saiunit/_sparse_base.py
Comment on lines +65 to +70
data: jax.Array
shape: tuple[int, ...]
nse: property
dtype: property

__hash__ = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Replace these property annotations with explicit abstract @Property accessors for better typing and subclass contracts.

Using nse: property / dtype: property suggests the attributes themselves are property objects, which is confusing for both readers and type checkers. Since these are part of the base interface, please define them as abstract-style properties instead, for example:

@property
def nse(self) -> int:
    raise NotImplementedError

@property
def dtype(self):  # or jax.typing.DTypeLike
    raise NotImplementedError

This keeps the intended interface while giving subclasses a clear contract and better typing/tooling support.

Suggested change
data: jax.Array
shape: tuple[int, ...]
nse: property
dtype: property
__hash__ = None
data: jax.Array
shape: tuple[int, ...]
__hash__ = None
@property
def nse(self) -> int:
"""Number of stored (non-zero) elements in the sparse array.
Subclasses must override this to return the actual count.
"""
raise NotImplementedError
@property
def dtype(self):
"""Data type of the sparse array values.
Subclasses must override this to return the concrete dtype.
"""
raise NotImplementedError

Comment thread saiunit/_sparse_base.py
Comment on lines +72 to +78
def __init__(
self,
args: tuple[jax.Array, ...],
*,
shape: Sequence[int]
):
self.shape = tuple(int(s) for s in shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Either use or drop the unused args parameter in SparseMatrix.__init__.

If subclasses are meant to use args, consider validating its structure (e.g., non-empty, expected length) or documenting the expected contents. If not, it’s clearer to remove args from the base __init__ and let subclasses define their own parameters to avoid confusion and misuse.

Suggested change
def __init__(
self,
args: tuple[jax.Array, ...],
*,
shape: Sequence[int]
):
self.shape = tuple(int(s) for s in shape)
def __init__(
self,
args: tuple[jax.Array, ...],
*,
shape: Sequence[int]
):
"""Base initializer for sparse matrices.
Parameters
----------
args:
Positional array arguments used by concrete sparse formats
(e.g. data, indices, indptr). Subclasses may rely on this tuple.
shape:
Overall shape of the sparse matrix.
"""
if not args:
raise ValueError(
"SparseMatrix.__init__ expected at least one array in `args` "
"for use by subclasses."
)
self._args = args
self.shape = tuple(int(s) for s in shape)

Comment thread saiunit/_sparse_base.py
Comment on lines +91 to +101
def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
dtype = self.dtype
shape = list(self.shape)
except Exception:
repr_ = f"{name}(<invalid>)"
else:
repr_ = f"{name}({dtype}{shape}, {nse=})"
return repr_
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Narrow the exception handling in __repr__ to avoid masking unrelated errors.

Catching a bare Exception here will turn any error (including real bugs) into ClassName(<invalid>), hiding issues and complicating debugging. Limit the except to the specific failures you expect (e.g. AttributeError or a small set of known exceptions) so unexpected errors still surface.

try:
    nse = self.nse
    dtype = self.dtype
    shape = list(self.shape)
except AttributeError:
    repr_ = f"{name}(<invalid>)"
else:
    repr_ = f"{name}({dtype}{shape}, {nse=})"
Suggested change
def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
dtype = self.dtype
shape = list(self.shape)
except Exception:
repr_ = f"{name}(<invalid>)"
else:
repr_ = f"{name}({dtype}{shape}, {nse=})"
return repr_
def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
dtype = self.dtype
shape = list(self.shape)
except AttributeError:
repr_ = f"{name}(<invalid>)"
else:
repr_ = f"{name}({dtype}{shape}, {nse=})"
return repr_

Comment thread saiunit/sparse/_csr.py
Comment on lines +186 to +191
if data.shape != self.data.shape:
raise ValueError(f"Shape mismatch: expected {self.data.shape}, got {data.shape}")
if data.dtype != self.data.dtype:
raise ValueError(f"Dtype mismatch: expected {self.data.dtype}, got {data.dtype}")
if get_unit(data) != get_unit(self.data):
raise ValueError(f"Unit mismatch: expected {get_unit(self.data)}, got {get_unit(data)}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider factoring out the repeated shape/dtype/unit validation logic used by CSR/CSC/COO.with_data.

The explicit ValueError checks are a nice improvement over asserts, but the same shape/dtype/unit validation and error text are now repeated across CSR/CSC/COO. Extracting this into a shared helper (e.g. _validate_new_data(self.data, data) on a common base or utility) would reduce duplication and keep future validation changes consistent.

@chaoming0625 chaoming0625 changed the title [sparse] Decouple SparseMatrix from JAX and drop block-sparse impls [sparse] Compatiable jax>=0.9.2 Decouple SparseMatrix from JAX and drop block-sparse impls Mar 20, 2026
@chaoming0625 chaoming0625 changed the title [sparse] Compatiable jax>=0.9.2 Decouple SparseMatrix from JAX and drop block-sparse impls [sparse] Compatiable jax>=0.9.2, decouple SparseMatrix from JAX and drop block-sparse impls Mar 20, 2026
@chaoming0625 chaoming0625 merged commit 5501b81 into main Mar 20, 2026
4 checks passed
@chaoming0625 chaoming0625 deleted the update branch March 20, 2026 06:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant