Skip to content

Commit

Permalink
Merge pull request #7304 from hadia206/continue_3655
Browse files Browse the repository at this point in the history
Continue PR#3655: add support for np.average
  • Loading branch information
sklam committed Aug 19, 2021
2 parents fe2e735 + 2bf1df3 commit e031b70
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
37 changes: 37 additions & 0 deletions numba/np/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,43 @@ def flat_any(a):
return flat_any


@overload(np.average)
def np_average(arr, axis=None, weights=None):

if weights is None or isinstance(weights, types.NoneType):
def np_average_impl(arr, axis=None, weights=None):
arr = np.asarray(arr)
return np.mean(arr)
else:
if axis is None or isinstance(axis, types.NoneType):
def np_average_impl(arr, axis=None, weights=None):
arr = np.asarray(arr)
weights = np.asarray(weights)

if arr.shape != weights.shape:
if axis is None:
raise TypeError(
"Numba does not support average when shapes of "
"a and weights differ.")
if weights.ndim != 1:
raise TypeError(
"1D weights expected when shapes of "
"a and weights differ.")

scl = np.sum(weights)
if scl == 0.0:
raise ZeroDivisionError(
"Weights sum to zero, can't be normalized.")

avg = np.sum(np.multiply(arr, weights)) / scl
return avg
else:
def np_average_impl(arr, axis=None, weights=None):
raise TypeError("Numba does not support average with axis.")

return np_average_impl


def get_isnan(dtype):
"""
A generic isnan() function
Expand Down
2 changes: 0 additions & 2 deletions numba/tests/test_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def array_dot_chain(a, b):
def array_ctor(n, dtype):
return np.ones(n, dtype=dtype)


class TestArrayMethods(MemoryLeakMixin, TestCase):
"""
Test various array methods and array-related functions.
Expand Down Expand Up @@ -1472,7 +1471,6 @@ def test_array_ctor_with_dtype_arg(self):
args = n, np.dtype('f4')
np.testing.assert_array_equal(pyfunc(*args), cfunc(*args))


class TestArrayComparisons(TestCase):

def test_identity(self):
Expand Down
89 changes: 89 additions & 0 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ def np_trapz_x_dx(y, x, dx):
return np.trapz(y, x, dx)


def np_average(a, axis=None, weights=None):
return np.average(a, axis=axis, weights=weights)


def interp(x, xp, fp):
return np.interp(x, xp, fp)

Expand Down Expand Up @@ -3083,6 +3087,91 @@ def check_not_ok(params):

self.assertIn('y cannot be a scalar', str(e.exception))

def test_average(self):

#array of random numbers
N = 100
a = np.random.ranf(N) * 100
w = np.random.ranf(N) * 100
w0 = np.zeros(N)

#boolean array and weights
a_bool = np.random.ranf(N) > 0.5
w_bool = np.random.ranf(N) > 0.5

#array of random ints
a_int = np.random.randint(101, size=N)
w_int = np.random.randint(101, size=N)

#3D array of random numbers
d0 = 100
d1 = 50
d2 = 25
a_3d = np.random.rand(d0,d1,d2) * 100
w_3d = np.random.rand(d0,d1,d2) * 100

pyfunc = np_average
cfunc = jit(nopython=True)(pyfunc)

#test case for average with weights
#(number of elements in array and weight array are equal)
self.assertAlmostEqual( pyfunc(a,weights=w),
cfunc(a,weights=w), places=10)
self.assertAlmostEqual( pyfunc(a_3d,weights=w_3d),
cfunc(a_3d,weights=w_3d), places=10)

#test case for average with array and weights with
#int datatype (number of elements in array and weight array are equal)
self.assertAlmostEqual( pyfunc(a_int,weights=w_int),
cfunc(a_int,weights=w_int), places=10)

#test case for average with boolean weights
self.assertAlmostEqual( pyfunc(a,weights=w_bool),
cfunc(a,weights=w_bool), places=10)
self.assertAlmostEqual( pyfunc(a_bool,weights=w),
cfunc(a_bool,weights=w), places=10)
self.assertAlmostEqual( pyfunc(a_bool, weights=w_bool),
cfunc(a_bool, weights=w_bool), places=10)

#test case for average without weights
self.assertAlmostEqual(pyfunc(a), cfunc(a), places=10)
self.assertAlmostEqual(pyfunc(a_3d), cfunc(a_3d), places=10)

def test_weights_zero_sum(data, weights):
with self.assertRaises(ZeroDivisionError) as e:
cfunc(data, weights=weights)
err = e.exception
self.assertEqual(str(err),
"Weights sum to zero, can't be normalized.")

#test case when sum of weights is zero
test_weights_zero_sum(a, weights=w0)

def test_1D_weights(data, weights):
with self.assertRaises(TypeError) as e:
cfunc(data, weights=weights)
err = e.exception
self.assertEqual(str(err),
"Numba does not support average when shapes of "
"a and weights differ.")

def test_1D_weights_axis(data, axis, weights):
with self.assertRaises(TypeError) as e:
cfunc(data,axis=axis, weights=weights)
err = e.exception
self.assertEqual(str(err),
"Numba does not support average with axis.")

#small case to test exceptions for 2D array and 1D weights
data = np.arange(6).reshape((3,2,1))
w = np.asarray([1. / 4, 3. / 4])

#test without axis argument
test_1D_weights(data, weights=w)

#test with axis argument
test_1D_weights_axis(data, axis=1, weights=w)

def test_interp_basic(self):
pyfunc = interp
cfunc = jit(nopython=True)(pyfunc)
Expand Down

0 comments on commit e031b70

Please sign in to comment.