Skip to content

Commit

Permalink
Validate shapes for boolean indices
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 3, 2021
1 parent 10bbd62 commit 08e1c83
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
18 changes: 12 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -4922,6 +4922,7 @@ def _unique_axis_sorted_mask(ar, axis):
size, *out_shape = aux.shape
aux = aux.reshape(size, _prod(out_shape)).T
if aux.shape[0] == 0:
size = 1
perm = zeros(1, dtype=int)
else:
perm = lexsort(aux[::-1])
Expand Down Expand Up @@ -5005,7 +5006,7 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False):
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
arr = asarray(arr)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
unique_indices)

Expand Down Expand Up @@ -5065,7 +5066,7 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
"newaxis_dims",
])

def _split_index_for_jit(idx):
def _split_index_for_jit(idx, shape):
"""Splits indices into necessarily-static and dynamic parts.
Used to pass indices into `jit`-ted function.
Expand All @@ -5075,7 +5076,7 @@ def _split_index_for_jit(idx):

# Expand any (concrete) boolean indices. We can then use advanced integer
# indexing logic to handle them.
idx = _expand_bool_indices(idx)
idx = _expand_bool_indices(idx, shape)

leaves, treedef = tree_flatten(idx)
dynamic = [None] * len(leaves)
Expand Down Expand Up @@ -5328,16 +5329,16 @@ def _eliminate_deprecated_list_indexing(idx):
idx = (idx,)
return idx

def _expand_bool_indices(idx):
def _expand_bool_indices(idx, shape):
"""Converts concrete bool indexes into advanced integer indexes."""
out = []
for i in idx:
for dim_number, i in enumerate(idx):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if (isinstance(abstract_i, ShapedArray) and issubdtype(abstract_i.dtype, bool_)
or isinstance(i, list) and _all(_is_scalar(e) and issubdtype(_dtype(e), np.bool_) for e in i)):
or isinstance(i, list) and i and _all(_is_scalar(e) and issubdtype(_dtype(e), np.bool_) for e in i)):
if isinstance(i, list):
i = array(i)
abstract_i = core.get_aval(i)
Expand All @@ -5346,6 +5347,11 @@ def _expand_bool_indices(idx):
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
else:
i_shape = _shape(i)
expected_shape = shape[len(out): len(out) + _ndim(i)]
if i_shape != expected_shape:
raise IndexError("boolean index did not match shape of indexed array in index "
f"{dim_number}: got {i_shape}, expected {expected_shape}")
out.extend(np.where(i))
else:
out.append(i)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/ops/scatter.py
Expand Up @@ -64,7 +64,7 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
y = jnp.asarray(y)
# XLA gathers and scatters are very similar in structure; the scatter logic
# is more or less a transpose of the gather equivalent.
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx)
treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
indices_are_sorted, unique_indices, normalize_indices)

Expand Down
19 changes: 19 additions & 0 deletions tests/lax_numpy_indexing_test.py
Expand Up @@ -833,6 +833,25 @@ def testBooleanIndexingWithEmptyResult(self):
expected = np.array([-1])[np.array([False])]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingShapeMismatch(self):
# Regression test for https://github.com/google/jax/issues/7329
x = jnp.arange(4)
idx = jnp.array([True, False])
with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"):
x[idx]

def testNontrivialBooleanIndexing(self):
# Test nontrivial corner case in boolean indexing shape validation
rng = jtu.rand_default(self.rng())
index = (rng((2, 3), np.bool_), rng((6,), np.bool_))

args_maker = lambda: [rng((2, 3, 6), np.int32)]
np_fun = lambda x: np.asarray(x)[index]
jnp_fun = lambda x: jnp.asarray(x)[index]

self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testFloatIndexingError(self):
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
Expand Down

0 comments on commit 08e1c83

Please sign in to comment.