From e9cc5238730146bf900b2fdb7b5f027caaad99bb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 12 Dec 2022 11:39:06 -0800 Subject: [PATCH] [sparse] validate BCOO on instantiation --- jax/experimental/sparse/_base.py | 3 ++- jax/experimental/sparse/bcoo.py | 14 ++++++++++++-- jax/experimental/sparse/bcsr.py | 14 ++++++++++++-- jax/experimental/sparse/coo.py | 15 +++++++++++++-- jax/experimental/sparse/csr.py | 24 +++++++++++++++++++++--- jax/experimental/sparse/util.py | 5 ----- tests/sparse_test.py | 3 ++- 7 files changed, 62 insertions(+), 16 deletions(-) diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 0ba83699bc09..67a918c59a90 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -62,8 +62,9 @@ def tree_flatten(self): ... @classmethod + @abc.abstractmethod def tree_unflatten(cls, aux_data, children): - return cls(children, **aux_data) + ... @abc.abstractmethod def transpose(self, axes=None): diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 1ee609242442..7b4e720aa4d1 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -31,7 +31,7 @@ from jax.config import config from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.util import ( - _broadcasting_vmap, _count_stored_elements, _safe_asarray, + _broadcasting_vmap, _count_stored_elements, _dot_general_validated_shape, CuSparseEfficiencyWarning, SparseEfficiencyError, SparseEfficiencyWarning) from jax.interpreters import batching @@ -2386,10 +2386,11 @@ def __init__(self, args: Tuple[Array, Array], *, shape: Sequence[int], indices_sorted: bool = False, unique_indices: bool = False): # JAX transforms will sometimes instantiate pytrees with null values, so we # must catch that in the initialization of inputs. - self.data, self.indices = _safe_asarray(args) # type: ignore[assignment] + self.data, self.indices = map(jnp.asarray, args) self.indices_sorted = indices_sorted self.unique_indices = unique_indices super().__init__(args, shape=tuple(shape)) + _validate_bcoo(self.data, self.indices, self.shape) def __repr__(self): name = self.__class__.__name__ @@ -2582,6 +2583,15 @@ def transpose(self, axes: Optional[Sequence[int]] = None) -> BCOO: def tree_flatten(self): return (self.data, self.indices), self._info._asdict() + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, obj.indices = children + if aux_data.keys() != {'shape', 'indices_sorted', 'unique_indices'}: + raise ValueError(f"BCOO.tree_unflatten: invalid {aux_data=}") + obj.__dict__.update(**aux_data) + return obj + # vmappable handlers def _bcoo_to_elt(cont, _, val, axis): diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index d8b40598d3a0..663efcfabf16 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -25,7 +25,7 @@ from jax import tree_util from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse import bcoo -from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _csr_to_coo, _safe_asarray +from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _csr_to_coo import jax.numpy as jnp from jax.util import split_list, safe_zip from jax.interpreters import batching @@ -320,8 +320,9 @@ def _sparse_shape(self): def __init__(self, args, *, shape): # JAX transforms will sometimes instantiate pytrees with null values, so we # must catch that in the initialization of inputs. - self.data, self.indices, self.indptr = _safe_asarray(args) + self.data, self.indices, self.indptr = map(jnp.asarray, args) super().__init__(args, shape=shape) + _validate_bcsr(self.data, self.indices, self.indptr, self.shape) def __repr__(self): name = self.__class__.__name__ @@ -348,6 +349,15 @@ def transpose(self, *args, **kwargs): def tree_flatten(self): return (self.data, self.indices, self.indptr), {'shape': self.shape} + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, obj.indices, obj.indptr = children + if aux_data.keys() != {'shape'}: + raise ValueError(f"BCSR.tree_unflatten: invalid {aux_data=}") + obj.__dict__.update(**aux_data) + return obj + @classmethod def _empty(cls, shape, *, dtype=None, index_dtype='int32', n_dense=0, n_batch=0, nse=0): diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 65e9199cf18b..6778cd1f9454 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -27,7 +27,7 @@ from jax.interpreters import ad from jax.interpreters import mlir from jax.experimental.sparse._base import JAXSparse -from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning +from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning from jax import tree_util from jax._src.lax.lax import _const from jax._src.lib.mlir.dialects import mhlo @@ -69,7 +69,7 @@ class COO(JAXSparse): def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape, rows_sorted: bool = False, cols_sorted: bool = False): - self.data, self.row, self.col = _safe_asarray(args) # type: ignore[assignment] + self.data, self.row, self.col = map(jnp.asarray, args) self._rows_sorted = rows_sorted self._cols_sorted = cols_sorted super().__init__(args, shape=shape) @@ -135,6 +135,17 @@ def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> COO: def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]: return (self.data, self.row, self.col), self._info._asdict() + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, obj.row, obj.col = children + if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}: + raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}") + obj.shape = aux_data['shape'] + obj._rows_sorted = aux_data['rows_sorted'] + obj._cols_sorted = aux_data['cols_sorted'] + return obj + def __matmul__(self, other: ArrayLike) -> Array: if isinstance(other, JAXSparse): raise NotImplementedError("matmul between two sparse objects.") diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 9db6c5c3dd2b..3ea81b5532bd 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -26,7 +26,7 @@ from jax.interpreters import mlir from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo -from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, _safe_asarray, CuSparseEfficiencyWarning +from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning from jax import lax from jax import tree_util from jax._src.lax.lax import _const @@ -51,7 +51,7 @@ class CSR(JAXSparse): dtype = property(lambda self: self.data.dtype) def __init__(self, args, *, shape): - self.data, self.indices, self.indptr = _safe_asarray(args) + self.data, self.indices, self.indptr = map(jnp.asarray, args) super().__init__(args, shape=shape) @classmethod @@ -116,6 +116,15 @@ def __matmul__(self, other): def tree_flatten(self): return (self.data, self.indices, self.indptr), {"shape": self.shape} + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, obj.indices, obj.indptr = children + if aux_data.keys() != {'shape'}: + raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}") + obj.__dict__.update(**aux_data) + return obj + @tree_util.register_pytree_node_class class CSC(JAXSparse): @@ -128,7 +137,7 @@ class CSC(JAXSparse): dtype = property(lambda self: self.data.dtype) def __init__(self, args, *, shape): - self.data, self.indices, self.indptr = _safe_asarray(args) + self.data, self.indices, self.indptr = map(jnp.asarray, args) super().__init__(args, shape=shape) @classmethod @@ -174,6 +183,15 @@ def __matmul__(self, other): def tree_flatten(self): return (self.data, self.indices, self.indptr), {"shape": self.shape} + @classmethod + def tree_unflatten(cls, aux_data, children): + obj = object.__new__(cls) + obj.data, obj.indices, obj.indptr = children + if aux_data.keys() != {'shape'}: + raise ValueError(f"CSC.tree_unflatten: invalid {aux_data=}") + obj.__dict__.update(**aux_data) + return obj + #-------------------------------------------------------------------- # csr_todense diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 6c4f883ba9c3..81680eea1353 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -107,11 +107,6 @@ def _is_aval(*args: Any) -> bool: def _is_arginfo(*args: Any) -> bool: return all(isinstance(arg, stages.ArgInfo) for arg in args) -def _safe_asarray(args: Sequence[Any]) -> Iterable[Union[np.ndarray, Array]]: - if _is_pytree_placeholder(*args) or _is_aval(*args) or _is_arginfo(*args): - return args - return map(_asarray_or_float0, args) - def _dot_general_validated_shape( lhs_shape: Tuple[int, ...], rhs_shape: Tuple[int, ...], dimension_numbers: DotDimensionNumbers) -> Tuple[int, ...]: diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 212cff547b79..debc8058d309 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -692,7 +692,8 @@ def test_repr(self): y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1, n_dense=1) self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=1, n_batch=1, n_dense=1)") - M_invalid = sparse.BCOO(([], []), shape=(100,)) + M_invalid = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3)) + M_invalid.indices = jnp.array([]) self.assertEqual(repr(M_invalid), "BCOO()") @jit