Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check type getitem #5549

Merged
merged 6 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions numba/core/types/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,23 +188,26 @@ def __getitem__(self, args):
def _determine_array_spec(self, args):
# XXX non-contiguous by default, even for 1d arrays,
# doesn't sound very intuitive
if isinstance(args, (tuple, list)):
def validate_slice(s):
return isinstance(s, slice) and s.start is None and s.stop is None

if isinstance(args, (tuple, list)) and all(map(validate_slice, args)):
ndim = len(args)
if args[0].step == 1:
layout = 'F'
elif args[-1].step == 1:
layout = 'C'
else:
layout = 'A'
elif isinstance(args, slice):
elif validate_slice(args):
ndim = 1
if args.step == 1:
layout = 'C'
else:
layout = 'A'
else:
ndim = 1
layout = 'A'
# Raise a KeyError to not be handled by collection constructors (e.g. list).
raise KeyError(f"Can only index numba types with slices with no start or stop, got {args}.")

return ndim, layout

Expand Down
12 changes: 6 additions & 6 deletions numba/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import warnings
from types import ModuleType
from importlib import import_module
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
import numpy as np

from inspect import signature as pysignature # noqa: F401
Expand Down Expand Up @@ -478,8 +478,7 @@ def unified_function_type(numba_types, require_precise=True):

Parameters
----------
numba_types : tuple
Numba type instances.
numba_types : Sequence of numba Type instances.
require_precise : bool
If True, the returned Numba function type must be precise.

Expand All @@ -501,9 +500,10 @@ def unified_function_type(numba_types, require_precise=True):
"""
from numba.core.errors import NumbaExperimentalFeatureWarning

if not (numba_types
and isinstance(numba_types[0],
(types.Dispatcher, types.FunctionType))):
if not (isinstance(numba_types, Sequence) and
len(numba_types) > 0 and
isinstance(numba_types[0],
(types.Dispatcher, types.FunctionType))):
return

warnings.warn("First-class function type feature is experimental",
Expand Down
5 changes: 4 additions & 1 deletion numba/stencils/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,10 @@ def add_indices_to_kernel(self, kernel, index_names, ndim,
const_index_vars[dim], loc)
new_body.append(ir.Assign(getitemcall, getitemvar, loc))
# Get the type of this particular part of the index tuple.
one_index_typ = stmt_index_var_typ[dim]
if isinstance(stmt_index_var_typ, types.ConstSized):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I found the affected test test_basic81 (numba.tests.test_stencils.TestManyStencils)

one_index_typ = stmt_index_var_typ[dim]
else:
one_index_typ = stmt_index_var_typ[:]
# If the array is indexed with a slice then we
# have to add the index value with a call to
# slice_addition.
Expand Down
14 changes: 14 additions & 0 deletions numba/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,27 @@ def check(arrty, scalar, ndim, layout):
self.assertIs(arrty.dtype, scalar)
self.assertEqual(arrty.ndim, ndim)
self.assertEqual(arrty.layout, layout)

def check_index_error(callable):
with self.assertRaises(KeyError) as raises:
callable()
self.assertIn(
"Can only index numba types with slices with no start or "
"stop, got", str(raises.exception))

scalar = types.int32
check(scalar[:], scalar, 1, 'A')
check(scalar[::1], scalar, 1, 'C')
check(scalar[:, :], scalar, 2, 'A')
check(scalar[:, ::1], scalar, 2, 'C')
check(scalar[::1, :], scalar, 2, 'F')

check_index_error(lambda: scalar[0])
check_index_error(lambda: scalar[:, 4])
check_index_error(lambda: scalar[::1, 1:])
check_index_error(lambda: scalar[:2])
check_index_error(lambda: list(scalar))

def test_array_notation_for_dtype(self):
def check(arrty, scalar, ndim, layout):
self.assertIs(arrty.dtype, scalar)
Expand Down