Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add np.asfarray impl #5418

Merged
merged 14 commits into from
Aug 27, 2020
1 change: 1 addition & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ The following top-level functions are supported:
* :func:`numpy.array` (only the 2 first arguments)
* :func:`numpy.array_equal`
* :func:`numpy.asarray` (only the 2 first arguments)
* :func:`numpy.asfarray`
* :func:`numpy.asfortranarray` (only the first argument)
* :func:`numpy.atleast_1d`
* :func:`numpy.atleast_2d`
Expand Down
13 changes: 13 additions & 0 deletions numba/np/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4070,6 +4070,19 @@ def impl(a, dtype=None):
return impl


@overload(np.asfarray)
def np_asfarray(a, dtype=np.float64):
dtype = as_dtype(dtype)
if not np.issubdtype(dtype, np.inexact):
dx = types.float64
else:
dx = dtype

def impl(a, dtype=np.float64):
return np.asarray(a, dx)
return impl


@overload(np.extract)
def np_extract(condition, arr):

Expand Down
26 changes: 26 additions & 0 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ def asarray_kws(a, dtype):
return np.asarray(a, dtype=dtype)


def asfarray(a, dtype=np.float64):
return np.asfarray(a, dtype=dtype)


def extract(condition, arr):
return np.extract(condition, arr)

Expand Down Expand Up @@ -3208,6 +3212,28 @@ def make_unicode_list():
test_reject(make_nested_list_with_dict())
test_reject(make_unicode_list())

def test_asfarray(self):
def inputs():
yield np.array([1, 2, 3]), None
yield np.array([2, 3], dtype=np.float32), np.float32
yield np.array([2, 3], dtype=np.int8), np.int8
guilhermeleobas marked this conversation as resolved.
Show resolved Hide resolved
yield np.array([2, 3], dtype=np.int8), np.complex64
yield np.array([2, 3], dtype=np.int8), np.complex128

pyfunc = asfarray
cfunc = jit(nopython=True)(pyfunc)

for arr, dt in inputs():
if dt is None:
expected = pyfunc(arr)
got = cfunc(arr)
else:
expected = pyfunc(arr, dtype=dt)
got = cfunc(arr, dtype=dt)

self.assertPreciseEqual(expected, got)
self.assertTrue(np.issubdtype(got.dtype, np.inexact), got.dtype)

def test_repeat(self):
# np.repeat(a, repeats)
np_pyfunc = np_repeat
Expand Down