Skip to content

Commit

Permalink
Merge pull request #9504 from shourya5/issue-fix
Browse files Browse the repository at this point in the history
added np.size() overload and added tests
  • Loading branch information
sklam committed Apr 16, 2024
2 parents 4a61446 + 83c1782 commit e749f05
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/upcoming_changes/9504.np_support.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Added support for ``np.size()``
-------------------------------

Added ``np.size()`` support for numpy, which was previously unsupported.
10 changes: 10 additions & 0 deletions numba/np/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2556,6 +2556,16 @@ def impl(a):
return np.asarray(a).shape
return impl


@overload(np.size)
def np_size(a):
if not type_can_asarray(a):
raise errors.TypingError("The argument to np.size must be array-like")

def impl(a):
return np.asarray(a).size
return impl

# ------------------------------------------------------------------------------


Expand Down
27 changes: 27 additions & 0 deletions numba/tests/test_array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def numpy_fill_diagonal(arr, val, wrap=False):
def numpy_shape(arr):
return np.shape(arr)

def numpy_size(arr):
return np.size(arr)


def numpy_flatnonzero(a):
return np.flatnonzero(a)
Expand Down Expand Up @@ -1219,6 +1222,30 @@ def check(x):
self.assertIn("The argument to np.shape must be array-like",
str(raises.exception))

def test_size(self):
pyfunc = numpy_size
cfunc = jit(nopython=True)(pyfunc)

def check(x):
expected = pyfunc(x)
got = cfunc(x)
self.assertPreciseEqual(got, expected)

# check arrays
for t in [(), (1,), (2, 3,), (4, 5, 6)]:
arr = np.empty(t)
check(arr)

# check scalar values
for t in [1, False, 3.14, np.int8(4), np.float32(2.718)]:
check(t)

with self.assertRaises(TypingError) as raises:
cfunc('a')

self.assertIn("The argument to np.size must be array-like",
str(raises.exception))

def test_flatnonzero_basic(self):
pyfunc = numpy_flatnonzero
cfunc = jit(nopython=True)(pyfunc)
Expand Down
18 changes: 18 additions & 0 deletions numba/tests/test_record_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ def get_shape(rec):
return np.shape(rec.j)


def get_size(rec):
return np.size(rec.j)


def get_charseq(ary, i):
return ary[i].n

Expand Down Expand Up @@ -1616,6 +1620,20 @@ def test_shape(self):
arr_res = cfunc(arg)
np.testing.assert_equal(arr_res, arr_expected)

def test_size(self):
# test getting the size of a nestedarray inside a record
nbarr = np.recarray(2, dtype=recordwith2darray)
nbarr[0] = np.array([(1, ((1, 2), (4, 5), (2, 3)))],
dtype=recordwith2darray)[0]

arg = nbarr[0]
pyfunc = get_size
ty = typeof(arg)
arr_expected = pyfunc(arg)
cfunc = self.get_cfunc(pyfunc, (ty,))
arr_res = cfunc(arg)
np.testing.assert_equal(arr_res, arr_expected)

def test_corner_slice(self):
# testing corner cases while slicing nested arrays
nbarr = np.recarray((1, 2, 3, 5, 7, 13, 17), dtype=recordwith4darray,
Expand Down

0 comments on commit e749f05

Please sign in to comment.