Skip to content

Commit

Permalink
Add x.at[idx].get().
Browse files Browse the repository at this point in the history
This allows the sorted/unique keyword arguments to be passed to indexed gather operations.
  • Loading branch information
hawkinsp committed Jul 7, 2021
1 parent 5c2bb8b commit 2168483
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
30 changes: 23 additions & 7 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4980,18 +4980,20 @@ def unique(ar, return_index=False, return_inverse=False,

### Indexing

def _rewriting_take(arr, idx):
def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
arr = asarray(arr)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
return _gather(arr, treedef, static_idx, dynamic_idx)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices)

# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @partial(jit, static_argnums=(1, 2))
def _gather(arr, treedef, static_idx, dynamic_idx):
def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices):
idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
y = arr
Expand All @@ -5003,10 +5005,10 @@ def _gather(arr, treedef, static_idx, dynamic_idx):

# We avoid generating a gather when indexer.gather_indices.size is empty.
if not core.is_empty_shape(indexer.gather_indices.shape):
y = lax.gather(y, indexer.gather_indices, indexer.dnums,
indexer.gather_slice_shape,
unique_indices=indexer.unique_indices,
indices_are_sorted=indexer.indices_are_sorted)
y = lax.gather(
y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape,
unique_indices=unique_indices or indexer.unique_indices,
indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted)

# Reverses axes with negative strides.
if indexer.reversed_y_dims:
Expand Down Expand Up @@ -6064,6 +6066,20 @@ def __init__(self, array, index):
def __repr__(self):
return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"

def get(self, indices_are_sorted=False, unique_indices=False):
"""Equivalent to ``x[idx]``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexing <numpy.doc.indexing>` ``x[idx]``. This function differs from
the usual array indexing syntax in that it allows additional keyword
arguments ``indices_are_sorted`` and ``unique_indices`` to be passed.
See :mod:`jax.ops` for details.
"""
return _rewriting_take(self.array, self.index,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)

def set(self, values, indices_are_sorted=False, unique_indices=False):
"""Pure equivalent of ``x[idx] = y``.
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
inserted_window_dims=indexer.dnums.collapsed_slice_dims,
scatter_dims_to_operand_dims=indexer.dnums.start_index_map
)
out = scatter_op(x, indexer.gather_indices, y, dnums,
indices_are_sorted=indices_are_sorted,
unique_indices=indexer.unique_indices or unique_indices)
out = scatter_op(
x, indexer.gather_indices, y, dnums,
indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted,
unique_indices=indexer.unique_indices or unique_indices)
return lax.convert_element_type(out, dtype)


Expand Down
15 changes: 15 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,21 @@ def testStaticIndexing(self, shape, dtype, indexer):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": "{}_inshape={}_indexer={}".format(
name, jtu.format_shape_dtype_string( shape, dtype), indexer),
"shape": shape, "dtype": dtype, "indexer": indexer
} for name, index_specs in STATIC_INDEXING_TESTS
for shape, indexer in index_specs
for dtype in all_dtypes))
def testStaticIndexingWithAtGet(self, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda x: np.asarray(x)[indexer]
jnp_fun = lambda x: jnp.asarray(x).at[indexer].get()
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters({
"testcase_name":
"{}_inshape={}_indexer={}".format(name,
Expand Down

0 comments on commit 2168483

Please sign in to comment.