From f7c52ab544fac0190c6444c7295dd31ca1057841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <8431159+mtsokol@users.noreply.github.com> Date: Thu, 11 Mar 2021 23:59:16 +0100 Subject: [PATCH] BUG: Fixed ``where`` keyword for ``np.mean`` & ``np.var`` methods (gh-18560) * Fixed keyword bug * Added test case * Reverted to original notation * Added tests for var and std Closes gh-18552 --- numpy/core/_methods.py | 4 ++-- numpy/core/tests/test_multiarray.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index c730e2035f36..fc118326a572 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -164,7 +164,7 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True): is_float16_result = False rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) - if rcount == 0 if where is True else umr_any(rcount == 0): + if rcount == 0 if where is True else umr_any(rcount == 0, axis=None): warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2) # Cast bool, unsigned int, and int to float64 by default @@ -197,7 +197,7 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where) # Make this warning show up on top. - if ddof >= rcount if where is True else umr_any(ddof >= rcount): + if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None): warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index bd8c51ab78fd..3ce46c43f472 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -5713,6 +5713,15 @@ def test_mean_where(self): np.array(_res)) assert_allclose(np.mean(a, axis=_ax, where=_wh), np.array(_res)) + + a3d = np.arange(16).reshape((2, 2, 4)) + _wh_partial = np.array([False, True, True, False]) + _res = [[1.5, 5.5], [9.5, 13.5]] + assert_allclose(a3d.mean(axis=2, where=_wh_partial), + np.array(_res)) + assert_allclose(np.mean(a3d, axis=2, where=_wh_partial), + np.array(_res)) + with pytest.warns(RuntimeWarning) as w: assert_allclose(a.mean(axis=1, where=wh_partial), np.array([np.nan, 5.5, 9.5, np.nan])) @@ -5788,6 +5797,15 @@ def test_var_where(self): np.array(_res)) assert_allclose(np.var(a, axis=_ax, where=_wh), np.array(_res)) + + a3d = np.arange(16).reshape((2, 2, 4)) + _wh_partial = np.array([False, True, True, False]) + _res = [[0.25, 0.25], [0.25, 0.25]] + assert_allclose(a3d.var(axis=2, where=_wh_partial), + np.array(_res)) + assert_allclose(np.var(a3d, axis=2, where=_wh_partial), + np.array(_res)) + assert_allclose(np.var(a, axis=1, where=wh_full), np.var(a[wh_full].reshape((5, 3)), axis=1)) assert_allclose(np.var(a, axis=0, where=wh_partial), @@ -5827,6 +5845,14 @@ def test_std_where(self): assert_allclose(a.std(axis=_ax, where=_wh), _res) assert_allclose(np.std(a, axis=_ax, where=_wh), _res) + a3d = np.arange(16).reshape((2, 2, 4)) + _wh_partial = np.array([False, True, True, False]) + _res = [[0.5, 0.5], [0.5, 0.5]] + assert_allclose(a3d.std(axis=2, where=_wh_partial), + np.array(_res)) + assert_allclose(np.std(a3d, axis=2, where=_wh_partial), + np.array(_res)) + assert_allclose(a.std(axis=1, where=whf), np.std(a[whf].reshape((5,3)), axis=1)) assert_allclose(np.std(a, axis=1, where=whf),