Skip to content

Commit

Permalink
[sparse] incremental improvement to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 4, 2022
1 parent 8c6f916 commit 9e371c7
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 11 deletions.
14 changes: 14 additions & 0 deletions docs/jax.experimental.sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,17 @@ API

BCOO
sparsify
bcoo_broadcast_in_dim
bcoo_concatenate
bcoo_dot_general
bcoo_dot_general_sampled
bcoo_extract
bcoo_fromdense
bcoo_multiply_dense
bcoo_multiply_sparse
bcoo_reduce_sum
bcoo_reshape
bcoo_sort_indices
bcoo_sum_duplicates
bcoo_todense
bcoo_transpose
3 changes: 2 additions & 1 deletion jax/experimental/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
-----------------------------------------
The main high-level sparse object currently available in JAX is the :class:`BCOO`,
or *batched coordinate* sparse array, which offers a compressed storage format compatible
with JAX functions.
with JAX transformations, in particular JIT (e.g. :func:`jax.jit`), batching
(e.g. :func:`jax.vmap`) and autodiff (e.g. :func:`jax.grad`).
Here is an example of creating a sparse array from a dense array:
Expand Down
16 changes: 8 additions & 8 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,14 +1619,14 @@ def bcoo_reshape(mat, *, new_sizes, dimensions):
"""Sparse implementation of {func}`jax.lax.reshape`.
Args:
operand: BCOO array to be reshaped.
new_sizes: sequence of integers specifying the resulting shape. The size
of the final array must match the size of the input. This must be specified
such that batch, sparse, and dense dimensions do not mix.
dimensions: optional sequence of integers specifying the permutation order of
the input shape. If specified, the length must match ``operand.shape``.
Additionally, dimensions must only permute among like dimensions of mat:
batch, sparse, and dense dimensions cannot be permuted.
operand: BCOO array to be reshaped.
new_sizes: sequence of integers specifying the resulting shape. The size
of the final array must match the size of the input. This must be specified
such that batch, sparse, and dense dimensions do not mix.
dimensions: optional sequence of integers specifying the permutation order of
the input shape. If specified, the length must match ``operand.shape``.
Additionally, dimensions must only permute among like dimensions of mat:
batch, sparse, and dense dimensions cannot be permuted.
Returns:
out: reshaped array.
Expand Down
7 changes: 6 additions & 1 deletion jax/experimental/sparse/coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ class COOInfo(NamedTuple):

@tree_util.register_pytree_node_class
class COO(JAXSparse):
"""Experimental COO matrix implemented in JAX; API subject to change."""
"""Experimental COO matrix implemented in JAX.
Note: this class has minimal compatibility with JAX transforms such as
grad and autodiff, and offers very little functionality. In general you
should prefer :class:`jax.experimental.sparse.BCOO`.
"""
data: jnp.ndarray
row: jnp.ndarray
col: jnp.ndarray
Expand Down
7 changes: 6 additions & 1 deletion jax/experimental/sparse/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@

@tree_util.register_pytree_node_class
class CSR(JAXSparse):
"""Experimental CSR matrix implemented in JAX; API subject to change."""
"""Experimental CSR matrix implemented in JAX.
Note: this class has minimal compatibility with JAX transforms such as
grad and autodiff, and offers very little functionality. In general you
should prefer :class:`jax.experimental.sparse.BCOO`.
"""
data: jnp.ndarray
indices: jnp.ndarray
indptr: jnp.ndarray
Expand Down

0 comments on commit 9e371c7

Please sign in to comment.