Skip to content

Commit

Permalink
[sparse] validate BCOO on instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 12, 2022
1 parent b868cf7 commit e9cc523
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 16 deletions.
3 changes: 2 additions & 1 deletion jax/experimental/sparse/_base.py
Expand Up @@ -62,8 +62,9 @@ def tree_flatten(self):
...

@classmethod
@abc.abstractmethod
def tree_unflatten(cls, aux_data, children):
return cls(children, **aux_data)
...

@abc.abstractmethod
def transpose(self, axes=None):
Expand Down
14 changes: 12 additions & 2 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -31,7 +31,7 @@
from jax.config import config
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import (
_broadcasting_vmap, _count_stored_elements, _safe_asarray,
_broadcasting_vmap, _count_stored_elements,
_dot_general_validated_shape, CuSparseEfficiencyWarning,
SparseEfficiencyError, SparseEfficiencyWarning)
from jax.interpreters import batching
Expand Down Expand Up @@ -2386,10 +2386,11 @@ def __init__(self, args: Tuple[Array, Array], *, shape: Sequence[int],
indices_sorted: bool = False, unique_indices: bool = False):
# JAX transforms will sometimes instantiate pytrees with null values, so we
# must catch that in the initialization of inputs.
self.data, self.indices = _safe_asarray(args) # type: ignore[assignment]
self.data, self.indices = map(jnp.asarray, args)
self.indices_sorted = indices_sorted
self.unique_indices = unique_indices
super().__init__(args, shape=tuple(shape))
_validate_bcoo(self.data, self.indices, self.shape)

def __repr__(self):
name = self.__class__.__name__
Expand Down Expand Up @@ -2582,6 +2583,15 @@ def transpose(self, axes: Optional[Sequence[int]] = None) -> BCOO:
def tree_flatten(self):
return (self.data, self.indices), self._info._asdict()

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices = children
if aux_data.keys() != {'shape', 'indices_sorted', 'unique_indices'}:
raise ValueError(f"BCOO.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj


# vmappable handlers
def _bcoo_to_elt(cont, _, val, axis):
Expand Down
14 changes: 12 additions & 2 deletions jax/experimental/sparse/bcsr.py
Expand Up @@ -25,7 +25,7 @@
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse import bcoo
from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _csr_to_coo, _safe_asarray
from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _csr_to_coo
import jax.numpy as jnp
from jax.util import split_list, safe_zip
from jax.interpreters import batching
Expand Down Expand Up @@ -320,8 +320,9 @@ def _sparse_shape(self):
def __init__(self, args, *, shape):
# JAX transforms will sometimes instantiate pytrees with null values, so we
# must catch that in the initialization of inputs.
self.data, self.indices, self.indptr = _safe_asarray(args)
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)
_validate_bcsr(self.data, self.indices, self.indptr, self.shape)

def __repr__(self):
name = self.__class__.__name__
Expand All @@ -348,6 +349,15 @@ def transpose(self, *args, **kwargs):
def tree_flatten(self):
return (self.data, self.indices, self.indptr), {'shape': self.shape}

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape'}:
raise ValueError(f"BCSR.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj

@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32', n_dense=0,
n_batch=0, nse=0):
Expand Down
15 changes: 13 additions & 2 deletions jax/experimental/sparse/coo.py
Expand Up @@ -27,7 +27,7 @@
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.lax.lax import _const
from jax._src.lib.mlir.dialects import mhlo
Expand Down Expand Up @@ -69,7 +69,7 @@ class COO(JAXSparse):

def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape,
rows_sorted: bool = False, cols_sorted: bool = False):
self.data, self.row, self.col = _safe_asarray(args) # type: ignore[assignment]
self.data, self.row, self.col = map(jnp.asarray, args)
self._rows_sorted = rows_sorted
self._cols_sorted = cols_sorted
super().__init__(args, shape=shape)
Expand Down Expand Up @@ -135,6 +135,17 @@ def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> COO:
def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]:
return (self.data, self.row, self.col), self._info._asdict()

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.row, obj.col = children
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}:
raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}")
obj.shape = aux_data['shape']
obj._rows_sorted = aux_data['rows_sorted']
obj._cols_sorted = aux_data['cols_sorted']
return obj

def __matmul__(self, other: ArrayLike) -> Array:
if isinstance(other, JAXSparse):
raise NotImplementedError("matmul between two sparse objects.")
Expand Down
24 changes: 21 additions & 3 deletions jax/experimental/sparse/csr.py
Expand Up @@ -26,7 +26,7 @@
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning
from jax import lax
from jax import tree_util
from jax._src.lax.lax import _const
Expand All @@ -51,7 +51,7 @@ class CSR(JAXSparse):
dtype = property(lambda self: self.data.dtype)

def __init__(self, args, *, shape):
self.data, self.indices, self.indptr = _safe_asarray(args)
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)

@classmethod
Expand Down Expand Up @@ -116,6 +116,15 @@ def __matmul__(self, other):
def tree_flatten(self):
return (self.data, self.indices, self.indptr), {"shape": self.shape}

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape'}:
raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj


@tree_util.register_pytree_node_class
class CSC(JAXSparse):
Expand All @@ -128,7 +137,7 @@ class CSC(JAXSparse):
dtype = property(lambda self: self.data.dtype)

def __init__(self, args, *, shape):
self.data, self.indices, self.indptr = _safe_asarray(args)
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)

@classmethod
Expand Down Expand Up @@ -174,6 +183,15 @@ def __matmul__(self, other):
def tree_flatten(self):
return (self.data, self.indices, self.indptr), {"shape": self.shape}

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape'}:
raise ValueError(f"CSC.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj


#--------------------------------------------------------------------
# csr_todense
Expand Down
5 changes: 0 additions & 5 deletions jax/experimental/sparse/util.py
Expand Up @@ -107,11 +107,6 @@ def _is_aval(*args: Any) -> bool:
def _is_arginfo(*args: Any) -> bool:
return all(isinstance(arg, stages.ArgInfo) for arg in args)

def _safe_asarray(args: Sequence[Any]) -> Iterable[Union[np.ndarray, Array]]:
if _is_pytree_placeholder(*args) or _is_aval(*args) or _is_arginfo(*args):
return args
return map(_asarray_or_float0, args)

def _dot_general_validated_shape(
lhs_shape: Tuple[int, ...], rhs_shape: Tuple[int, ...],
dimension_numbers: DotDimensionNumbers) -> Tuple[int, ...]:
Expand Down
3 changes: 2 additions & 1 deletion tests/sparse_test.py
Expand Up @@ -692,7 +692,8 @@ def test_repr(self):
y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1, n_dense=1)
self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=1, n_batch=1, n_dense=1)")

M_invalid = sparse.BCOO(([], []), shape=(100,))
M_invalid = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3))
M_invalid.indices = jnp.array([])
self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")

@jit
Expand Down

0 comments on commit e9cc523

Please sign in to comment.