From df358242ff7f533fa8dc609486da8f2e7e7646e9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 16 Feb 2023 16:27:31 -0800 Subject: [PATCH] [sparse] test coo/csr extra nse --- jax/experimental/sparse/coo.py | 32 ++++++++++++++++++-------- tests/sparse_test.py | 41 +++++++++++++++++++++++++++++++--- 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 6a6de2783c3d..0331e839584d 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -45,6 +45,7 @@ class COOInfo(NamedTuple): shape: Shape rows_sorted: bool = False cols_sorted: bool = False + padded: bool = False @tree_util.register_pytree_node_class @@ -63,16 +64,19 @@ class COO(JAXSparse): dtype = property(lambda self: self.data.dtype) _info = property(lambda self: COOInfo( shape=self.shape, rows_sorted=self._rows_sorted, - cols_sorted=self._cols_sorted)) + cols_sorted=self._cols_sorted, padded=self._padded)) _bufs = property(lambda self: (self.data, self.row, self.col)) _rows_sorted: bool _cols_sorted: bool + _padded: bool def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape, - rows_sorted: bool = False, cols_sorted: bool = False): + rows_sorted: bool = False, cols_sorted: bool = False, + padded: bool = True): self.data, self.row, self.col = map(jnp.asarray, args) self._rows_sorted = rows_sorted self._cols_sorted = cols_sorted + self._padded = padded super().__init__(args, shape=shape) @classmethod @@ -131,7 +135,7 @@ def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> COO: if axes is not None: raise NotImplementedError("axes argument to transpose()") return COO((self.data, self.col, self.row), shape=self.shape[::-1], - rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted) + rows_sorted=self._cols_sorted, cols_sorted=self._rows_sorted, padded=self._padded) def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]: return (self.data, self.row, self.col), self._info._asdict() @@ -140,11 +144,12 @@ def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]: def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) obj.data, obj.row, obj.col = children - if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}: + if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted', 'padded'}: raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}") obj.shape = aux_data['shape'] obj._rows_sorted = aux_data['rows_sorted'] obj._cols_sorted = aux_data['cols_sorted'] + obj._padded = aux_data['padded'] return obj def __matmul__(self, other: ArrayLike) -> Array: @@ -207,6 +212,9 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) + if spinfo.padded: + # GPU rule returns incorrect results with padded representation. + return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) if spinfo.rows_sorted: shape = spinfo.shape @@ -274,11 +282,12 @@ def coo_fromdense(mat: Array, *, nse: Optional[int] = None, index_dtype: DTypeLi Returns: mat_coo : COO representation of the matrix. """ + padded = nse is not None if nse is None: nse = int((mat != 0).sum()) nse_int = core.concrete_or_error(operator.index, nse, "coo_fromdense nse argument") return COO(_coo_fromdense(mat, nse=nse_int, index_dtype=index_dtype), - shape=mat.shape, rows_sorted=True) + shape=mat.shape, rows_sorted=True, padded=padded) def _coo_fromdense(mat: Array, *, nse: int, index_dtype: DTypeLike = jnp.int32) -> Tuple[Array, Array, Array]: """Create COO-format sparse matrix from a dense matrix. @@ -446,8 +455,10 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) - return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, - transpose=transpose) + return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose) + if spinfo.padded: + # GPU rule returns incorrect results with padded representation. + return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose) if spinfo.rows_sorted: shape = spinfo.shape @@ -569,8 +580,11 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) - return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, - transpose=transpose) + return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose) + if spinfo.padded: + # GPU rule returns incorrect results with padded representation. + return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose) + if spinfo.rows_sorted: shape = spinfo.shape elif spinfo.cols_sorted: diff --git a/tests/sparse_test.py b/tests/sparse_test.py index c704d6671df9..2a2a3fcd54b8 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -508,7 +508,7 @@ def test_coo_sorted_indices_gpu_lowerings(self): self.assertFalse(mat_cols_sorted._rows_sorted) self.assertTrue(mat_cols_sorted._cols_sorted) - mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape) + mat_unsorted = sparse.COO(mat_rows_sorted._bufs, shape=mat_rows_sorted.shape, padded=False) self.assertFalse(mat_unsorted._rows_sorted) self.assertFalse(mat_unsorted._cols_sorted) @@ -582,10 +582,45 @@ def test_gpu_translation_rule(self): ) def test_extra_nse(self, shape, dtype, mat_type): rng = rand_sparse(self.rng()) + rng_dense = jtu.rand_default(self.rng()) M = rng(shape, dtype) nse = (M != 0).sum() + 5 - M_out = mat_type.fromdense(M, nse=nse, index_dtype=jnp.int32).todense() - self.assertArraysEqual(M, M_out) + M_sp = mat_type.fromdense(M, nse=nse) + + with self.subTest("todense"): + def todense1(M, _): + assert isinstance(M, np.ndarray) + return M + def todense2(_, M): + assert isinstance(M, mat_type) + return M.todense() + args_maker = lambda: [M, M_sp] + self._CheckAgainstNumpy(todense1, todense2, args_maker) + self._CompileAndCheck(todense2, args_maker) + + with self.subTest("matvec"): + v = rng_dense(M.shape[-1:], dtype) + args_maker = lambda: [M, M_sp, v] + def matvec1(M, _, v): + assert isinstance(M, np.ndarray) + return M @ v + def matvec2(_, M, v): + assert isinstance(M, mat_type) + return M @ v + self._CheckAgainstNumpy(matvec1, matvec2, args_maker) + self._CompileAndCheck(matvec2, args_maker) + + with self.subTest("matmat"): + B = rng_dense(M.shape[::-1], dtype) + args_maker = lambda: [M, M_sp, B] + def matmat1(M, _, B): + assert isinstance(M, np.ndarray) + return M @ B + def matmat2(_, M, B): + assert isinstance(M, mat_type) + return M @ B + self._CheckAgainstNumpy(matmat1, matmat2, args_maker) + self._CompileAndCheck(matmat2, args_maker) @jtu.sample_product( shape=[(5, 8), (8, 5), (5, 5), (8, 8)],