Skip to content

Commit

Permalink
[sparse] bcoo_dynamic_slice: remove unnecessary padding from output
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 9, 2022
1 parent fa0217b commit 46d9cac
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
9 changes: 7 additions & 2 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,7 +1964,7 @@ def bcoo_slice(mat, *, start_indices: Sequence[int], limit_indices: Sequence[int

return BCOO((new_data, new_indices), shape=new_shape)

def bcoo_dynamic_slice(mat, start_indices: Sequence[Any], slice_sizes: Sequence[int]):
def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Sequence[int]) -> BCOO:
"""Sparse implementation of {func}`jax.lax.dynamic_slice`.
Args:
Expand Down Expand Up @@ -2028,11 +2028,16 @@ def bcoo_dynamic_slice(mat, start_indices: Sequence[Any], slice_sizes: Sequence[
sparse_shape = jnp.expand_dims(sparse_shape, range(mat.n_batch + 1))

keep = jnp.all((new_indices >= starts) & (new_indices < starts + sizes), -1, keepdims=True)
new_indices = jnp.where(keep, new_indices - starts, sparse_shape)
new_indices = jnp.where(keep, new_indices - starts, sizes)

keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
new_data = jnp.where(keep_data, new_data, 0)

if mat.nse > np.prod(size_sparse):
new_nse = np.prod(size_sparse)
new_data, new_indices = _bcoo_sum_duplicates(
new_data, new_indices, spinfo=BCOOInfo(shape=new_shape), nse=new_nse)

return BCOO((new_data, new_indices), shape=new_shape)


Expand Down
11 changes: 11 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,17 @@ def test_bcoo_dynamic_slice(self, shape, dtype, n_batch, n_dense):
sparse_result = sparse.bcoo_dynamic_slice(Msp, **kwds)
sparse_result_jit = partial(sparse.bcoo_dynamic_slice, slice_sizes=slice_sizes)(Msp, start_indices)

# Array layout is the same
self.assertEqual(sparse_result.n_batch, Msp.n_batch)
self.assertEqual(sparse_result.n_sparse, Msp.n_sparse)
self.assertEqual(sparse_result.n_dense, Msp.n_dense)

# Unnecessary padding eliminated
max_nse = np.prod(sparse_result.shape[Msp.n_batch: Msp.n_batch + Msp.n_sparse])
self.assertLessEqual(sparse_result.nse, max_nse)
self.assertLessEqual(sparse_result_jit.nse, max_nse)

# Result matches dense computation
self.assertArraysEqual(dense_result, sparse_result.todense())
self.assertArraysEqual(dense_result, sparse_result_jit.todense())

Expand Down

0 comments on commit 46d9cac

Please sign in to comment.