Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

BUG: ma: ma.average didn't handle complex arrays correctly (issue gh-…

  • Loading branch information...
commit db3e22900d4392031337f99cbdf1a0c8bd941fcc 1 parent e823627
@WarrenWeckesser WarrenWeckesser authored
Showing with 52 additions and 9 deletions.
  1. +7 −7 numpy/ma/extras.py
  2. +45 −2 numpy/ma/tests/test_extras.py
View
14 numpy/ma/extras.py
@@ -453,7 +453,8 @@ def average(a, axis=None, weights=None, returned=False):
The weights array can either be 1-D (in which case its length must be
the size of `a` along the given axis) or of the same shape as `a`.
If ``weights=None``, then all data in `a` are assumed to have a
- weight equal to one.
+ weight equal to one. If `weights` is complex, the imaginary parts
+ are ignored.
returned : bool, optional
Flag indicating whether a tuple ``(result, sum of weights)``
should be returned as output (True), or just the result (False).
@@ -513,7 +514,7 @@ def average(a, axis=None, weights=None, returned=False):
if mask is nomask:
if weights is None:
d = ash[axis] * 1.0
- n = add.reduce(a._data, axis, dtype=float)
+ n = add.reduce(a._data, axis)
else:
w = filled(weights, 0.0)
wsh = w.shape
@@ -529,14 +530,14 @@ def average(a, axis=None, weights=None, returned=False):
r = [None] * len(ash)
r[axis] = slice(None, None, 1)
w = eval ("w[" + repr(tuple(r)) + "] * ones(ash, float)")
- n = add.reduce(a * w, axis, dtype=float)
+ n = add.reduce(a * w, axis)
d = add.reduce(w, axis, dtype=float)
del w, r
else:
raise ValueError('average: weights wrong shape.')
else:
if weights is None:
- n = add.reduce(a, axis, dtype=float)
+ n = add.reduce(a, axis)
d = umath.add.reduce((-mask), axis=axis, dtype=float)
else:
w = filled(weights, 0.0)
@@ -545,7 +546,7 @@ def average(a, axis=None, weights=None, returned=False):
wsh = (1,)
if wsh == ash:
w = array(w, dtype=float, mask=mask, copy=0)
- n = add.reduce(a * w, axis, dtype=float)
+ n = add.reduce(a * w, axis)
d = add.reduce(w, axis, dtype=float)
elif wsh == (ash[axis],):
ni = ash[axis]
@@ -553,7 +554,7 @@ def average(a, axis=None, weights=None, returned=False):
r[axis] = slice(None, None, 1)
w = eval ("w[" + repr(tuple(r)) + \
"] * masked_array(ones(ash, float), mask)")
- n = add.reduce(a * w, axis, dtype=float)
+ n = add.reduce(a * w, axis)
d = add.reduce(w, axis, dtype=float)
else:
raise ValueError('average: weights wrong shape.')
@@ -578,7 +579,6 @@ def average(a, axis=None, weights=None, returned=False):
return result
-
def median(a, axis=None, out=None, overwrite_input=False):
"""
Compute the median along the specified axis.
View
47 numpy/ma/tests/test_extras.py
@@ -24,8 +24,7 @@
compress_rowcols, mask_rowcols,
clump_masked, clump_unmasked,
flatnotmasked_contiguous, notmasked_contiguous, notmasked_edges,
- masked_all, masked_all_like,
- )
+ masked_all, masked_all_like)
class TestGeneric(TestCase):
@@ -199,6 +198,50 @@ def test_onintegers_with_mask(self):
a = average(array([1, 2, 3, 4], mask=[False, False, True, True]))
assert_equal(a, 1.5)
+ def test_complex(self):
+ # Test with complex data.
+ # (Regression test for https://github.com/numpy/numpy/issues/2684)
+ mask = np.array([[0, 0, 0, 1, 0],
+ [0, 1, 0, 0, 0]], dtype=bool)
+ a = masked_array([[0, 1+2j, 3+4j, 5+6j, 7+8j],
+ [9j, 0+1j, 2+3j, 4+5j, 7+7j]],
+ mask=mask)
+
+ av = average(a)
+ expected = np.average(a.compressed())
+ assert_almost_equal(av.real, expected.real)
+ assert_almost_equal(av.imag, expected.imag)
+
+ av0 = average(a, axis=0)
+ expected0 = average(a.real, axis=0) + average(a.imag, axis=0)*1j
+ assert_almost_equal(av0.real, expected0.real)
+ assert_almost_equal(av0.imag, expected0.imag)
+
+ av1 = average(a, axis=1)
+ expected1 = average(a.real, axis=1) + average(a.imag, axis=1)*1j
+ assert_almost_equal(av1.real, expected1.real)
+ assert_almost_equal(av1.imag, expected1.imag)
+
+ # Test with the 'weights' argument.
+ wts = np.array([[0.5, 1.0, 2.0, 1.0, 0.5],
+ [1.0, 1.0, 1.0, 1.0, 1.0]])
+ wav = average(a, weights=wts)
+ expected = np.average(a.compressed(), weights=wts[~mask])
+ assert_almost_equal(wav.real, expected.real)
+ assert_almost_equal(wav.imag, expected.imag)
+
+ wav0 = average(a, weights=wts, axis=0)
+ expected0 = (average(a.real, weights=wts, axis=0) +
+ average(a.imag, weights=wts, axis=0)*1j)
+ assert_almost_equal(wav0.real, expected0.real)
+ assert_almost_equal(wav0.imag, expected0.imag)
+
+ wav1 = average(a, weights=wts, axis=1)
+ expected1 = (average(a.real, weights=wts, axis=1) +
+ average(a.imag, weights=wts, axis=1)*1j)
+ assert_almost_equal(wav1.real, expected1.real)
+ assert_almost_equal(wav1.imag, expected1.imag)
+
class TestConcatenator(TestCase):
"""
Please sign in to comment.
Something went wrong with that request. Please try again.