Skip to content

Commit

Permalink
[sparse] add bcoo_add_batchdim
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 6, 2022
1 parent 4012267 commit 93a24f3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/experimental/sparse/__init__.py
Expand Up @@ -188,6 +188,7 @@
value_and_grad as value_and_grad,
)
from jax.experimental.sparse.bcoo import (
bcoo_add_batch_dim as bcoo_add_batch_dim,
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
bcoo_dot_general as bcoo_dot_general,
bcoo_dot_general_p as bcoo_dot_general_p,
Expand Down
39 changes: 39 additions & 0 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -1228,6 +1228,45 @@ def _bcoo_spdot_general_jvp(primals, tangents, **kwds):

#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?


def bcoo_add_batch_dim(M):
"""Convert a sparse dimension to a batch dimension
Please note that this function may result in a far less efficient storage scheme
for the matrix (storage required will increase by a factor of `M.shape[0] * M.nse`).
This utility is provided for convenience, e.g. to allow vmapping over non-batched
matrices.
Args:
M: BCOO matrix
Returns:
M2: BCOO matrix with n_batch = M.n_batch + 1 and n_sparse = M.n_sparse - 1
"""
# TODO(jakevdp): allow user-specified nse?
if M.n_sparse == 0:
raise ValueError("Cannot add a batch dimension to a matrix with n_sparse=0")
f = _add_batch_dim
for _ in range(M.n_batch):
f = vmap(f)
return f(M)

def _add_batch_dim(M):
assert M.n_batch == 0
assert M.n_sparse > 0
data = jnp.zeros_like(M.data, shape=(M.shape[0], *M.data.shape))
data = data.at[M.indices[:, 0], jnp.arange(M.nse)].set(M.data)
indices_shape = (M.shape[0], M.nse, M.n_sparse - 1)
if M.n_sparse > 1:
fill_value = jnp.array(M.shape[M.n_batch + 1: M.n_batch + M.n_sparse])
indices = jnp.full_like(M.indices, shape=indices_shape, fill_value=fill_value)
indices = indices.at[M.indices[:, 0], jnp.arange(M.nse)].set(M.indices[:, 1:])
else:
indices = jnp.empty_like(M.indices, shape=indices_shape)
return BCOO((data, indices), shape=M.shape)


def bcoo_broadcast_in_dim(mat, *, shape, broadcast_dimensions):
"""Expand the size and rank of a BCOO array by duplicating the data.
Expand Down
18 changes: 18 additions & 0 deletions tests/sparse_test.py
Expand Up @@ -1772,6 +1772,24 @@ def test_bcoo_unbatch(self, shape, dtype, n_batch, n_dense):
self.assertEqual(M1.dtype, M2.dtype)
self.assertArraysEqual(M1.todense(), M2.todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for n_batch in range(len(shape))
for n_dense in range(len(shape) - n_batch)))
def test_bcoo_add_batch_dim(self, shape, dtype, n_batch, n_dense):
rng_sparse = rand_sparse(self.rng())
M1 = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
M2 = sparse.bcoo_add_batch_dim(M1)
self.assertEqual(M2.n_batch, M1.n_batch + 1)
self.assertEqual(M1.n_dense, M2.n_dense)
self.assertEqual(M1.shape, M2.shape)
self.assertEqual(M1.dtype, M2.dtype)
self.assertArraysEqual(M1.todense(), M2.todense())

def test_bcoo_bad_fillvals(self):
# Extra values have 100 rather than zero. This lets us check that logic is
# properly ignoring these indices.
Expand Down

0 comments on commit 93a24f3

Please sign in to comment.