Skip to content

Commit

Permalink
[sparse] add bcoo_sort_indices
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 18, 2022
1 parent 6411f8a commit 16d6c4d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
22 changes: 22 additions & 0 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse, remove_zeros):
data_unique = jnp.where(oob_mask[(...,) + props.n_dense * (None,)], 0, data_unique)
return data_unique, indices_unique, nse

def _bcoo_sort_indices(data, indices, shape):
props = _validate_bcoo(data, indices, shape)
if props.n_sparse == 0:
return data, indices
def f(data, indices):
_, N = indices.shape
idx_cols = (indices[:, i] for i in range(N))
if data.ndim > 1:
*indices, i = lax.sort((*idx_cols, lax.iota(indices.dtype, len(data))), num_keys=N)
data = data[i]
else:
*indices, data = lax.sort((*idx_cols, data), num_keys=N)
return data, jnp.column_stack(indices)
for _ in range(props.n_batch):
f = broadcasting_vmap(f)
return f(data, indices)

def _unbatch_bcoo(data, indices, shape):
n_batch = _validate_bcoo(data, indices, shape).n_batch
if n_batch == 0:
Expand Down Expand Up @@ -1241,6 +1258,11 @@ def sum_duplicates(self, nse=None, remove_zeros=True):
nse=nse, remove_zeros=remove_zeros)
return BCOO((data, indices), shape=self.shape)

def sort_indices(self):
"""Return a copy of the matrix with indices sorted."""
data, indices = _bcoo_sort_indices(self.data, self.indices, self.shape)
return BCOO((data, indices), shape=self.shape)

def todense(self):
"""Create a dense version of the array."""
return bcoo_todense(self.data, self.indices, spinfo=self._info)
Expand Down
22 changes: 22 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,28 @@ def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse, remove_z
self.assertAllClose(M.todense(), M_dedup.todense())
self.assertEqual(M_dedup.nse, nse)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_sort_indices(self, shape, dtype, n_batch, n_dense):
rng_sparse = rand_sparse(self.rng(), rand_method=jtu.rand_some_zero)
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
M.indices = M.indices[..., ::-1, :]

M_sorted = M.sort_indices()
self.assertArraysEqual(M.todense(), M_sorted.todense())

indices = M_sorted.indices
if indices.size > 0:
flatind = indices.reshape(-1, *indices.shape[-2:]).transpose(0, 2, 1)
sorted = jax.vmap(jnp.lexsort)(flatind[:, ::-1])
self.assertTrue(jnp.all(sorted == jnp.arange(sorted.shape[-1])))

def test_bcoo_sum_duplicates_inferred_nse(self):
x = sparse.BCOO.fromdense(jnp.diag(jnp.arange(4)))
self.assertEqual(x.nse, 3)
Expand Down

0 comments on commit 16d6c4d

Please sign in to comment.