diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index e22c5ffd974a..2e726fbd2d3f 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1086,11 +1086,43 @@ def bcoo_dot_general_sampled(A: Array, B: Array, indices: Array, *, dimension_nu return bcoo_dot_general_sampled_p.bind(A, B, indices, dimension_numbers=(cdims, bdims)) +def _bcoo_dot_general_sampled_slow(A, B, indices, *, dimension_numbers): + return _bcoo_extract(indices, lax.dot_general(A, B, dimension_numbers=dimension_numbers)) + +def _bcoo_dot_general_sampled_simple(A, B, indices, *, dimension_numbers): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + assert not (lhs_contract or rhs_contract or lhs_batch or rhs_batch) + assert A.ndim == B.ndim == 1 + n_batch = indices.ndim - 2 + n_sparse = indices.shape[-1] + nse = indices.shape[-2] + assert n_batch + n_sparse == 2 + if n_batch == 0: + return A[indices[:, 0]] * B[indices[:, 1]] + elif n_batch == 1: + return A[:, None] * B[indices[..., 0]] + elif n_batch == 2: + out = A[:, None, None] * B[None, :, None] + return lax.broadcast_in_dim(out, (len(A), len(B), nse), (0, 1, 2)) + else: + raise ValueError("too many batch dimensions.") + @bcoo_dot_general_sampled_p.def_impl def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers): - # TODO(jakevdp): use a more efficient implementation that avoids the full dot product. - dense_result = lax.dot_general(A, B, dimension_numbers=dimension_numbers) - return _bcoo_extract(indices, dense_result) + A = jnp.asarray(A) + B = jnp.asarray(B) + indices = jnp.asarray(indices) + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + n_batch = indices.ndim - 2 + n_sparse = indices.shape[-1] + + # TODO(jakevdp): add fast approach for more general cases. + if (not (lhs_contract or rhs_contract or lhs_batch or rhs_batch) + and A.ndim == B.ndim == 1 and n_sparse + n_batch == 2): + return _bcoo_dot_general_sampled_simple(A, B, indices, dimension_numbers=dimension_numbers) + + return _bcoo_dot_general_sampled_slow(A, B, indices, dimension_numbers=dimension_numbers) + @bcoo_dot_general_sampled_p.def_abstract_eval def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers): diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 2a2a3fcd54b8..fac17497f3be 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -1294,6 +1294,30 @@ def sparse_fun(lhs, rhs, indices): # TODO(jakevdp) fix forward-mode autodiff & enable tests here. self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=['rev'], argnums=[0, 1]) + @jtu.sample_product( + xshape=[(3,), (5,)], + yshape=[(3,), (5,)], + dtype=jtu.dtypes.floating + jtu.dtypes.complex, + n_batch=[0, 1, 2], + ) + def test_bcoo_dot_general_sampled_fast(self, xshape, yshape, n_batch, dtype): + rng = jtu.rand_default(self.rng()) + sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch) + + dimension_numbers = (([], []), ([], [])) + args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype), + sprng(xshape + yshape, dtype).indices] + + def f1(x, y, indices): + mat_full = lax.dot_general(x, y, dimension_numbers=dimension_numbers) + return sparse_bcoo._bcoo_extract(indices, mat_full) + + def f2(x, y, indices): + return sparse.bcoo_dot_general_sampled(x, y, indices, dimension_numbers=dimension_numbers) + + self._CheckAgainstNumpy(f1, f2, args_maker) + self._CompileAndCheck(f2, args_maker) + @jtu.sample_product( [dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)