Skip to content

Commit

Permalink
Merge pull request numba#8238 from kc611/adv_idx_1
Browse files Browse the repository at this point in the history
Advanced Indexing Support #1
  • Loading branch information
sklam committed Mar 14, 2023
2 parents 99cc983 + 21f0e91 commit 26bc501
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 18 deletions.
66 changes: 52 additions & 14 deletions numba/core/typing/arraydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,88 @@ def get_array_index_type(ary, idx):
right_indices = []
ellipsis_met = False
advanced = False
has_integer = False
num_newaxis = 0

if not isinstance(idx, types.BaseTuple):
idx = [idx]

# Here, a subspace is considered as a contiguous group of advanced indices.
# num_subspaces keeps track of the number of such
# contiguous groups.
in_subspace = False
num_subspaces = 0
array_indices = 0

# Walk indices
for ty in idx:
if ty is types.ellipsis:
if ellipsis_met:
raise NumbaTypeError("only one ellipsis allowed in array index "
"(got %s)" % (idx,))
raise NumbaTypeError(
"Only one ellipsis allowed in array indices "
"(got %s)" % (idx,))
ellipsis_met = True
in_subspace = False
elif isinstance(ty, types.SliceType):
pass
# If we encounter a non-advanced index while in a
# subspace then that subspace ends.
in_subspace = False
# In advanced indexing, any index broadcastable to an
# array is considered an advanced index. Hence all the
# branches below are considered as advanced indices.
elif isinstance(ty, types.Integer):
# Normalize integer index
ty = types.intp if ty.signed else types.uintp
# Integer indexing removes the given dimension
ndim -= 1
has_integer = True
# If we're within a subspace/contiguous group of
# advanced indices then no action is necessary
# since we've already counted that subspace once.
if not in_subspace:
# If we're not within a subspace and we encounter
# this branch then we have a new subspace/group.
num_subspaces += 1
in_subspace = True
elif (isinstance(ty, types.Array) and ty.ndim == 0
and isinstance(ty.dtype, types.Integer)):
# 0-d array used as integer index
ndim -= 1
has_integer = True
if not in_subspace:
num_subspaces += 1
in_subspace = True
elif (isinstance(ty, types.Array)
and ty.ndim == 1
and isinstance(ty.dtype, (types.Integer, types.Boolean))):
if advanced or has_integer:
# We don't support the complicated combination of
# advanced indices (and integers are considered part
# of them by Numpy).
msg = "only one advanced index supported"
raise NumbaNotImplementedError(msg)
if ty.ndim > 1:
# Advanced indexing limitation # 1
raise NumbaTypeError(
"Multi-dimensional indices are not supported.")
array_indices += 1
# The condition for activating advanced indexing is simply
# having at least one array with size > 1.
advanced = True
if not in_subspace:
num_subspaces += 1
in_subspace = True
elif (is_nonelike(ty)):
ndim += 1
num_newaxis += 1
else:
raise NumbaTypeError("unsupported array index type %s in %s"
raise NumbaTypeError("Unsupported array index type %s in %s"
% (ty, idx))
(right_indices if ellipsis_met else left_indices).append(ty)

if advanced:
if array_indices > 1:
# Advanced indexing limitation # 2
msg = "Using more than one non-scalar array index is unsupported."
raise NumbaTypeError(msg)

if num_subspaces > 1:
# Advanced indexing limitation # 3
msg = ("Using more than one indexing subspace is unsupported."
" An indexing subspace is a group of one or more"
" consecutive indices comprising integer or array types.")
raise NumbaTypeError(msg)

# Only Numpy arrays support advanced indexing
if advanced and not isinstance(ary, types.Array):
return
Expand Down
4 changes: 2 additions & 2 deletions numba/tests/test_array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,13 @@ def test_bad_index_npm(self):
arraytype2 = types.Array(types.int32, 2, 'C')
compile_isolated(bad_index, (arraytype1, arraytype2),
flags=no_pyobj_flags)
self.assertIn('unsupported array index type', str(raises.exception))
self.assertIn('Unsupported array index type', str(raises.exception))

def test_bad_float_index_npm(self):
with self.assertTypingError() as raises:
compile_isolated(bad_float_index,
(types.Array(types.float64, 2, 'C'),))
self.assertIn('unsupported array index type float64',
self.assertIn('Unsupported array index type float64',
str(raises.exception))

def test_fill_diagonal_basic(self):
Expand Down
185 changes: 184 additions & 1 deletion numba/tests/test_fancy_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numba import jit, typeof, njit
from numba.core import types
from numba.core.errors import TypingError
from numba.tests.support import MemoryLeakMixin, TestCase, tag
from numba.tests.support import MemoryLeakMixin, TestCase


def getitem_usecase(a, b):
Expand Down Expand Up @@ -301,5 +301,188 @@ def np_new_axis_setitem(a, idx, item):
np.testing.assert_equal(expected, got)


class TestFancyIndexingMultiDim(MemoryLeakMixin, TestCase):
# Every case has exactly one, one-dimensional array,
# otherwise it's not fancy indexing.
shape = (5, 6, 7, 8, 9, 10)
indexing_cases = [
# Slices + Integers
(slice(4, 5), 3, np.array([0, 1, 3, 4, 2]), 1),
(3, np.array([0,1,3,4,2]), slice(None), slice(4)),

# Ellipsis + Integers
(Ellipsis, 1, np.array([0,1,3,4,2])),
(np.array([0,1,3,4,2]), 3, Ellipsis),

# Ellipsis + Slices + Integers
(Ellipsis, 1, np.array([0,1,3,4,2]), 3, slice(1,5)),
(np.array([0,1,3,4,2]), 3, Ellipsis, slice(1,5)),

# Boolean Arrays + Integers
(slice(4, 5), 3,
np.array([True, False, True, False, True, False, False]),
1),
(3, np.array([True, False, True, False, True, False]),
slice(None), slice(4)),
]

def setUp(self):
super().setUp()
self.rng = np.random.default_rng(1)

def generate_random_indices(self):
N = min(self.shape)
slice_choices = [slice(None, None, None),
slice(1, N - 1, None),
slice(0, None, 2),
slice(N - 1, None, -2),
slice(-N + 1, -1, None),
slice(-1, -N, -2),
slice(0, N - 1, None),
slice(-1, -N, -2)
]
integer_choices = list(np.arange(N))

indices = []

# Generate K random slice cases. The value of K is arbitrary, the intent is
# to create plenty of variation.
K = 20
for _ in range(K):
array_idx = self.rng.integers(0, 5, size=15)
# Randomly select 4 slices from our list
curr_idx = self.rng.choice(slice_choices, size=4).tolist()
# Replace one of the slice with the array index
_array_idx = self.rng.choice(4)
curr_idx[_array_idx] = array_idx
indices.append(tuple(curr_idx))
# Generate K random integer cases
for _ in range(K):
array_idx = self.rng.integers(0, 5, size=15)
# Randomly select 4 integers from our list
curr_idx = self.rng.choice(integer_choices, size=4).tolist()
# Replace one of the slice with the array index
_array_idx = self.rng.choice(4)
curr_idx[_array_idx] = array_idx
indices.append(tuple(curr_idx))

# Generate K random ellipsis cases
for _ in range(K):
array_idx = self.rng.integers(0, 5, size=15)
# Randomly select 4 slices from our list
curr_idx = self.rng.choice(slice_choices, size=4).tolist()
# Generate two seperate random indices, replace one with
# array and second with Ellipsis
_array_idx = self.rng.choice(4, size=2, replace=False)
curr_idx[_array_idx[0]] = array_idx
curr_idx[_array_idx[1]] = Ellipsis
indices.append(tuple(curr_idx))

# Generate K random boolean cases
for _ in range(K):
array_idx = self.rng.integers(0, 5, size=15)
# Randomly select 4 slices from our list
curr_idx = self.rng.choice(slice_choices, size=4).tolist()
# Replace one of the slice with the boolean array index
_array_idx = self.rng.choice(4)
bool_arr_shape = self.shape[_array_idx]
curr_idx[_array_idx] = np.array(
self.rng.choice(2, size=bool_arr_shape),
dtype=bool
)
indices.append(tuple(curr_idx))

return indices

def check_getitem_indices(self, arr_shape, index):
@njit
def numba_get_item(array, idx):
return array[idx]

arr = np.random.randint(0, 11, size=arr_shape)
get_item = numba_get_item.py_func
orig_base = arr.base or arr

expected = get_item(arr, index)
got = numba_get_item(arr, index)
# Sanity check: In advanced indexing, the result is always a copy.
self.assertNotIn(expected.base, orig_base)

# Note: Numba may not return the same array strides and
# contiguity as NumPy
self.assertEqual(got.shape, expected.shape)
self.assertEqual(got.dtype, expected.dtype)
np.testing.assert_equal(got, expected)

# Check a copy was *really* returned by Numba
self.assertFalse(np.may_share_memory(got, expected))

def check_setitem_indices(self, arr_shape, index):
@njit
def set_item(array, idx, item):
array[idx] = item

arr = np.random.randint(0, 11, size=arr_shape)
src = arr[index]
expected = np.zeros_like(arr)
got = np.zeros_like(arr)

set_item.py_func(expected, index, src)
set_item(got, index, src)

# Note: Numba may not return the same array strides and
# contiguity as NumPy
self.assertEqual(got.shape, expected.shape)
self.assertEqual(got.dtype, expected.dtype)

np.testing.assert_equal(got, expected)

def test_getitem(self):
# Cases with a combination of integers + other objects
indices = self.indexing_cases.copy()

# Cases with permutations of either integers or objects
indices += self.generate_random_indices()

for idx in indices:
with self.subTest(idx=idx):
self.check_getitem_indices(self.shape, idx)

def test_setitem(self):
# Cases with a combination of integers + other objects
indices = self.indexing_cases.copy()

# Cases with permutations of either integers or objects
indices += self.generate_random_indices()

for idx in indices:
with self.subTest(idx=idx):
self.check_setitem_indices(self.shape, idx)

def test_unsupported_condition_exceptions(self):
err_idx_cases = [
# Cases with multi-dimensional indexing array
('Multi-dimensional indices are not supported.',
(0, 3, np.array([[1, 2], [2, 3]]))),
# Cases with more than one indexing array
('Using more than one non-scalar array index is unsupported.',
(0, 3, np.array([1, 2]), np.array([1, 2]))),
# Cases with more than one indexing subspace
# (The subspaces here are separated by slice(None))
("Using more than one indexing subspace is unsupported." + \
" An indexing subspace is a group of one or more consecutive" + \
" indices comprising integer or array types.",
(0, np.array([1, 2]), slice(None), 3, 4))
]

for err, idx in err_idx_cases:
with self.assertRaises(TypingError) as raises:
self.check_getitem_indices(self.shape, idx)
self.assertIn(
err,
str(raises.exception)
)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion numba/tests/test_record_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ def test_setitem_whole_array_error(self):
nbarr2 = np.recarray(1, dtype=recordwith2darray)
args = (nbarr1, nbarr2)
pyfunc = record_setitem_array
errmsg = "unsupported array index type"
errmsg = "Unsupported array index type"
with self.assertRaisesRegex(TypingError, errmsg):
self.get_cfunc(pyfunc, tuple((typeof(arg) for arg in args)))

Expand Down

0 comments on commit 26bc501

Please sign in to comment.