Skip to content

Commit

Permalink
[sparse] add support for bcoo equivalent of lax.slice
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 25, 2022
1 parent 5527966 commit 269e752
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 1 deletion.
1 change: 1 addition & 0 deletions jax/experimental/sparse/__init__.py
Expand Up @@ -203,6 +203,7 @@
bcoo_update_layout as bcoo_update_layout,
bcoo_reduce_sum as bcoo_reduce_sum,
bcoo_reshape as bcoo_reshape,
bcoo_slice as bcoo_slice,
bcoo_sort_indices as bcoo_sort_indices,
bcoo_sort_indices_p as bcoo_sort_indices_p,
bcoo_spdot_general_p as bcoo_spdot_general_p,
Expand Down
65 changes: 64 additions & 1 deletion jax/experimental/sparse/bcoo.py
Expand Up @@ -17,7 +17,7 @@
import functools
from functools import partial
import operator
from typing import Any, NamedTuple, Sequence, Tuple
from typing import Any, NamedTuple, Optional, Sequence, Tuple
import warnings

import numpy as np
Expand Down Expand Up @@ -1785,6 +1785,69 @@ def bcoo_reshape(mat, *, new_sizes, dimensions):
return BCOO((data, new_indices), shape=new_sizes)


def bcoo_slice(mat, *, start_indices: Sequence[int], limit_indices: Sequence[int],
strides: Optional[Sequence[int]]=None):
"""Sparse implementation of {func}`jax.lax.slice`.
Args:
operand: BCOO array to be reshaped.
start_indices: sequence of integers of length `mat.ndim` specifying the starting
indices of each slice.
limit_indices: sequence of integers of length `mat.ndim` specifying the ending
indices of each slice
strides: sequence of integers of length `mat.ndim` specifying the stride for
each slice
Returns:
out: BCOO array containing the slice.
"""
if not isinstance(mat, BCOO):
raise ValueError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
start_indices = [operator.index(i) for i in start_indices]
limit_indices = [operator.index(i) for i in limit_indices]
if strides is not None:
strides = [operator.index(i) for i in strides]
else:
strides = [1] * mat.ndim
if len(start_indices) != len(limit_indices) != len(strides) != mat.ndim:
raise ValueError(f"bcoo_slice: indices must have size mat.ndim={mat.ndim}")
if strides != [1] * mat.ndim:
raise NotImplementedError(f"non-unit strides; got {strides}")

if not all(0 <= start <= end <= size
for start, end, size in safe_zip(start_indices, limit_indices, mat.shape)):
raise ValueError(f"bcoo_slice: invalid indices. Got start_indices={start_indices}, "
f"limit_indices={limit_indices} and shape={mat.shape}")

start_batch, start_sparse, start_dense = split_list(start_indices, [mat.n_batch, mat.n_sparse])
end_batch, end_sparse, end_dense = split_list(limit_indices, [mat.n_batch, mat.n_sparse])

data_slices = []
index_slices = []
for i, (start, end) in enumerate(zip(start_batch, end_batch)):
data_slices.append(slice(None) if mat.data.shape[i] != mat.shape[i] else slice(start, end))
index_slices.append(slice(None) if mat.indices.shape[i] != mat.shape[i] else slice(start, end))
data_slices.append(slice(None))
index_slices.extend([slice(None), slice(None)])
for i, (start, end) in enumerate(zip(start_dense, end_dense)):
data_slices.append(slice(start, end))
new_data = mat.data[tuple(data_slices)]
new_indices = mat.indices[tuple(index_slices)]
new_shape = [end - start for start, end in safe_zip(start_indices, limit_indices)]

if mat.n_sparse:
starts = jnp.expand_dims(jnp.array(start_sparse, dtype=new_indices.dtype), range(mat.n_batch + 1))
ends = jnp.expand_dims(jnp.array(end_sparse, dtype=new_indices.dtype), range(mat.n_batch + 1))
sparse_shape = jnp.array(mat.shape[mat.n_batch: mat.n_batch + mat.n_sparse], dtype=new_indices.dtype)

keep = jnp.all((new_indices >= starts) & (new_indices < ends), -1, keepdims=True)
new_indices = jnp.where(keep, new_indices - starts, sparse_shape)

keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
new_data = jnp.where(keep_data, new_data, 0)

return BCOO((new_data, new_indices), shape=new_shape)

def _tuple_replace(tup, ind, val):
return tuple(val if i == ind else t for i, t in enumerate(tup))

Expand Down
26 changes: 26 additions & 0 deletions jax/experimental/sparse/transform.py
Expand Up @@ -759,6 +759,13 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):

sparse_rules[sparse.todense_p] = _todense_sparse_rule

def _slice_sparse_rule(spenv, *operands, **params):
args = spvalues_to_arrays(spenv, operands)
out = sparse.bcoo_slice(*args, **params)
return arrays_to_spvalues(spenv, [out])

sparse_rules[lax.slice_p] = _slice_sparse_rule


#------------------------------------------------------------------------------
# BCOO methods derived from sparsify
Expand All @@ -775,6 +782,25 @@ def _reshape(self, *args, **kwargs):
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# mirrors lax_numpy._rewriting_take.

# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
return sparsify(lambda arr: lax.index_in_dim(arr, idx, keepdims=False))(arr)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
return sparsify(lambda arr: lax.slice_in_dim(arr, start, stop, step))(arr)

treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape)
result = sparsify(
lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted,
Expand Down
45 changes: 45 additions & 0 deletions tests/sparse_test.py
Expand Up @@ -36,6 +36,7 @@
from jax._src.lib import xla_extension_version
from jax._src.lib import gpu_sparse
from jax._src.lib import xla_bridge
from jax._src.util import unzip2
from jax import jit
from jax import tree_util
from jax import vmap
Expand Down Expand Up @@ -926,6 +927,50 @@ def trans(M):
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
self.assertArraysEqual(trans(M), trans(Msp).todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_slice(self, shape, dtype, n_batch, n_dense):
rng = self.rng()
sprng = rand_sparse(rng)
M = sprng(shape, dtype)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)

rng = self.rng()
slices = rng.randint(0, M.shape, (2, M.ndim)).T
slices.sort(1)
start_indices, limit_indices = unzip2(slices)
strides = None # strides currently not implemented
kwds = dict(start_indices=start_indices, limit_indices=limit_indices, strides=strides)

dense_result = lax.slice(M, **kwds)
sparse_result = sparse.bcoo_slice(Msp, **kwds)

self.assertArraysEqual(dense_result, sparse_result.todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}_idx={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, idx),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense,
"idx": idx}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)
for idx in [1, slice(1, 3)]))
def test_bcoo_getitem(self, shape, dtype, n_batch, n_dense, idx):
# Note: __getitem__ is currently only supported for simple slices and indexing
rng = self.rng()
sprng = rand_sparse(rng)
M = sprng(shape, dtype)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
self.assertArraysEqual(M[idx], Msp[idx].todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
Expand Down

0 comments on commit 269e752

Please sign in to comment.