Skip to content

Commit

Permalink
[sparse] add bcoo_rdot_general
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 16, 2021
1 parent 8f5a784 commit 6003f79
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
15 changes: 12 additions & 3 deletions jax/experimental/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,15 @@ def bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape
return bcoo_dot_general_p.bind(jnp.asarray(lhs_data), jnp.asarray(lhs_indices), jnp.asarray(rhs),
dimension_numbers=dimension_numbers, lhs_shape=tuple(lhs_shape))

def bcoo_rdot_general(lhs, rhs_data, rhs_indices, *, dimension_numbers, rhs_shape):
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
result = bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_shape=rhs_shape,
dimension_numbers=[d[::-1] for d in dimension_numbers])
n_contract, n_batch = (len(d[0]) for d in dimension_numbers)
n_swap = len(rhs_shape) - n_contract
permutation = tuple([*range(n_batch), *range(n_swap, result.ndim), *range(n_batch, n_swap)])
return lax.transpose(result, permutation)

@bcoo_dot_general_p.def_impl
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
lhs_data = jnp.asarray(lhs_data)
Expand Down Expand Up @@ -1246,9 +1255,9 @@ def __rmatmul__(self, other):
if self.ndim > 2 or other.ndim > 2:
raise NotImplementedError("sparse matmul for dimensions larger than 2")
dtype = jnp.promote_types(self.dtype, other.dtype)
return bcoo_dot_general(self.data.astype(dtype), self.indices, other.astype(dtype),
lhs_shape=self.shape,
dimension_numbers=(([0], [other.ndim - 1]), ([], []))).T
return bcoo_rdot_general(other.astype(dtype), self.data.astype(dtype), self.indices,
rhs_shape=self.shape,
dimension_numbers=(([other.ndim - 1], [0]), ([], [])))

def transpose(self):
if self.n_batch or self.n_dense:
Expand Down
48 changes: 45 additions & 3 deletions tests/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def test_bcoo_todense_partial_batch(self, shape, dtype, n_batch, n_dense):
for n_dense in range(len(lhs_shape) - max(lhs_contracting, default=0))
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def test_bcoo_dot_general_contract_only(self, lhs_shape, rhs_shape, dtype,
lhs_contracting, rhs_contracting, n_dense):
lhs_contracting, rhs_contracting, n_dense):
rng = jtu.rand_small(self.rng())
rng_sparse = rand_sparse(self.rng())
def args_maker():
Expand All @@ -565,7 +565,7 @@ def f_sparse(data, indices, lhs, rhs):

self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
# In rare cases, this fails python_should_be_executing check. Why?
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
# self._CompileAndCheck(f_sparse, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -606,7 +606,49 @@ def f_sparse(data, indices, lhs, rhs):

self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
# In rare cases, this fails python_should_be_executing check. Why?
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
# self._CompileAndCheck(f_sparse, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_dimension_numbers={}_n_batch={}_n_dense={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
dimension_numbers, n_batch, n_dense),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"dimension_numbers": dimension_numbers,
"n_batch": n_batch, "n_dense": n_dense}
for lhs_shape, rhs_shape, dimension_numbers, n_batch, n_dense in [
((3, 2, 4), (3, 3, 2), (([1], [2]), ([0], [0])), 1, 0),
((3, 2, 4), (3, 3, 2), (([1], [2]), ([0], [0])), 2, 0),
((2, 3, 4), (3, 3, 2), (([0], [2]), ([1], [0])), 1, 0),
((2, 3, 4), (3, 3, 2), (([0], [2]), ([1], [0])), 2, 0),
((3, 4, 3, 2), (3, 4, 2, 4), (([3], [2]), ([0], [0])), 1, 0),
((3, 4, 3, 2), (3, 4, 2, 4), (([3], [2]), ([0, 1], [0, 1])), 2, 0),
((3, 4, 3, 2), (3, 4, 2, 4), (([3], [2]), ([0, 1], [0, 1])), 2, 1),
]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def test_bcoo_rdot_general_contract_and_batch(self, lhs_shape, rhs_shape, dtype,
dimension_numbers, n_batch, n_dense):
rng = jtu.rand_small(self.rng())
rng_sparse = rand_sparse(self.rng())
def args_maker():
lhs = rng(lhs_shape, dtype)
rhs = rng_sparse(rhs_shape, dtype)
data, indices = sparse_ops.bcoo_fromdense(rhs, n_batch=n_batch, n_dense=n_dense)
return data, indices, lhs, rhs

def f_dense(data, indices, lhs, rhs):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)

def f_sparse(data, indices, lhs, rhs):
return sparse_ops.bcoo_rdot_general(lhs, data, indices,
rhs_shape=rhs.shape,
dimension_numbers=dimension_numbers)

self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
# TODO(jakevdp): In rare cases, this fails python_should_be_executing check. Why?
# self._CompileAndCheck(f_sparse, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down

0 comments on commit 6003f79

Please sign in to comment.