Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Reimplementing np.mean with overload and adding axis parameter #4480

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6b7f05a
implementing np.mean with overload and adding axis parameter
luk-f-a Aug 23, 2019
0411aa1
Merge branch 'sum_with_dtype' into mean_axis_overload
luk-f-a Sep 25, 2019
3b35bcf
adding tests for different dtypes
luk-f-a Sep 25, 2019
ec87918
Merge branch 'sum_with_dtype' into mean_axis_overload
luk-f-a Sep 25, 2019
69e8095
improving np.mean tests
luk-f-a Sep 26, 2019
568c045
np.mean implementation with axis parameter including tests
luk-f-a Oct 1, 2019
c416f50
Merge branch 'sum_with_dtype' into mean_axis_overload
luk-f-a Oct 1, 2019
8336f12
flake8 fix
luk-f-a Oct 1, 2019
e6c3a9b
Merge branch 'master' into mean_axis_overload
luk-f-a Oct 1, 2019
d9b4f51
mean implemantation with tests
luk-f-a Oct 11, 2019
f5d082d
mean implemantation with tests2
luk-f-a Oct 11, 2019
f750c88
Merge branch 'master' of https://github.com/numba/numba into mean_axi…
luk-f-a Oct 11, 2019
e0e7e78
fix for axis is None case
luk-f-a Oct 11, 2019
102b64a
Merge branch 'mean_axis_overload' of github.com:luk-f-a/numba into me…
luk-f-a Oct 11, 2019
ab58749
expanded tests
luk-f-a Oct 11, 2019
00ddcd2
working on timedelta64 test
luk-f-a Jan 6, 2020
1a73ff6
Merge branch 'master' of https://github.com/numba/numba into mean_axi…
luk-f-a Jan 6, 2020
817a8e6
Merge branch 'master' of https://github.com/numba/numba into mean_axi…
luk-f-a Apr 19, 2020
4b29b4d
improvements to the typing of timedelta functions
luk-f-a Apr 20, 2020
44b6f5a
Merge branch 'sum_timedelta' into mean_axis_overload
luk-f-a May 15, 2020
7fef5b9
Merge branch 'timedelta_div' into mean_axis_overload
luk-f-a May 16, 2020
2723995
merged timedelta division
luk-f-a May 16, 2020
63514f7
Merge branch 'timedelta_div' into mean_axis_overload
luk-f-a May 16, 2020
ce56eb2
flake8 fix
luk-f-a May 16, 2020
be93617
flake8 fix
luk-f-a May 16, 2020
0f61b68
Merge branch 'main' into pr/4480
sklam Jul 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ The following methods of Numpy arrays are supported in their basic form
* :meth:`~numpy.ndarray.cumprod`
* :meth:`~numpy.ndarray.cumsum`
* :meth:`~numpy.ndarray.max`
* :meth:`~numpy.ndarray.mean`
* :meth:`~numpy.ndarray.min`
* :meth:`~numpy.ndarray.nonzero`
* :meth:`~numpy.ndarray.prod`
Expand Down Expand Up @@ -287,6 +286,14 @@ Reductions
The following reduction functions are supported:

* :func:`numpy.diff` (only the 2 first arguments)
* :func:`numpy.mean` (with or without the ``axis`` argument)

* ``axis`` only supports ``integer`` values.
* If the ``axis`` argument is a compile-time constant, all valid values
are supported.
An out-of-range value will result in a ``LoweringError`` at compile-time.
* If the ``axis`` argument is not a compile-time constant, only values
from 0 to 3 are supported.
* :func:`numpy.median` (only the first argument)
* :func:`numpy.nancumprod` (only the first argument, requires NumPy >= 1.12))
* :func:`numpy.nancumsum` (only the first argument, requires NumPy >= 1.12))
Expand Down
58 changes: 44 additions & 14 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,22 +399,52 @@ def array_cumprod_impl(arr):
return impl_ret_new_ref(context, builder, sig.return_type, res)


@lower_builtin(np.mean, types.Array)
@lower_builtin("array.mean", types.Array)
def array_mean(context, builder, sig, args):
zero = sig.return_type(0)
@register_jitable
def sum_array(arr, dtype):
return np.sum(arr, dtype=dtype)

def array_mean_impl(arr):
# Can't use the naive `arr.sum() / arr.size`, as it would return
# a wrong result on integer sum overflow.
c = zero
for v in np.nditer(arr):
c += v.item()
return c / arr.size

res = context.compile_internal(builder, array_mean_impl, sig, args,
locals=dict(c=sig.return_type))
return impl_ret_untracked(context, builder, sig.return_type, res)
@register_jitable
def sum_array_axis(arr, axis, dtype):
return np.sum(arr, axis=axis, dtype=dtype)


@overload(np.mean)
@overload_method(types.Array, 'mean')
def array_mean(arr, axis=None):
if isinstance(arr, types.Array):
# determine accumulator type
if isinstance(arr.dtype, (types.Integer, types.Boolean)):
ret_dtype = np.float64
elif isinstance(arr.dtype, (types.Float, types.Complex)):
ret_dtype = arr.dtype
else:
raise TypeError(("np.mean is not supported on {} arrays. "
"It supports boolean, integer, float "
"and complex arrays").format(arr.dtype))
# dispatch based on whether there's an axis parameter and its type
if axis is None:
luk-f-a marked this conversation as resolved.
Show resolved Hide resolved

def mean_impl(arr, axis=None):
return sum_array(arr, dtype=ret_dtype) / arr.size

return mean_impl
elif isinstance(axis, types.Integer):

def mean_impl(arr, axis=None):
if axis >= arr.ndim:
raise ValueError("'axis' entry is out of bounds")
return sum_array_axis(arr, axis=axis, dtype=ret_dtype) / arr.shape[axis]

return mean_impl
elif isinstance(axis, types.IntegerLiteral):
if axis.literal_value >= arr.ndim:
raise ValueError("'axis' entry is out of bounds")

def mean_impl(arr, axis=None):
return sum_array_axis(arr, axis=axis, dtype=ret_dtype) / arr.shape[axis]

return mean_impl


@lower_builtin(np.var, types.Array)
Expand Down
42 changes: 41 additions & 1 deletion numba/tests/test_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def array_sum_const_axis_neg_one(a, axis):
# "const_multi" variant would raise errors
return a.sum(axis=-1)

def array_mean(a):
return a.mean()

def array_mean_axis(a, axis):
return a.mean(axis)

def array_cumsum(a, *args):
return a.cumsum(*args)

Expand Down Expand Up @@ -784,7 +790,8 @@ def test_sum(self):
pyfunc = array_sum
cfunc = jit(nopython=True)(pyfunc)
all_dtypes = [np.float64, np.float32, np.int64, np.int32,
np.complex64, np.complex128, np.uint32, np.uint64, np.timedelta64]
np.complex64, np.complex128, np.uint32, np.uint64,
np.timedelta64, np.bool_]
all_test_arrays = [
[np.ones((7, 6, 5, 4, 3), arr_dtype),
np.ones(1, arr_dtype),
Expand Down Expand Up @@ -1006,6 +1013,39 @@ def foo(arr):
# Just check for the "out of bounds" phrase in it.
self.assertIn("out of bounds", str(raises.exception))

def test_mean_axis(self):
""" tests np.mean with and without axis parameter
"""
pyfunc = array_mean
cfunc = jit(nopython=True)(pyfunc)
pyfunc_axis = array_mean_axis
cfunc_axis = jit(nopython=True)(pyfunc_axis)
# a complete list
# all_dtypes = [np.float64, np.float32, np.int64, np.int32, np.uint32,
# np.uint64, np.complex64, np.complex128]
# a reduced list to save test execution time
all_dtypes = [np.float32, np.int32, np.uint32, np.complex64, np.bool_]
all_dtypes = [np.float64]
all_test_arrays = [np.ones((7, 6, 5, 4, 3), arr_dtype) for arr_dtype
in all_dtypes]

for arr in all_test_arrays:
with self.subTest("no axis - dtype: {}".format(arr.dtype)):
self.assertPreciseEqual(pyfunc(arr), cfunc(arr))
# with self.subTest("axis 0 as integer literal - dtype: {}".format(arr.dtype)):
# self.assertPreciseEqual(pyfunc_axis(arr, 0), cfunc_axis(arr, 0))
# with self.subTest("axis 1 as integer variable - dtype: {}".format(arr.dtype)):
# axis = 1
# self.assertPreciseEqual(pyfunc_axis(arr, axis), cfunc_axis(arr, axis))
# with self.subTest("axis 2 as integer variable - dtype: {}".format(arr.dtype)):
# axis = 2
# self.assertPreciseEqual(pyfunc_axis(arr, axis), cfunc_axis(arr, axis))
# with self.subTest("axis -1 as integer literal- dtype: {}".format(arr.dtype)):
# # axis -1 is only supported for IntegerLiterals
# pyfunc2 = lambda x: x.mean(axis=-1)
# cfunc2 = jit(nopython=True)(pyfunc2)
# self.assertPreciseEqual(pyfunc2(arr), cfunc2(arr))

def test_cumsum(self):
pyfunc = array_cumsum
cfunc = jit(nopython=True)(pyfunc)
Expand Down
4 changes: 2 additions & 2 deletions numba/typing/arraydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,8 @@ def array_attribute_attachment(self, ary):
install_array_method(fname, generic_expand_cumulative)

# Functions that require integer arrays get promoted to float64 return
for fName in ["mean"]:
install_array_method(fName, generic_hetero_real)
# for fName in ["mean"]:
# install_array_method(fName, generic_hetero_real)

# var and std by definition return in real space and int arrays
# get promoted to float64 return
Expand Down
2 changes: 1 addition & 1 deletion numba/typing/npydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _numpy_redirect(fname):
dict(key=numpy_function, method_name=fname))
infer_global(numpy_function, types.Function(cls))

for func in ['min', 'max', 'sum', 'prod', 'mean', 'var', 'std',
for func in ['min', 'max', 'sum', 'prod', 'var', 'std', #'mean',
'cumsum', 'cumprod', 'argmin', 'argmax', 'argsort',
'nonzero', 'ravel']:
_numpy_redirect(func)
Expand Down