Skip to content

Commit

Permalink
Generalize BCOO.__matmul__
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 14, 2021
1 parent 1e4d28a commit 510a777
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
41 changes: 31 additions & 10 deletions jax/experimental/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def result(out_array, lhs_data, lhs_indices, rhs):
idx_right, idx_out = idx[:n_contracting], idx[n_contracting:]
ctc = [0] if n_contracting else []
prod = lax.dot_general(lhs_data, rhs[idx_right], (([], []), (ctc, ctc)))
return out_array.at[idx_out].add(prod) if idx_out else prod.sum(0)
return out_array.at[idx_out].add(prod) if idx_out else prod.sum(0, dtype=out_array.dtype)
for i in range(n_batch)[::-1]:
axes_in = [0, 0, 0, 0]
if lhs_data.shape[i] == 1:
Expand Down Expand Up @@ -1040,6 +1040,10 @@ class JAXSparse:
nnz: property
dtype: property

@property
def ndim(self):
return len(self.shape)

def __init__(self, args, *, shape):
self.shape = shape

Expand Down Expand Up @@ -1202,6 +1206,7 @@ class BCOO(JAXSparse):
n_batch = property(lambda self: self.indices.ndim - 2)
n_sparse = property(lambda self: self.indices.shape[-2])
n_dense = property(lambda self: self.data.ndim - 1 - self.n_batch)
shape = Tuple[int, ...]

def __init__(self, args, *, shape):
self.data, self.indices = map(jnp.asarray, args)
Expand All @@ -1215,15 +1220,31 @@ def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32, n_dense=0, n_batch=0)
def todense(self):
return bcoo_todense(self.data, self.indices, shape=self.shape)

@api.jit
def matvec(self, v):
return bcoo_dot_general(self.data, self.indices, v, lhs_shape=self.shape,
dimension_numbers=(([1], [0]), ([], [])))

@api.jit
def matmat(self, B):
return bcoo_dot_general(self.data, self.indices, B, lhs_shape=self.shape,
dimension_numbers=(([1], [0]), ([], [])))
def __matmul__(self, other):
if isinstance(other, JAXSparse):
raise NotImplementedError("sparse-sparse matmul")
other = jnp.asarray(other)
if self.ndim == 0 or other.ndim == 0:
raise ValueError("matmul inputs cannot be zero-dimensional.")
if self.ndim > 2 or other.ndim > 2:
raise NotImplementedError("sparse matmul for dimensions larger than 2")
dtype = jnp.promote_types(self.dtype, other.dtype)
return bcoo_dot_general(self.data.astype(dtype), self.indices, other.astype(dtype),
lhs_shape=self.shape,
dimension_numbers=(([self.ndim - 1], [0]), ([], [])))

def __rmatmul__(self, other):
if isinstance(other, JAXSparse):
raise NotImplementedError("sparse-sparse matmul")
other = jnp.asarray(other)
if self.ndim == 0 or other.ndim == 0:
raise ValueError("matmul inputs cannot be zero-dimensional.")
if self.ndim > 2 or other.ndim > 2:
raise NotImplementedError("sparse matmul for dimensions larger than 2")
dtype = jnp.promote_types(self.dtype, other.dtype)
return bcoo_dot_general(self.data.astype(dtype), self.indices, other.astype(dtype),
lhs_shape=self.shape,
dimension_numbers=(([0], [other.ndim - 1]), ([], []))).T

def transpose(self):
if self.n_batch or self.n_dense:
Expand Down
30 changes: 30 additions & 0 deletions tests/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
np.complex128: 1E-10,
}

all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex


def rand_sparse(rng, nnz=0.5, post=lambda x: x):
def _rand_sparse(shape, dtype, nnz=nnz):
Expand Down Expand Up @@ -735,6 +737,34 @@ def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
tol = {np.float32: 1E-6, np.float64: 1E-14}
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)

@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
"lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
"rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
}
for lhs_shape, rhs_shape in [[(3,), (3,)],
[(3, 4), (4,)],
[(4,), (4, 5)],
[(3, 4), (4, 5)]]
for lhs_dtype in all_dtypes
for rhs_dtype in all_dtypes))
def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
rng = jtu.rand_default(self.rng())
lhs = jnp.array(rng(lhs_shape, lhs_dtype))
rhs = jnp.array(rng(rhs_shape, rhs_dtype))

out1 = lhs @ rhs
out2 = sparse_ops.BCOO.fromdense(lhs) @ rhs
out3 = lhs @ sparse_ops.BCOO.fromdense(rhs)

tol = {np.float64: 1E-13, np.complex128: 1E-13,
np.float32: 1E-6, np.complex64: 1E-6}
self.assertAllClose(out1, out2, rtol=tol)
self.assertAllClose(out1, out3, rtol=tol)


class SparseObjectTest(jtu.JaxTestCase):
@parameterized.named_parameters(
Expand Down

0 comments on commit 510a777

Please sign in to comment.