Skip to content

Commit

Permalink
[sparse] add sparse support for dynamic_slice
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 1, 2022
1 parent 1dbfa88 commit 47b9f21
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 3 deletions.
1 change: 1 addition & 0 deletions jax/experimental/sparse/__init__.py
Expand Up @@ -194,6 +194,7 @@
bcoo_dot_general_p as bcoo_dot_general_p,
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
bcoo_dot_general_sampled_p as bcoo_dot_general_sampled_p,
bcoo_dynamic_slice as bcoo_dynamic_slice,
bcoo_extract as bcoo_extract,
bcoo_extract_p as bcoo_extract_p,
bcoo_fromdense as bcoo_fromdense,
Expand Down
78 changes: 75 additions & 3 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -1790,13 +1790,13 @@ def bcoo_slice(mat, *, start_indices: Sequence[int], limit_indices: Sequence[int
"""Sparse implementation of {func}`jax.lax.slice`.
Args:
operand: BCOO array to be reshaped.
mat: BCOO array to be reshaped.
start_indices: sequence of integers of length `mat.ndim` specifying the starting
indices of each slice.
limit_indices: sequence of integers of length `mat.ndim` specifying the ending
indices of each slice
strides: sequence of integers of length `mat.ndim` specifying the stride for
each slice
strides: (not implemented) sequence of integers of length `mat.ndim` specifying
the stride for each slice
Returns:
out: BCOO array containing the slice.
Expand Down Expand Up @@ -1848,6 +1848,78 @@ def bcoo_slice(mat, *, start_indices: Sequence[int], limit_indices: Sequence[int

return BCOO((new_data, new_indices), shape=new_shape)

def bcoo_dynamic_slice(mat, start_indices: Sequence[Any], slice_sizes: Sequence[int]):
"""Sparse implementation of {func}`jax.lax.dynamic_slice`.
Args:
mat: BCOO array to slice.
start_indices: a list of scalar indices, one per dimension. These values
may be dynamic.
slice_sizes: the size of the slice. Must be a sequence of non-negative
integers with length equal to `ndim(operand)`. Inside a JIT compiled
function, only static values are supported (all JAX arrays inside JIT
must have statically known size).
Returns:
out: BCOO array containing the slice.
"""
if not isinstance(mat, BCOO):
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
start_indices = tuple(jnp.asarray(i) for i in start_indices)
assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices)
assert all(i.shape == () for i in start_indices)
slice_sizes = tuple(operator.index(i) for i in slice_sizes)
if len(start_indices) != len(slice_sizes) != mat.ndim:
raise ValueError(f"bcoo_dynamic_slice: indices must have size mat.ndim={mat.ndim}")
if not all(0 <= slice_size <= axis_size for slice_size, axis_size in zip(slice_sizes, mat.shape)):
raise TypeError("slice_sizes must be less than or equal to operand shape, "
f"got slice_sizes {slice_sizes} for operand shape {mat.shape}")

start_batch, start_sparse, start_dense = split_list(start_indices, [mat.n_batch, mat.n_sparse])
size_batch, size_sparse, size_dense = split_list(slice_sizes, [mat.n_batch, mat.n_sparse])

data_start = []
data_sizes = []
indices_start = []
indices_sizes = []
for i, (start, size) in enumerate(zip(start_batch, size_batch)):
data_is_broadcast = mat.data.shape[i] != mat.shape[i]
indices_is_broadcast = mat.indices.shape[i] != mat.shape[i]
data_start.append(0 if data_is_broadcast else start)
data_sizes.append(1 if data_is_broadcast else size)
indices_start.append(0 if indices_is_broadcast else start)
indices_sizes.append(1 if indices_is_broadcast else size)
data_start.append(0)
data_sizes.append(mat.nse)
indices_start.extend([0, 0])
indices_sizes.extend([mat.nse, mat.n_sparse])
data_start.extend(start_dense)
data_sizes.extend(size_dense)

new_data = lax.dynamic_slice(mat.data, data_start, data_sizes)
new_indices = lax.dynamic_slice(mat.indices, indices_start, indices_sizes)
new_shape = slice_sizes

if mat.n_sparse:
starts = jnp.array(start_sparse, dtype=new_indices.dtype)
sizes = jnp.array(size_sparse, dtype=new_indices.dtype)
sparse_shape = jnp.array(mat.shape[mat.n_batch: mat.n_batch + mat.n_sparse], dtype=new_indices.dtype)
starts = jnp.where(starts < 0, starts + sparse_shape, starts)
starts = jnp.clip(starts, 0, sparse_shape - sizes)

starts = jnp.expand_dims(starts, range(mat.n_batch + 1))
sizes = jnp.expand_dims(sizes, range(mat.n_batch + 1))
sparse_shape = jnp.expand_dims(sparse_shape, range(mat.n_batch + 1))

keep = jnp.all((new_indices >= starts) & (new_indices < starts + sizes), -1, keepdims=True)
new_indices = jnp.where(keep, new_indices - starts, sparse_shape)

keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
new_data = jnp.where(keep_data, new_data, 0)

return BCOO((new_data, new_indices), shape=new_shape)


def _tuple_replace(tup, ind, val):
return tuple(val if i == ind else t for i, t in enumerate(tup))

Expand Down
7 changes: 7 additions & 0 deletions jax/experimental/sparse/transform.py
Expand Up @@ -766,6 +766,13 @@ def _slice_sparse_rule(spenv, *operands, **params):

sparse_rules[lax.slice_p] = _slice_sparse_rule

def _dynamic_slice_sparse_rule(spenv, *operands, **params):
args = spvalues_to_arrays(spenv, operands)
out = sparse.bcoo_dynamic_slice(args[0], args[1:], **params)
return arrays_to_spvalues(spenv, [out])

sparse_rules[lax.dynamic_slice_p] = _dynamic_slice_sparse_rule


#------------------------------------------------------------------------------
# BCOO methods derived from sparsify
Expand Down
29 changes: 29 additions & 0 deletions tests/sparse_test.py
Expand Up @@ -950,8 +950,37 @@ def test_bcoo_slice(self, shape, dtype, n_batch, n_dense):

dense_result = lax.slice(M, **kwds)
sparse_result = sparse.bcoo_slice(Msp, **kwds)
sparse_result_jit = jax.jit(partial(sparse.bcoo_slice, **kwds))(Msp)

self.assertArraysEqual(dense_result, sparse_result.todense())
self.assertArraysEqual(dense_result, sparse_result_jit.todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_dynamic_slice(self, shape, dtype, n_batch, n_dense):
rng = self.rng()
sprng = rand_sparse(rng)
M = sprng(shape, dtype)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)

rng = self.rng()
# Note: test out-of-range start indices
start_indices = rng.randint(-max(M.shape), max(M.shape), M.ndim)
slice_sizes = rng.randint(0, M.shape, M.ndim)
kwds = dict(start_indices=start_indices, slice_sizes=slice_sizes)

dense_result = lax.dynamic_slice(M, **kwds)
sparse_result = sparse.bcoo_dynamic_slice(Msp, **kwds)
sparse_result_jit = partial(sparse.bcoo_dynamic_slice, slice_sizes=slice_sizes)(Msp, start_indices)

self.assertArraysEqual(dense_result, sparse_result.todense())
self.assertArraysEqual(dense_result, sparse_result_jit.todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}_idx={}".format(
Expand Down
24 changes: 24 additions & 0 deletions tests/sparsify_test.py
Expand Up @@ -514,6 +514,30 @@ def func(M):
self.assertArraysEqual(jit(func)(M), M + 1)
self.assertArraysEqual(jit(func)(Msp), M + 1)

def testSparseSlice(self):
M = jnp.arange(24).reshape(2, 3, 4)
Msp = BCOO.fromdense(M)
@self.sparsify
def func(M):
return lax.slice(M, (0, 1, 2), (1, 3, 3))
expected = M[:1, 1:3, 2:3]
self.assertArraysEqual(func(M), expected)
self.assertArraysEqual(func(Msp).todense(), expected)
self.assertArraysEqual(jit(func)(M), expected)
self.assertArraysEqual(jit(func)(Msp).todense(), expected)

def testSparseDynamicSlice(self):
M = jnp.arange(24).reshape(2, 3, 4)
Msp = BCOO.fromdense(M)
@self.sparsify
def func(M):
return lax.dynamic_slice(M, (0, 1, 2), (1, 1, 3))
expected = M[:1, 1:2, 1:4]
self.assertArraysEqual(func(M), expected)
self.assertArraysEqual(func(Msp).todense(), expected)
self.assertArraysEqual(jit(func)(M), expected)
self.assertArraysEqual(jit(func)(Msp).todense(), expected)

def testWeakTypes(self):
# Regression test for https://github.com/google/jax/issues/8267
M = jnp.arange(12, dtype='int32').reshape(3, 4)
Expand Down

0 comments on commit 47b9f21

Please sign in to comment.