Skip to content

Commit

Permalink
Speedup _expand_bool_indices when passing basic integer indices
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed May 5, 2022
1 parent c9d6e76 commit 0c5b132
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3893,6 +3893,8 @@ def _eliminate_deprecated_list_indexing(idx):
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
# objects]". Detects this and raises a TypeError.
if not isinstance(idx, tuple):
if isinstance(idx, (int, slice)):
return (idx,)
if isinstance(idx, Sequence) and not isinstance(idx, (ndarray, np.ndarray)):
# As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
# others are converted to arrays, based on a set of somewhat convoluted heuristics
Expand All @@ -3912,13 +3914,11 @@ def _eliminate_deprecated_list_indexing(idx):
return idx

def _is_boolean_index(i):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
return (isinstance(abstract_i, ShapedArray) and issubdtype(abstract_i.dtype, bool_)
or isinstance(i, list) and i and _all(_is_scalar(e)
and issubdtype(_dtype(e), np.bool_) for e in i))
if isinstance(i, core.Tracer):
i = i.aval
return (isinstance(i, bool) or getattr(i, "dtype", None) == np.bool_ or
isinstance(i, list) and i and _all(_is_scalar(e)
and _dtype(e) == np.bool_ for e in i))

def _expand_bool_indices(idx, shape):
"""Converts concrete bool indexes into advanced integer indexes."""
Expand All @@ -3932,11 +3932,11 @@ def _expand_bool_indices(idx, shape):
if e is not None and e is not Ellipsis)
ellipsis_offset = 0
for dim_number, i in enumerate(idx):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if _is_boolean_index(i):
try:
abstract_i = core.get_aval(i)
except TypeError:
abstract_i = None
if isinstance(i, list):
i = array(i)
abstract_i = core.get_aval(i)
Expand Down

0 comments on commit 0c5b132

Please sign in to comment.