[sparse] Compatiable jax>=0.9.2, decouple SparseMatrix from JAX and drop block-sparse impls#65
[sparse] Compatiable jax>=0.9.2, decouple SparseMatrix from JAX and drop block-sparse impls#65chaoming0625 merged 3 commits intomainfrom
jax>=0.9.2, decouple SparseMatrix from JAX and drop block-sparse impls#65Conversation
…nce and adding new properties and methods
… checks in COO and CSR classes
Reviewer's GuideRefactors the SparseMatrix base away from JAX’s JAXSparse, adds core container/utility behavior, updates CSR/CSC/COO implementations to interoperate with the new base and improve validation and tree_unflatten, and removes unused block-sparse modules while bumping the saiunit package version. Class diagram for updated SparseMatrix base and sparse formatsclassDiagram
class SparseMatrix {
+jax.Array data
+tuple~int~ shape
+property nse
+property dtype
+__hash__
+SparseMatrix(args, shape)
+__len__() int
+size int
+ndim int
+__repr__() str
+T SparseMatrix
+block_until_ready() SparseMatrix
+tree_flatten()
+tree_unflatten(aux_data, children) SparseMatrix
+transpose(axes)
+todense()
+with_data(data)
}
class CSR {
+jax.Array data
+jax.Array indices
+jax.Array indptr
+tuple~int~ shape
+with_data(data) CSR
+todense()
+_binary_op(other, op)
+_binary_rop(other, op)
+__matmul__(other)
+__rmatmul__(other)
+tree_unflatten(aux_data, children) CSR
}
class CSC {
+jax.Array data
+jax.Array indices
+jax.Array indptr
+tuple~int~ shape
+with_data(data) CSC
+todense()
+_binary_op(other, op)
+_binary_rop(other, op)
+__matmul__(other)
+__rmatmul__(other)
+tree_unflatten(aux_data, children) CSC
}
class COO {
+jax.Array data
+jax.Array row
+jax.Array col
+tuple~int~ shape
+with_data(data) COO
+todense()
+_binary_op(other, op)
+_binary_rop(other, op)
+__matmul__(other)
+__rmatmul__(other)
}
class JAXSparse
SparseMatrix <|-- CSR
SparseMatrix <|-- CSC
SparseMatrix <|-- COO
CSR ..> JAXSparse : runtime_checks
CSC ..> JAXSparse : runtime_checks
COO ..> JAXSparse : runtime_checks
CSR ..> SparseMatrix : runtime_checks
CSC ..> SparseMatrix : runtime_checks
COO ..> SparseMatrix : runtime_checks
Flow diagram for with_data validation in sparse matricesflowchart TD
A[call with_data on sparse matrix] --> B{shape matches self.data.shape}
B -- no --> E[raise ValueError Shape mismatch]
B -- yes --> C{dtype matches self.data.dtype}
C -- no --> F[raise ValueError Dtype mismatch]
C -- yes --> D{unit matches self.data unit}
D -- no --> G[raise ValueError Unit mismatch]
D -- yes --> H[construct new instance with new data and existing indices and structure]
H --> I[return new sparse matrix instance]
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Hey - I've found 4 issues, and left some high level feedback:
- In
SparseMatrix.__init__, theargstuple is accepted but never used anddata/nse/dtypeare not initialized; consider either making this class explicitly abstract for those attributes or wiringargsinto the base initialization so subclasses have a consistent contract. SparseMatrix.block_until_readyassumes every child fromtree_flatten()has ablock_until_readymethod; if non-JAX arrays or scalars can appear there, you may want to guard this withhasattror a narrower expectation.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `SparseMatrix.__init__`, the `args` tuple is accepted but never used and `data`/`nse`/`dtype` are not initialized; consider either making this class explicitly abstract for those attributes or wiring `args` into the base initialization so subclasses have a consistent contract.
- `SparseMatrix.block_until_ready` assumes every child from `tree_flatten()` has a `block_until_ready` method; if non-JAX arrays or scalars can appear there, you may want to guard this with `hasattr` or a narrower expectation.
## Individual Comments
### Comment 1
<location path="saiunit/_sparse_base.py" line_range="65-70" />
<code_context>
+ data: jax.Array
+ shape: tuple[int, ...]
+ nse: property
+ dtype: property
+
+ __hash__ = None
</code_context>
<issue_to_address>
**suggestion:** Replace these property annotations with explicit abstract @property accessors for better typing and subclass contracts.
Using `nse: property` / `dtype: property` suggests the attributes themselves are `property` objects, which is confusing for both readers and type checkers. Since these are part of the base interface, please define them as abstract-style properties instead, for example:
```python
@property
def nse(self) -> int:
raise NotImplementedError
@property
def dtype(self): # or jax.typing.DTypeLike
raise NotImplementedError
```
This keeps the intended interface while giving subclasses a clear contract and better typing/tooling support.
```suggestion
data: jax.Array
shape: tuple[int, ...]
__hash__ = None
@property
def nse(self) -> int:
"""Number of stored (non-zero) elements in the sparse array.
Subclasses must override this to return the actual count.
"""
raise NotImplementedError
@property
def dtype(self):
"""Data type of the sparse array values.
Subclasses must override this to return the concrete dtype.
"""
raise NotImplementedError
```
</issue_to_address>
### Comment 2
<location path="saiunit/_sparse_base.py" line_range="72-78" />
<code_context>
+
+ def __init__(
+ self,
+ args: tuple[jax.Array, ...],
+ *,
+ shape: Sequence[int]
+ ):
</code_context>
<issue_to_address>
**suggestion:** Either use or drop the unused `args` parameter in `SparseMatrix.__init__`.
If subclasses are meant to use `args`, consider validating its structure (e.g., non-empty, expected length) or documenting the expected contents. If not, it’s clearer to remove `args` from the base `__init__` and let subclasses define their own parameters to avoid confusion and misuse.
```suggestion
def __init__(
self,
args: tuple[jax.Array, ...],
*,
shape: Sequence[int]
):
"""Base initializer for sparse matrices.
Parameters
----------
args:
Positional array arguments used by concrete sparse formats
(e.g. data, indices, indptr). Subclasses may rely on this tuple.
shape:
Overall shape of the sparse matrix.
"""
if not args:
raise ValueError(
"SparseMatrix.__init__ expected at least one array in `args` "
"for use by subclasses."
)
self._args = args
self.shape = tuple(int(s) for s in shape)
```
</issue_to_address>
### Comment 3
<location path="saiunit/_sparse_base.py" line_range="91-101" />
<code_context>
+
+ def __repr__(self):
+ name = self.__class__.__name__
+ try:
+ nse = self.nse
+ dtype = self.dtype
+ shape = list(self.shape)
+ except Exception:
+ repr_ = f"{name}(<invalid>)"
+ else:
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Narrow the exception handling in `__repr__` to avoid masking unrelated errors.
Catching a bare `Exception` here will turn any error (including real bugs) into `ClassName(<invalid>)`, hiding issues and complicating debugging. Limit the `except` to the specific failures you expect (e.g. `AttributeError` or a small set of known exceptions) so unexpected errors still surface.
```python
try:
nse = self.nse
dtype = self.dtype
shape = list(self.shape)
except AttributeError:
repr_ = f"{name}(<invalid>)"
else:
repr_ = f"{name}({dtype}{shape}, {nse=})"
```
```suggestion
def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
dtype = self.dtype
shape = list(self.shape)
except AttributeError:
repr_ = f"{name}(<invalid>)"
else:
repr_ = f"{name}({dtype}{shape}, {nse=})"
return repr_
```
</issue_to_address>
### Comment 4
<location path="saiunit/sparse/_csr.py" line_range="186-191" />
<code_context>
- assert data.shape == self.data.shape
- assert data.dtype == self.data.dtype
- assert get_unit(data) == get_unit(self.data)
+ if data.shape != self.data.shape:
+ raise ValueError(f"Shape mismatch: expected {self.data.shape}, got {data.shape}")
+ if data.dtype != self.data.dtype:
+ raise ValueError(f"Dtype mismatch: expected {self.data.dtype}, got {data.dtype}")
+ if get_unit(data) != get_unit(self.data):
+ raise ValueError(f"Unit mismatch: expected {get_unit(self.data)}, got {get_unit(data)}")
return COO(
(data, self.row, self.col),
</code_context>
<issue_to_address>
**suggestion:** Consider factoring out the repeated shape/dtype/unit validation logic used by CSR/CSC/COO.with_data.
The explicit `ValueError` checks are a nice improvement over asserts, but the same shape/dtype/unit validation and error text are now repeated across CSR/CSC/COO. Extracting this into a shared helper (e.g. `_validate_new_data(self.data, data)` on a common base or utility) would reduce duplication and keep future validation changes consistent.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| data: jax.Array | ||
| shape: tuple[int, ...] | ||
| nse: property | ||
| dtype: property | ||
|
|
||
| __hash__ = None |
There was a problem hiding this comment.
suggestion: Replace these property annotations with explicit abstract @Property accessors for better typing and subclass contracts.
Using nse: property / dtype: property suggests the attributes themselves are property objects, which is confusing for both readers and type checkers. Since these are part of the base interface, please define them as abstract-style properties instead, for example:
@property
def nse(self) -> int:
raise NotImplementedError
@property
def dtype(self): # or jax.typing.DTypeLike
raise NotImplementedErrorThis keeps the intended interface while giving subclasses a clear contract and better typing/tooling support.
| data: jax.Array | |
| shape: tuple[int, ...] | |
| nse: property | |
| dtype: property | |
| __hash__ = None | |
| data: jax.Array | |
| shape: tuple[int, ...] | |
| __hash__ = None | |
| @property | |
| def nse(self) -> int: | |
| """Number of stored (non-zero) elements in the sparse array. | |
| Subclasses must override this to return the actual count. | |
| """ | |
| raise NotImplementedError | |
| @property | |
| def dtype(self): | |
| """Data type of the sparse array values. | |
| Subclasses must override this to return the concrete dtype. | |
| """ | |
| raise NotImplementedError |
| def __init__( | ||
| self, | ||
| args: tuple[jax.Array, ...], | ||
| *, | ||
| shape: Sequence[int] | ||
| ): | ||
| self.shape = tuple(int(s) for s in shape) |
There was a problem hiding this comment.
suggestion: Either use or drop the unused args parameter in SparseMatrix.__init__.
If subclasses are meant to use args, consider validating its structure (e.g., non-empty, expected length) or documenting the expected contents. If not, it’s clearer to remove args from the base __init__ and let subclasses define their own parameters to avoid confusion and misuse.
| def __init__( | |
| self, | |
| args: tuple[jax.Array, ...], | |
| *, | |
| shape: Sequence[int] | |
| ): | |
| self.shape = tuple(int(s) for s in shape) | |
| def __init__( | |
| self, | |
| args: tuple[jax.Array, ...], | |
| *, | |
| shape: Sequence[int] | |
| ): | |
| """Base initializer for sparse matrices. | |
| Parameters | |
| ---------- | |
| args: | |
| Positional array arguments used by concrete sparse formats | |
| (e.g. data, indices, indptr). Subclasses may rely on this tuple. | |
| shape: | |
| Overall shape of the sparse matrix. | |
| """ | |
| if not args: | |
| raise ValueError( | |
| "SparseMatrix.__init__ expected at least one array in `args` " | |
| "for use by subclasses." | |
| ) | |
| self._args = args | |
| self.shape = tuple(int(s) for s in shape) |
| def __repr__(self): | ||
| name = self.__class__.__name__ | ||
| try: | ||
| nse = self.nse | ||
| dtype = self.dtype | ||
| shape = list(self.shape) | ||
| except Exception: | ||
| repr_ = f"{name}(<invalid>)" | ||
| else: | ||
| repr_ = f"{name}({dtype}{shape}, {nse=})" | ||
| return repr_ |
There was a problem hiding this comment.
suggestion (bug_risk): Narrow the exception handling in __repr__ to avoid masking unrelated errors.
Catching a bare Exception here will turn any error (including real bugs) into ClassName(<invalid>), hiding issues and complicating debugging. Limit the except to the specific failures you expect (e.g. AttributeError or a small set of known exceptions) so unexpected errors still surface.
try:
nse = self.nse
dtype = self.dtype
shape = list(self.shape)
except AttributeError:
repr_ = f"{name}(<invalid>)"
else:
repr_ = f"{name}({dtype}{shape}, {nse=})"| def __repr__(self): | |
| name = self.__class__.__name__ | |
| try: | |
| nse = self.nse | |
| dtype = self.dtype | |
| shape = list(self.shape) | |
| except Exception: | |
| repr_ = f"{name}(<invalid>)" | |
| else: | |
| repr_ = f"{name}({dtype}{shape}, {nse=})" | |
| return repr_ | |
| def __repr__(self): | |
| name = self.__class__.__name__ | |
| try: | |
| nse = self.nse | |
| dtype = self.dtype | |
| shape = list(self.shape) | |
| except AttributeError: | |
| repr_ = f"{name}(<invalid>)" | |
| else: | |
| repr_ = f"{name}({dtype}{shape}, {nse=})" | |
| return repr_ |
| if data.shape != self.data.shape: | ||
| raise ValueError(f"Shape mismatch: expected {self.data.shape}, got {data.shape}") | ||
| if data.dtype != self.data.dtype: | ||
| raise ValueError(f"Dtype mismatch: expected {self.data.dtype}, got {data.dtype}") | ||
| if get_unit(data) != get_unit(self.data): | ||
| raise ValueError(f"Unit mismatch: expected {get_unit(self.data)}, got {get_unit(data)}") |
There was a problem hiding this comment.
suggestion: Consider factoring out the repeated shape/dtype/unit validation logic used by CSR/CSC/COO.with_data.
The explicit ValueError checks are a nice improvement over asserts, but the same shape/dtype/unit validation and error text are now repeated across CSR/CSC/COO. Extracting this into a shared helper (e.g. _validate_new_data(self.data, data) on a common base or utility) would reduce duplication and keep future validation changes consistent.
jax>=0.9.2 Decouple SparseMatrix from JAX and drop block-sparse impls
jax>=0.9.2 Decouple SparseMatrix from JAX and drop block-sparse implsjax>=0.9.2, decouple SparseMatrix from JAX and drop block-sparse impls
Summary by Sourcery
Decouple the internal sparse matrix base class from JAX's JAXSparse, enhance its core interface, and tighten validation and interoperability across CSR, CSC, and COO sparse formats while removing unused block-sparse implementations.
Enhancements:
Chores: