Skip to content

Commit

Permalink
[sparse] bcoo_dot_general_sampled: faster special case
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 22, 2023
1 parent 7e001d8 commit 54bd631
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
38 changes: 35 additions & 3 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -1086,11 +1086,43 @@ def bcoo_dot_general_sampled(A: Array, B: Array, indices: Array, *, dimension_nu
return bcoo_dot_general_sampled_p.bind(A, B, indices,
dimension_numbers=(cdims, bdims))

def _bcoo_dot_general_sampled_slow(A, B, indices, *, dimension_numbers):
return _bcoo_extract(indices, lax.dot_general(A, B, dimension_numbers=dimension_numbers))

def _bcoo_dot_general_sampled_simple(A, B, indices, *, dimension_numbers):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
assert not (lhs_contract or rhs_contract or lhs_batch or rhs_batch)
assert A.ndim == B.ndim == 1
n_batch = indices.ndim - 2
n_sparse = indices.shape[-1]
nse = indices.shape[-2]
assert n_batch + n_sparse == 2
if n_batch == 0:
return A[indices[:, 0]] * B[indices[:, 1]]
elif n_batch == 1:
return A[:, None] * B[indices[..., 0]]
elif n_batch == 2:
out = A[:, None, None] * B[None, :, None]
return lax.broadcast_in_dim(out, (len(A), len(B), nse), (0, 1, 2))
else:
raise ValueError("too many batch dimensions.")

@bcoo_dot_general_sampled_p.def_impl
def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
# TODO(jakevdp): use a more efficient implementation that avoids the full dot product.
dense_result = lax.dot_general(A, B, dimension_numbers=dimension_numbers)
return _bcoo_extract(indices, dense_result)
A = jnp.asarray(A)
B = jnp.asarray(B)
indices = jnp.asarray(indices)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
n_batch = indices.ndim - 2
n_sparse = indices.shape[-1]

# TODO(jakevdp): add fast approach for more general cases.
if (not (lhs_contract or rhs_contract or lhs_batch or rhs_batch)
and A.ndim == B.ndim == 1 and n_sparse + n_batch == 2):
return _bcoo_dot_general_sampled_simple(A, B, indices, dimension_numbers=dimension_numbers)

return _bcoo_dot_general_sampled_slow(A, B, indices, dimension_numbers=dimension_numbers)


@bcoo_dot_general_sampled_p.def_abstract_eval
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
Expand Down
24 changes: 24 additions & 0 deletions tests/sparse_test.py
Expand Up @@ -1294,6 +1294,30 @@ def sparse_fun(lhs, rhs, indices):
# TODO(jakevdp) fix forward-mode autodiff & enable tests here.
self._CheckGradsSparse(dense_fun, sparse_fun, args_maker, modes=['rev'], argnums=[0, 1])

@jtu.sample_product(
xshape=[(3,), (5,)],
yshape=[(3,), (5,)],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
n_batch=[0, 1, 2],
)
def test_bcoo_dot_general_sampled_fast(self, xshape, yshape, n_batch, dtype):
rng = jtu.rand_default(self.rng())
sprng = sptu.rand_bcoo(self.rng(), n_batch=n_batch)

dimension_numbers = (([], []), ([], []))
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype),
sprng(xshape + yshape, dtype).indices]

def f1(x, y, indices):
mat_full = lax.dot_general(x, y, dimension_numbers=dimension_numbers)
return sparse_bcoo._bcoo_extract(indices, mat_full)

def f2(x, y, indices):
return sparse.bcoo_dot_general_sampled(x, y, indices, dimension_numbers=dimension_numbers)

self._CheckAgainstNumpy(f1, f2, args_maker)
self._CompileAndCheck(f2, args_maker)

@jtu.sample_product(
[dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape,
rhs_shape=rhs_shape, dimension_numbers=dimension_numbers)
Expand Down

0 comments on commit 54bd631

Please sign in to comment.