Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re roll 6279 #6341

Merged
merged 12 commits into from
Oct 12, 2020
1 change: 1 addition & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ The following top-level functions are supported:
* :func:`numpy.array_equal`
* :func:`numpy.array_split`
* :func:`numpy.asarray` (only the 2 first arguments)
* :func:`numpy.asarray_chkfinite` (only the 2 first arguments)
* :func:`numpy.asfarray`
* :func:`numpy.asfortranarray` (only the first argument)
* :func:`numpy.atleast_1d`
Expand Down
25 changes: 25 additions & 0 deletions numba/np/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4175,6 +4175,31 @@ def np_select_arr_impl(condlist, choicelist, default=0):

return np_select_arr_impl


@overload(np.asarray_chkfinite)
def np_asarray_chkfinite(a, dtype=None):

msg = "The argument to np.asarray_chkfinite must be array-like"
if not isinstance(a, (types.Array, types.Sequence, types.Tuple)):
raise TypingError(msg)

if is_nonelike(dtype):
dt = a.dtype
else:
try:
dt = as_dtype(dtype)
except NotImplementedError:
raise TypingError('dtype must be a valid Numpy dtype')

def impl(a, dtype=None):
a = np.asarray(a, dtype=dt)
for i in np.nditer(a):
if not np.isfinite(i):
raise ValueError("array must not contain infs or NaNs")
return a

return impl

#----------------------------------------------------------------------------
# Windowing functions
# - translated from the numpy implementations found in:
Expand Down
73 changes: 73 additions & 0 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ def flip_ud(a):
return np.flipud(a)


def np_asarray_chkfinite(a, dtype=None):
return np.asarray_chkfinite(a, dtype)


def array_contains(a, key):
return key in a

Expand Down Expand Up @@ -3866,6 +3870,75 @@ def test_cross2d_exceptions(self):
str(raises.exception)
)

def test_asarray_chkfinite(self):
pyfunc = np_asarray_chkfinite
cfunc = jit(nopython=True)(pyfunc)
self.disable_leak_check()

pairs = [
#1D array with all args
(
np.array([1, 2, 3]),
np.float32,
),
#1D array
(
np.array([1, 2, 3]),
),
#1D array-like
(
[1, 2, 3, 4],
),
# 2x2 (n-dims)
(
np.array([[1, 2], [3, 4]]),
np.float32,
),
# 2x2 array-like (n-dims)
(
((1, 2), (3, 4)),
np.int64
),
# 2x2 (1-dim) with type promotion
(
np.array([1, 2], dtype=np.int64),
),
# 3x2 (with higher order broadcasting)
(
np.arange(36).reshape(6, 2, 3),
),
]

for pair in pairs:
expected = pyfunc(*pair)
got = cfunc(*pair)
self.assertPreciseEqual(expected, got)

def test_asarray_chkfinite_exceptions(self):
cfunc = jit(nopython=True)(np_asarray_chkfinite)
self.disable_leak_check()

#test for single value
with self.assertRaises(TypingError) as e:
cfunc(2)
msg = "The argument to np.asarray_chkfinite must be array-like"
self.assertIn(msg, str(e.exception))

#test for NaNs
with self.assertRaises(ValueError) as e:
cfunc(np.array([2, 4, np.nan, 5]))
self.assertIn("array must not contain infs or NaNs", str(e.exception))

#test for infs
with self.assertRaises(ValueError) as e:
cfunc(np.array([1, 2, np.inf, 4]))
self.assertIn("array must not contain infs or NaNs", str(e.exception))

#test for dtype
with self.assertRaises(TypingError) as e:
cfunc(np.array([1, 2, 3, 4]), 'float32')
self.assertIn("dtype must be a valid Numpy dtype", str(e.exception))


class TestNPMachineParameters(TestCase):
# tests np.finfo, np.iinfo, np.MachAr
Expand Down