Skip to content

Commit

Permalink
Merge pull request #6341 from esc/re_roll_6279
Browse files Browse the repository at this point in the history
Re roll 6279
  • Loading branch information
sklam committed Oct 12, 2020
2 parents 68802b1 + 8074a41 commit 01ccc3a
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ The following top-level functions are supported:
* :func:`numpy.array_equal`
* :func:`numpy.array_split`
* :func:`numpy.asarray` (only the 2 first arguments)
* :func:`numpy.asarray_chkfinite` (only the 2 first arguments)
* :func:`numpy.asfarray`
* :func:`numpy.asfortranarray` (only the first argument)
* :func:`numpy.atleast_1d`
Expand Down
25 changes: 25 additions & 0 deletions numba/np/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4175,6 +4175,31 @@ def np_select_arr_impl(condlist, choicelist, default=0):

return np_select_arr_impl


@overload(np.asarray_chkfinite)
def np_asarray_chkfinite(a, dtype=None):

msg = "The argument to np.asarray_chkfinite must be array-like"
if not isinstance(a, (types.Array, types.Sequence, types.Tuple)):
raise TypingError(msg)

if is_nonelike(dtype):
dt = a.dtype
else:
try:
dt = as_dtype(dtype)
except NotImplementedError:
raise TypingError('dtype must be a valid Numpy dtype')

def impl(a, dtype=None):
a = np.asarray(a, dtype=dt)
for i in np.nditer(a):
if not np.isfinite(i):
raise ValueError("array must not contain infs or NaNs")
return a

return impl

#----------------------------------------------------------------------------
# Windowing functions
# - translated from the numpy implementations found in:
Expand Down
73 changes: 73 additions & 0 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ def flip_ud(a):
return np.flipud(a)


def np_asarray_chkfinite(a, dtype=None):
return np.asarray_chkfinite(a, dtype)


def array_contains(a, key):
return key in a

Expand Down Expand Up @@ -3866,6 +3870,75 @@ def test_cross2d_exceptions(self):
str(raises.exception)
)

def test_asarray_chkfinite(self):
pyfunc = np_asarray_chkfinite
cfunc = jit(nopython=True)(pyfunc)
self.disable_leak_check()

pairs = [
#1D array with all args
(
np.array([1, 2, 3]),
np.float32,
),
#1D array
(
np.array([1, 2, 3]),
),
#1D array-like
(
[1, 2, 3, 4],
),
# 2x2 (n-dims)
(
np.array([[1, 2], [3, 4]]),
np.float32,
),
# 2x2 array-like (n-dims)
(
((1, 2), (3, 4)),
np.int64
),
# 2x2 (1-dim) with type promotion
(
np.array([1, 2], dtype=np.int64),
),
# 3x2 (with higher order broadcasting)
(
np.arange(36).reshape(6, 2, 3),
),
]

for pair in pairs:
expected = pyfunc(*pair)
got = cfunc(*pair)
self.assertPreciseEqual(expected, got)

def test_asarray_chkfinite_exceptions(self):
cfunc = jit(nopython=True)(np_asarray_chkfinite)
self.disable_leak_check()

#test for single value
with self.assertRaises(TypingError) as e:
cfunc(2)
msg = "The argument to np.asarray_chkfinite must be array-like"
self.assertIn(msg, str(e.exception))

#test for NaNs
with self.assertRaises(ValueError) as e:
cfunc(np.array([2, 4, np.nan, 5]))
self.assertIn("array must not contain infs or NaNs", str(e.exception))

#test for infs
with self.assertRaises(ValueError) as e:
cfunc(np.array([1, 2, np.inf, 4]))
self.assertIn("array must not contain infs or NaNs", str(e.exception))

#test for dtype
with self.assertRaises(TypingError) as e:
cfunc(np.array([1, 2, 3, 4]), 'float32')
self.assertIn("dtype must be a valid Numpy dtype", str(e.exception))


class TestNPMachineParameters(TestCase):
# tests np.finfo, np.iinfo, np.MachAr
Expand Down

0 comments on commit 01ccc3a

Please sign in to comment.