Skip to content

Commit

Permalink
[sparse] Implement several BCOO methods via sparsify
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 2, 2021
1 parent c97d63d commit bb6e463
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
26 changes: 26 additions & 0 deletions jax/experimental/sparse/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,3 +1469,29 @@ def tree_unflatten(cls, aux_data, children):
+ tuple(sparse_shape)
+ tuple(data.shape[n_batch + 1:]))
return cls(children, shape=shape)

# TODO(jakevdp): refactor to avoid circular imports - we can use the same strategy
# we use when adding methods to DeviceArray within lax_numpy.py
def __neg__(self):
from jax.experimental.sparse import sparsify
return sparsify(jnp.negative)(self)

def __mul__(self, other):
from jax.experimental.sparse import sparsify
return sparsify(jnp.multiply)(self, other)

def __rmul__(self, other):
from jax.experimental.sparse import sparsify
return sparsify(jnp.multiply)(other, self)

def __add__(self, other):
from jax.experimental.sparse import sparsify
return sparsify(jnp.add)(self, other)

def __radd__(self, other):
from jax.experimental.sparse import sparsify
return sparsify(jnp.add)(other, self)

def sum(self, *args, **kwargs):
from jax.experimental.sparse import sparsify
return sparsify(lambda x: x.sum(*args, **kwargs))(self)
14 changes: 14 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,20 @@ def test_matmul(self, shape, dtype, Obj, bshape):

self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)

def test_bcoo_methods(self):
M = jnp.arange(12).reshape(3, 4)
Msp = sparse.BCOO.fromdense(M)

self.assertArraysEqual(-M, (-Msp).todense())

self.assertArraysEqual(2 * M, (2 * Msp).todense())
self.assertArraysEqual(M * 2, (Msp * 2).todense())

self.assertArraysEqual(M + M, (Msp + Msp).todense())

self.assertArraysEqual(M.sum(0), Msp.sum(0).todense())
self.assertArraysEqual(M.sum(1), Msp.sum(1).todense())
self.assertArraysEqual(M.sum(), Msp.sum())

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit bb6e463

Please sign in to comment.