Skip to content

Commit

Permalink
[sparse] bcoo_extract: add assume_unique keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 12, 2023
1 parent 34e10e3 commit e37e3a9
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 34 deletions.
102 changes: 71 additions & 31 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,59 +381,82 @@ def _bcoo_fromdense_batching_rule(batched_args, batch_dims, *, nse, n_batch, n_d

bcoo_extract_p = core.Primitive('bcoo_extract')

def bcoo_extract(indices: Array, mat: Array) -> Array:
def bcoo_extract(indices: Array, mat: Array, *, assume_unique=True) -> Array:
"""Extract BCOO data values from a dense matrix at given BCOO indices.
Args:
indices: An ndarray; see BCOO indices.
mat: A dense matrix.
assume_unique: bool, default=True
If True, then indices will be assumed unique and a value will be extracted
from mat for each index. Otherwise, extra work will be done to de-duplicate
indices to zero-out duplicate extracted values.
Returns:
An ndarray; see BCOO data.
"""
return bcoo_extract_p.bind(indices, mat)
return bcoo_extract_p.bind(indices, mat, assume_unique=assume_unique)

@bcoo_extract_p.def_impl
def _bcoo_extract_impl(indices, mat):
def _bcoo_extract_impl(indices, mat, *, assume_unique):
mat = jnp.asarray(mat)
n_batch, n_sparse, _, nse = _validate_bcoo_indices(indices, mat.shape)
props = _validate_bcoo_indices(indices, mat.shape)
if not assume_unique:
indices, sort_ind = _unique_indices(indices, shape=mat.shape, return_index=True)
original_props = props
props = _validate_bcoo_indices(indices, mat.shape)

ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
for s, i_s in zip(mat.shape[:n_batch], indices.shape[:n_batch]))
for s, i_s in zip(mat.shape[:props.n_batch], indices.shape[:props.n_batch]))
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(n_sparse))
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(props.n_sparse))

batch_slices = tuple(np.arange(s) for s in mat.shape[:n_batch])
batch_slices = tuple(np.arange(s) for s in mat.shape[:props.n_batch])
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
batch_ind = tuple(grid)[:-1]

if not sparse_ind + batch_ind:
result = mat[None]
else:
result = mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
if n_sparse == 0 and nse != 1:
result = lax.broadcast_in_dim(
result, _tuple_replace(result.shape, n_batch, nse), range(result.ndim))
if props.n_sparse == 0 and props.nse != 1:
if assume_unique:
result = lax.broadcast_in_dim(
result, _tuple_replace(result.shape, props.n_batch, props.nse), range(result.ndim))
else:
out_shape = _tuple_replace(result.shape, props.n_batch, original_props.nse)
ind = props.n_batch * (slice(None),) + (slice(1),)
result = jnp.zeros_like(result, shape=out_shape).at[ind].set(result)
if not assume_unique:
unbatched_out_shape = (original_props.nse, *result.shape[props.n_batch + 1:])
def f(r, i):
return jnp.zeros_like(r, shape=unbatched_out_shape).at[i].add(r)
for _ in range(props.n_batch):
f = vmap(f)
result = f(result, sort_ind)
return result

@bcoo_extract_p.def_abstract_eval
def _bcoo_extract_abstract_eval(indices, mat):
def _bcoo_extract_abstract_eval(indices, mat, *, assume_unique):
_ = bool(assume_unique)
n_batch, _, n_dense, nse = _validate_bcoo_indices(indices, mat.shape)
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
return core.ShapedArray(out_shape, mat.dtype)

def _bcoo_extract_jvp(mat_dot, indices, mat):
def _bcoo_extract_jvp(mat_dot, indices, mat, *, assume_unique):
assert mat_dot.shape == mat.shape
return bcoo_extract(indices, mat_dot)
return bcoo_extract(indices, mat_dot, assume_unique=assume_unique)

def _bcoo_extract_transpose(ct, indices, mat):
def _bcoo_extract_transpose(ct, indices, mat, *, assume_unique):
if not assume_unique:
raise NotImplementedError("transpose of bcoo_extract with assume_unique=False")
assert ad.is_undefined_primal(mat)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.dtype == mat.aval.dtype
return indices, _bcoo_todense(ct, indices, spinfo=SparseInfo(mat.aval.shape))

def _bcoo_extract_batching_rule(batched_args, batch_dims):
def _bcoo_extract_batching_rule(batched_args, batch_dims, *, assume_unique):
indices, mat = batched_args
assert any(b is not None for b in batch_dims)
if batch_dims[0] is None:
Expand All @@ -452,7 +475,7 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims):
n_batch = indices.ndim - 2
if bdim >= n_batch:
raise ValueError(f"{batch_dims=} out of range for indices with {n_batch=}")
return bcoo_extract(indices, mat), bdim
return bcoo_extract(indices, mat, assume_unique=assume_unique), bdim

ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
Expand Down Expand Up @@ -1068,7 +1091,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
B_shape = B.aval.shape if hasattr(B, 'aval') else B.shape
mat_shape = _dot_general_validated_shape(A_shape, B_shape, dimension_numbers)
mat = ad.UndefinedPrimal(core.ShapedArray(mat_shape, ct.dtype))
indices, ct = _bcoo_extract_transpose(ct, indices, mat)
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
kwds = {'dimension_numbers': dimension_numbers,
'precision': None,
'preferred_element_type': None}
Expand Down Expand Up @@ -1397,9 +1420,8 @@ def _bcoo_sum_duplicates(data: Array, indices: Array, *, spinfo: SparseInfo, nse
@bcoo_sum_duplicates_p.def_impl
def _bcoo_sum_duplicates_impl(data, indices, *, spinfo, nse):
props = _validate_bcoo(data, indices, spinfo.shape)
f = nfold_vmap(functools.partial(_bcoo_sum_duplicates_unbatched, shape=spinfo.shape[props.n_batch:]),
N=props.n_batch, broadcasted=False)
indices_out, mapping, nse_batched = f(indices)
indices_out, mapping, nse_batched = _unique_indices(
indices, shape=spinfo.shape, return_inverse=True, return_true_size=True)
if nse is None:
nse = 1 if props.n_sparse == 0 else nse_batched.max()
indices_out = _adjust_indices_nse(indices_out, nse=nse, shape=spinfo.shape)
Expand All @@ -1425,22 +1447,40 @@ def _adjust_indices_nse(indices, *, nse, shape):
indices = lax.concatenate([indices, fill], dimension=indices.ndim - 2)
return indices

def _bcoo_sum_duplicates_unbatched(indices, *, shape):
def _unique_indices(indices, *, shape, return_inverse=False,
return_index=False, return_true_size=False):
props = _validate_bcoo_indices(indices, shape)
f = partial(_unique_indices_unbatched, shape=shape[props.n_batch:],
return_inverse=return_inverse, return_index=return_index,
return_true_size=return_true_size)
f = nfold_vmap(f, props.n_batch, broadcasted=False)
return f(indices)

def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
return_index=False, return_true_size=False):
props = _validate_bcoo_indices(indices, shape)
if props.n_sparse == 0:
nse = 1
mapping = jnp.zeros(nse, dtype='int32')
indices_out = jnp.zeros_like(indices, shape=(nse, props.n_sparse))
return indices_out, mapping, nse
indices_out = jnp.zeros_like(indices, shape=(nse, 0))
out = (indices_out,)
if return_index:
out = (*out, jnp.zeros(nse, dtype='int32'))
if return_inverse:
out = (*out, jnp.zeros(nse, dtype='int32'))
if return_true_size:
out = (*out, nse)
return out[0] if len(out) == 1 else out
fill_value = jnp.expand_dims(jnp.array(shape[:props.n_sparse], dtype=indices.dtype), (0,))
out_of_bounds = (indices >= fill_value).any(-1, keepdims=True)
indices = jnp.where(out_of_bounds, fill_value, indices)
# TODO: check if `indices_sorted` is True.
indices_unique, inv_idx, nse = _unique(
indices, axis=0, return_inverse=True, return_true_size=True,
size=props.nse, fill_value=fill_value)
nse = nse - (indices == fill_value).any().astype(nse.dtype)
return indices_unique, inv_idx, nse
out = _unique(indices, axis=0, return_inverse=return_inverse, return_index=return_index,
return_true_size=return_true_size, size=props.nse, fill_value=fill_value)
if return_true_size:
nse = out[-1]
nse = nse - (indices == fill_value).any().astype(nse.dtype)
out = (*out[:-1], nse)
return out

@bcoo_sum_duplicates_p.def_abstract_eval
def _bcoo_sum_duplicates_abstract_eval(data, indices, *, spinfo, nse):
Expand Down Expand Up @@ -1472,8 +1512,8 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse):

data, indices = primals
data_dot, _ = tangents
f = nfold_vmap(functools.partial(_bcoo_sum_duplicates_unbatched, shape=spinfo.shape[props.n_batch:]), props.n_batch)
indices_out, mapping, nse_batched = f(indices)
indices_out, mapping, nse_batched = _unique_indices(
indices, shape=spinfo.shape, return_inverse=True, return_true_size=True)
if nse is None:
nse = jnp.sum(nse_batched)
try:
Expand Down
34 changes: 31 additions & 3 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,18 +804,46 @@ def test_bcoo_fromdense_sorted_and_unique_indices(self):
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for layout in iter_sparse_layouts(shape)],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
assume_unique=[True, False]
)
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense):
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense, assume_unique):
rng = rand_sparse(self.rng())
M = rng(shape, dtype)
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
n_dense=n_dense)
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse)
data2 = sparse.bcoo_extract(indices, M)
bcoo_extract = partial(sparse.bcoo_extract, assume_unique=assume_unique)

data2 = bcoo_extract(indices, M)
self.assertArraysEqual(data, data2)
data3 = jit(sparse.bcoo_extract)(indices, M)

data3 = jit(bcoo_extract)(indices, M)
self.assertArraysEqual(data, data3)

def test_bcoo_extract_duplicate_indices(self):
data = jnp.array([1, 3, 9, 27, 81, 243])
indices = jnp.array([[0], [5], [0], [3], [2], [3]])
shape = (6,)
mat = sparse.BCOO((data, indices), shape=shape).todense()

data1 = sparse.bcoo_extract(indices, mat, assume_unique=True)
self.assertArraysEqual(data1, jnp.array([10, 3, 10, 270, 81, 270]))

data2 = sparse.bcoo_extract(indices, mat, assume_unique=False)
self.assertArraysEqual(data2, jnp.array([10, 3, 0, 270, 81, 0]))

def test_bcoo_extract_duplicate_indices_n_sparse_0(self):
data = jnp.arange(6).reshape(3, 2)
indices = jnp.empty((3, 2, 0), dtype=int)
shape = (3,)
mat = sparse.BCOO((data, indices), shape=shape).todense()

data1 = sparse.bcoo_extract(indices, mat, assume_unique=True)
self.assertArraysEqual(data1, jnp.array([[1, 1], [5, 5], [9, 9]]))

data2 = sparse.bcoo_extract(indices, mat, assume_unique=False)
self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]]))

def test_bcoo_extract_batching(self):
# https://github.com/google/jax/issues/9431
indices = jnp.zeros((4, 1, 1), dtype=int)
Expand Down

0 comments on commit e37e3a9

Please sign in to comment.