Skip to content

Commit

Permalink
[sparse] bcoo_mul: support mixing batch & sparse dims
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 18, 2022
1 parent e392af3 commit b2e0dea
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
21 changes: 15 additions & 6 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -1209,21 +1209,30 @@ def bcoo_multiply_sparse(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_sp
# Similar requirement as lax.mul:
raise TypeError("bcoo_multiply_sparse: arrays must have same number of dimensions, "
f"got {lhs_shape}, {rhs_shape}")
if (lhs.n_batch, lhs.n_sparse, lhs.n_dense) != (rhs.n_batch, rhs.n_sparse, rhs.n_dense):
if lhs.n_dense != rhs.n_dense:
raise NotImplementedError("bcoo_multiply_sparse: arrays with differing numbers of "
f"batch & dense dimensions: {lhs}, {rhs}")
f"dense dimensions: {lhs}, {rhs}")
n_batch = min(lhs.n_batch, rhs.n_batch)
_mul = functools.partial(_bcoo_multiply_sparse_unbatched,
lhs_shape=lhs_shape[lhs.n_batch:],
rhs_shape=rhs_shape[rhs.n_batch:])
for _ in range(lhs.n_batch):
lhs_shape=lhs_shape[n_batch:],
rhs_shape=rhs_shape[n_batch:])
for _ in range(n_batch):
_mul = broadcasting_vmap(_mul)
data, indices = _mul(lhs_data, lhs_indices, rhs_data, rhs_indices)
return data, indices, jnp.broadcast_shapes(lhs_shape, rhs_shape)

def _bcoo_multiply_sparse_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_shape, rhs_shape):
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
assert lhs.n_batch == rhs.n_batch == 0
assert (lhs.n_batch == 0) or (rhs.n_batch == 0) # Ensured at call site above

# TODO(jakevdp): this can be made more efficient by utilizing batch structure.
if lhs.n_batch:
lhs_data, lhs_indices = _unbatch_bcoo(lhs_data, lhs_indices, lhs_shape)
lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
elif rhs.n_batch:
rhs_data, rhs_indices = _unbatch_bcoo(rhs_data, rhs_indices, rhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
dims = jnp.array([i for i, (s1, s2) in enumerate(safe_zip(lhs_shape[:lhs.n_sparse], rhs_shape[:rhs.n_sparse]))
if s1 != 1 and s2 != 1], dtype=int)

Expand Down
21 changes: 11 additions & 10 deletions tests/sparse_test.py
Expand Up @@ -1659,33 +1659,34 @@ def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batc
self.assertAllClose(out1, out2, rtol=tol)
self.assertAllClose(out1, out3, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_n_batch={}_n_dense={}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
n_batch, n_dense),
{"testcase_name": "_{}_n_batch={}_{}_n_batch={}_n_dense={}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), lhs_n_batch,
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), rhs_n_batch, n_dense),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
"n_batch": n_batch, "n_dense": n_dense,
"lhs_n_batch": lhs_n_batch, "rhs_n_batch": rhs_n_batch, "n_dense": n_dense,
}
# TODO(jakevdp): add broadcasted shapes (from bcoo_mul_dense) once sparse-sparse mul
# supports inputs of differing rank.
for lhs_shape, rhs_shape in [[(3,), (1,)], [(3,), (3,)],
[(3, 4), (1, 1)], [(3, 4), (1, 4)], [(3, 4), (3, 1)], [(3, 4), (3, 4)],
[(3, 4, 5), (1, 4, 5)], [(3, 4, 5), (3, 1, 1)], [(3, 4, 5), (1, 4, 1)]]
# TODO(jakevdp): add tests for batch & dense dimensions.
for n_batch in range(len(lhs_shape) + 1)
for n_dense in range(len(lhs_shape) + 1 - n_batch)
for lhs_n_batch in range(len(lhs_shape) + 1)
for rhs_n_batch in range(len(lhs_shape) + 1)
for n_dense in range(len(lhs_shape) + 1 - max(lhs_n_batch, rhs_n_batch))
for lhs_dtype in all_dtypes
for rhs_dtype in all_dtypes))
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n_batch, rhs_n_batch, n_dense):
rng = rand_sparse(self.rng())
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
rhs = jnp.array(rng(rhs_shape, rhs_dtype))

sp = lambda x: sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)
lhs_sp = sparse.BCOO.fromdense(lhs, n_batch=lhs_n_batch, n_dense=n_dense)
rhs_sp = sparse.BCOO.fromdense(rhs, n_batch=rhs_n_batch, n_dense=n_dense)

out1 = lhs * rhs
out2 = (sp(lhs) * sp(rhs)).todense()
out2 = (lhs_sp * rhs_sp).todense()

tol = {np.float64: 1E-13, np.complex128: 1E-13,
np.float32: 1E-6, np.complex64: 1E-6}
Expand Down

0 comments on commit b2e0dea

Please sign in to comment.