Skip to content

Commit

Permalink
Do not call concrete_aval for basic integer index checks
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed May 5, 2022
1 parent 38cee9f commit c9d6e76
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3653,8 +3653,16 @@ def _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
idx.append(s)
return treedef.unflatten(idx)

def _int(aval):
return not aval.shape and issubdtype(aval.dtype, integer)
def _is_basic_int_index(x):
if isinstance(x, core.Tracer):
aval = x.aval
return (isinstance(aval, (ConcreteArray, ShapedArray)) and
not aval.shape and issubdtype(aval.dtype, integer))
try:
operator.index(x)
return not isinstance(x, bool)
except TypeError:
return False

def _index_to_gather(x_shape, idx, normalize_indices=True):
# Remove ellipses and add trailing slice(None)s.
Expand Down Expand Up @@ -3696,7 +3704,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
collapsed_slice_dims = []
start_index_map = []

use_64bit_index = _any([not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape])
use_64bit_index = _any(not core.is_constant_dim(d) or d >= (1 << 31) for d in x_shape)
index_dtype = int64 if use_64bit_index else int32

# Gather indices.
Expand Down Expand Up @@ -3747,12 +3755,8 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
gather_slice_shape.append(1)
continue

try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
# Handle basic int indexes.
if isinstance(abstract_i, (ConcreteArray, ShapedArray)) and _int(abstract_i):
if _is_basic_int_index(i):
if core.symbolic_equal_dim(x_shape[x_axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
Expand Down Expand Up @@ -3833,6 +3837,10 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
y_axis += 1
x_axis += 1
else:
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if (abstract_i is not None and
not (issubdtype(abstract_i.dtype, integer) or issubdtype(abstract_i.dtype, bool_))):
msg = ("Indexer must have integer or boolean type, got indexer "
Expand Down Expand Up @@ -3967,7 +3975,7 @@ def _is_advanced_int_indexer(idx):
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
assert isinstance(idx, tuple)
if _all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_scalar(e) and issubdtype(_dtype(e), np.integer) for e in idx):
or _is_basic_int_index(e) for e in idx):
return False
return _all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_int_arraylike(e) for e in idx)
Expand Down

0 comments on commit c9d6e76

Please sign in to comment.