Skip to content

Commit

Permalink
[sparse] avoid implicit rank promotion
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 25, 2022
1 parent 2e3b483 commit fa24395
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
25 changes: 16 additions & 9 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse, remove_zeros):
data_unique = jnp.zeros_like(data, shape=(nse, *data.shape[1:])).at[0].set(data.sum(0))
indices_unique = jnp.zeros_like(indices, shape=(nse, 0))
return data_unique, indices_unique, nse
fill_value = jnp.array(shape[:props.n_sparse], dtype=indices.dtype)
fill_value = jnp.expand_dims(jnp.array(shape[:props.n_sparse], dtype=indices.dtype),
range(indices.ndim - 1))
out_of_bounds = (indices >= fill_value).any(-1, keepdims=True)
if remove_zeros:
data_all_zero = (data == 0).all(range(props.n_batch + 1, data.ndim))[:, None]
Expand All @@ -115,7 +116,7 @@ def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse, remove_zeros):
size=nse, fill_value=fill_value)
data_shape = [indices_unique.shape[0], *data.shape[1:]]
data_unique = jnp.zeros(data_shape, data.dtype).at[inv_idx].add(data)
oob_mask = jnp.all(indices_unique == jnp.array(shape[:props.n_sparse]), 1)
oob_mask = jnp.all(indices_unique == fill_value, 1)
data_unique = jnp.where(oob_mask[(...,) + props.n_dense * (None,)], 0, data_unique)
return data_unique, indices_unique, nse

Expand Down Expand Up @@ -316,7 +317,8 @@ def _nonzero(a):
indices = jnp.moveaxis(jnp.array(indices, index_dtype), 0, n_batch + 1)
data = bcoo_extract(indices, mat)

true_nonzeros = jnp.arange(nse) < mask.sum(list(range(n_batch, mask.ndim)))[..., None]
true_nonzeros = (lax.broadcasted_iota(jnp.int32, (1,) * n_batch + (nse,), n_batch) <
mask.sum(list(range(n_batch, mask.ndim)))[..., None])
true_nonzeros = true_nonzeros[(n_batch + 1) * (slice(None),) + n_dense * (None,)]
data = jnp.where(true_nonzeros, data, 0)

Expand Down Expand Up @@ -814,9 +816,13 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
# jnp.isin() currently doesn't help much, because it also does all() over an outer
# comparison.
overlap = (lhs_i[:, None] == rhs_i[None, :]).all(-1)
lhs_valid = (lhs_i < jnp.array([lhs_shape[d] for d in lhs_contracting])).all(-1)
rhs_valid = (rhs_i < jnp.array([rhs_shape[d] for d in rhs_contracting])).all(-1)
out_data = jnp.where(overlap & lhs_valid[:, None] & rhs_valid,
lhs_fill_value = jnp.expand_dims(
jnp.array([lhs_shape[d] for d in lhs_contracting]), range(lhs_i.ndim - 1))
rhs_fill_value = jnp.expand_dims(
jnp.array([rhs_shape[d] for d in rhs_contracting]), range(rhs_i.ndim - 1))
lhs_valid = (lhs_i < lhs_fill_value).all(-1)
rhs_valid = (rhs_i < rhs_fill_value).all(-1)
out_data = jnp.where(overlap & lhs_valid[:, None] & rhs_valid[None, :],
lhs_data[:, None] * rhs_data[None, :], 0).ravel()

out_indices = jnp.empty([lhs.nse, rhs.nse, lhs_j.shape[-1] + rhs_j.shape[-1]],
Expand Down Expand Up @@ -1017,8 +1023,9 @@ def bcoo_reduce_sum(data, indices, *, spinfo, axes):
data = data.sum(dense_axes)
if n_sparse:
# zero-out data corresponding to invalid indices.
sparse_shape = jnp.array(shape[n_batch: n_batch + n_sparse])
mask = jnp.all(indices < sparse_shape, -1)
fill_value = jnp.expand_dims(
jnp.array(shape[n_batch: n_batch + n_sparse]), range(indices.ndim - 1))
mask = jnp.all(indices < fill_value, -1)
if data.ndim > mask.ndim:
mask = lax.expand_dims(mask, tuple(range(mask.ndim, data.ndim)))
data = jnp.where(mask, data, 0)
Expand Down Expand Up @@ -1093,7 +1100,7 @@ def _bcoo_multiply_sparse_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices

# TODO(jakevdp): this is pretty inefficient. Can we do this membership check
# without constructing the full (lhs.nse, rhs.nse) masking matrix?
mask = jnp.all(lhs_indices[:, None, dims] == rhs_indices[:, dims], -1)
mask = jnp.all(lhs_indices[:, None, dims] == rhs_indices[None, :, dims], -1)
i_lhs, i_rhs = jnp.nonzero(mask, size=nse, fill_value=(lhs.nse, rhs.nse))
data = (lhs_data.at[i_lhs].get(mode='fill', fill_value=0) *
rhs_data.at[i_rhs].get(mode='fill', fill_value=0))
Expand Down
10 changes: 9 additions & 1 deletion tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _rand_sparse(shape, dtype, nse=nse):
return _rand_sparse


@jtu.with_config(jax_numpy_rank_promotion="raise")
class cuSparseTest(jtu.JaxTestCase):
def gpu_dense_conversion_warning_context(self, dtype):
if jtu.device_under_test() == "gpu" and np.issubdtype(dtype, np.integer):
Expand Down Expand Up @@ -553,6 +554,8 @@ def test_coo_matmul_ad(self, shape, dtype, bshape):
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class BCOOTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
Expand Down Expand Up @@ -1310,7 +1313,7 @@ def test_bcoo_sort_indices(self, shape, dtype, n_batch, n_dense):
if indices.size > 0:
flatind = indices.reshape(-1, *indices.shape[-2:]).transpose(0, 2, 1)
sorted = jax.vmap(jnp.lexsort)(flatind[:, ::-1])
self.assertTrue(jnp.all(sorted == jnp.arange(sorted.shape[-1])))
self.assertArraysEqual(sorted, lax.broadcasted_iota(sorted.dtype, sorted.shape, sorted.ndim - 1))

def test_bcoo_sum_duplicates_inferred_nse(self):
x = sparse.BCOO.fromdense(jnp.diag(jnp.arange(4)))
Expand Down Expand Up @@ -1418,6 +1421,7 @@ def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
for n_dense in range(len(lhs_shape) + 1 - n_batch)
for lhs_dtype in all_dtypes
for rhs_dtype in all_dtypes))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def test_bcoo_mul_dense(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, n_batch, n_dense):
rng_lhs = rand_sparse(self.rng())
rng_rhs = jtu.rand_default(self.rng())
Expand Down Expand Up @@ -1558,6 +1562,7 @@ def test_bcoo_bad_fillvals(self):
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class SparseGradTest(jtu.JaxTestCase):
def test_sparse_grad(self):
rng_sparse = rand_sparse(self.rng())
Expand All @@ -1580,6 +1585,7 @@ def f(X, y):
self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense)


@jtu.with_config(jax_numpy_rank_promotion="raise")
class SparseObjectTest(jtu.JaxTestCase):
def test_repr(self):
M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
Expand Down Expand Up @@ -1765,6 +1771,8 @@ def test_bcoo_methods(self):
self.assertArraysEqual(M.sum(1), Msp.sum(1).todense())
self.assertArraysEqual(M.sum(), Msp.sum())


@jtu.with_config(jax_numpy_rank_promotion="raise")
class SparseRandomTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_indices_dtype={}_nbatch={}_ndense={}".format(
Expand Down
1 change: 1 addition & 0 deletions tests/sparsify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_numpy_rank_promotion="raise")
class SparsifyTest(jtu.JaxTestCase):
@classmethod
def sparsify(cls, f):
Expand Down

0 comments on commit fa24395

Please sign in to comment.