Skip to content

Commit

Permalink
[sparse] add BCOO._dedupe() method
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 3, 2021
1 parent ca3135c commit 4bb7018
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
16 changes: 10 additions & 6 deletions jax/experimental/sparse/ops.py
Expand Up @@ -565,13 +565,13 @@ def _bcoo_nse(mat, n_batch=0, n_dense=0):
mask = mask.sum(list(range(n_batch, mask.ndim)))
return mask.max()

def _dedupe_bcoo(data, indices):
def _dedupe_bcoo(data, indices, shape):
n_batch, _, _ = _validate_bcoo(data, indices, shape)
if indices.shape[:n_batch] != data.shape[:n_batch]:
# TODO: handle broadcasted dimensions.
raise NotImplementedError("dedupe_bcoo for broadcasted dimensions.")
f = _dedupe_bcoo_one
n_batch = indices.ndim - 2
for s1, s2 in safe_zip(indices.shape[:n_batch], data.shape[:n_batch]):
if s1 != s2:
# TODO: handle broadcasted dimensions.
raise NotImplementedError("dedupe_bcoo for broadcasted dimensions.")
for _ in range(n_batch):
f = vmap(f)
return f(data, indices)

Expand Down Expand Up @@ -1496,6 +1496,10 @@ def _unbatch(self):
"""Return an unbatched representation of the BCOO matrix."""
return BCOO(_unbatch_bcoo(self.data, self.indices, self.shape), shape=self.shape)

def _dedupe(self):
"""Return a de-duplicated representation of the BCOO matrix."""
return BCOO(_dedupe_bcoo(self.data, self.indices, self.shape), shape=self.shape)

@api.jit
def todense(self):
"""Create a dense version of the array."""
Expand Down
14 changes: 5 additions & 9 deletions tests/sparse_test.py
Expand Up @@ -25,7 +25,7 @@
from jax import config
from jax import dtypes
from jax.experimental import sparse
from jax.experimental.sparse.ops import _bcoo_nse, _dedupe_bcoo
from jax.experimental.sparse.ops import _bcoo_nse
from jax import lax
from jax.lib import cusparse
from jax.lib import xla_bridge
Expand Down Expand Up @@ -963,15 +963,11 @@ def sparse_fun(lhs, rhs, indices):
def test_bcoo_dedupe(self, shape, dtype, n_batch, n_dense):
rng = self.rng()
rng_sparse = rand_sparse(self.rng())
M = rng_sparse(shape, dtype)
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype))
for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
indices = indices.at[..., i, :].set(rng.randint(0, s, size=indices.shape[-1]))
data2, indices2 = _dedupe_bcoo(data, indices)
M1 = sparse.bcoo_todense(data, indices, shape=shape)
M2 = sparse.bcoo_todense(data2, indices2, shape=shape)

self.assertAllClose(M1, M2)
M.indices = M.indices.at[..., i, :].set(rng.randint(0, s, size=M.indices.shape[-1]))
M_dedup = M._dedupe()
self.assertAllClose(M.todense(), M_dedup.todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}_axes={}".format(
Expand Down

0 comments on commit 4bb7018

Please sign in to comment.