Skip to content

Commit

Permalink
Make mask_invalid consistent with mask_where when copy is set to Fals…
Browse files Browse the repository at this point in the history
…e. Add test for type erroring.
  • Loading branch information
cmarmo committed Jul 26, 2022
1 parent 45bc13e commit 44c8da9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
17 changes: 4 additions & 13 deletions numpy/ma/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,20 +2356,11 @@ def masked_invalid(a, copy=True):
fill_value=1e+20)
"""
a = np.array(a, copy=copy, subok=True)
mask = getattr(a, '_mask', None)
if mask is not None:
condition = ~(np.isfinite(getdata(a)))
if mask is not nomask:
condition |= mask
cls = type(a)
else:
condition = ~(np.isfinite(a))
cls = MaskedArray
result = a.view(cls)
result._mask = condition
return result

try:
return masked_where(~(np.isfinite(getdata(a))), a, copy=copy)
except TypeError:
raise

###############################################################################
# Printing options #
Expand Down
7 changes: 7 additions & 0 deletions numpy/ma/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4496,6 +4496,13 @@ def test_where_structured_masked(self):
assert_equal(ma, expected)
assert_equal(ma.mask, expected.mask)

def test_masked_invalid_error(self):
a = np.arange(5, dtype=object)
a[3] = np.PINF
a[2] = np.NaN
with pytest.raises(TypeError, match="not supported for the input types"):
np.ma.masked_invalid(a)

def test_choose(self):
# Test choose
choices = [[0, 1, 2, 3], [10, 11, 12, 13],
Expand Down

0 comments on commit 44c8da9

Please sign in to comment.