Skip to content

Commit

Permalink
Merge pull request #3907 from mhvk/ma/allow-subclass-in-ufunc
Browse files Browse the repository at this point in the history
BUG allow subclasses in MaskedArray ufuncs -- for non-ndarray _data
  • Loading branch information
charris committed Jun 17, 2015
2 parents f4e0bdd + 3c6b6ba commit 6c1e1de
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 104 deletions.
188 changes: 94 additions & 94 deletions numpy/ma/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,14 +880,10 @@ def __call__ (self, a, *args, **kwargs):
except TypeError:
pass
# Transform to
if isinstance(a, MaskedArray):
subtype = type(a)
else:
subtype = MaskedArray
result = result.view(subtype)
result._mask = m
result._update_from(a)
return result
masked_result = result.view(get_masked_subclass(a))
masked_result._mask = m
masked_result._update_from(result)
return masked_result
#
def __str__ (self):
return "Masked version of %s. [Invalid values are masked]" % str(self.f)
Expand Down Expand Up @@ -928,8 +924,12 @@ def __init__ (self, mbfunc, fillx=0, filly=0):
def __call__ (self, a, b, *args, **kwargs):
"Execute the call behavior."
# Get the data, as ndarray
(da, db) = (getdata(a, subok=False), getdata(b, subok=False))
# Get the mask
(da, db) = (getdata(a), getdata(b))
# Get the result
with np.errstate():
np.seterr(divide='ignore', invalid='ignore')
result = self.f(da, db, *args, **kwargs)
# Get the mask for the result
(ma, mb) = (getmask(a), getmask(b))
if ma is nomask:
if mb is nomask:
Expand All @@ -940,66 +940,64 @@ def __call__ (self, a, b, *args, **kwargs):
m = umath.logical_or(ma, getmaskarray(b))
else:
m = umath.logical_or(ma, mb)
# Get the result
with np.errstate(divide='ignore', invalid='ignore'):
result = self.f(da, db, *args, **kwargs)
# check it worked
if result is NotImplemented:
return NotImplemented
# Case 1. : scalar
if not result.ndim:
if m:
return masked
return result
# Case 2. : array
# Revert result to da where masked
if m is not nomask:
np.copyto(result, da, casting='unsafe', where=m)
if m is not nomask and m.any():
# any errors, just abort; impossible to guarantee masked values
try:
np.copyto(result, 0, casting='unsafe', where=m)
# avoid using "*" since this may be overlaid
masked_da = umath.multiply(m, da)
# only add back if it can be cast safely
if np.can_cast(masked_da.dtype, result.dtype, casting='safe'):
result += masked_da
except:
pass
# Transforms to a (subclass of) MaskedArray
result = result.view(get_masked_subclass(a, b))
result._mask = m
# Update the optional info from the inputs
if isinstance(b, MaskedArray):
if isinstance(a, MaskedArray):
result._update_from(a)
else:
result._update_from(b)
elif isinstance(a, MaskedArray):
result._update_from(a)
return result

masked_result = result.view(get_masked_subclass(a, b))
masked_result._mask = m
masked_result._update_from(result)
return masked_result

def reduce(self, target, axis=0, dtype=None):
"""Reduce `target` along the given `axis`."""
if isinstance(target, MaskedArray):
tclass = type(target)
else:
tclass = MaskedArray
tclass = get_masked_subclass(target)
m = getmask(target)
t = filled(target, self.filly)
if t.shape == ():
t = t.reshape(1)
if m is not nomask:
m = make_mask(m, copy=1)
m.shape = (1,)

if m is nomask:
return self.f.reduce(t, axis).view(tclass)
t = t.view(tclass)
t._mask = m
tr = self.f.reduce(getdata(t), axis, dtype=dtype or t.dtype)
mr = umath.logical_and.reduce(m, axis)
tr = tr.view(tclass)
if mr.ndim > 0:
tr._mask = mr
return tr
elif mr:
return masked
return tr
tr = self.f.reduce(t, axis)
mr = nomask
else:
tr = self.f.reduce(t, axis, dtype=dtype or t.dtype)
mr = umath.logical_and.reduce(m, axis)

def outer (self, a, b):
if not tr.shape:
if mr:
return masked
else:
return tr
masked_tr = tr.view(tclass)
masked_tr._mask = mr
masked_tr._update_from(tr)
return masked_tr

def outer(self, a, b):
"""Return the function applied to the outer product of a and b.
"""
(da, db) = (getdata(a), getdata(b))
d = self.f.outer(da, db)
ma = getmask(a)
mb = getmask(b)
if ma is nomask and mb is nomask:
Expand All @@ -1010,31 +1008,28 @@ def outer (self, a, b):
m = umath.logical_or.outer(ma, mb)
if (not m.ndim) and m:
return masked
(da, db) = (getdata(a), getdata(b))
d = self.f.outer(da, db)
# check it worked
if d is NotImplemented:
return NotImplemented
if m is not nomask:
np.copyto(d, da, where=m)
if d.shape:
d = d.view(get_masked_subclass(a, b))
d._mask = m
return d
if not d.shape:
return d
masked_d = d.view(get_masked_subclass(a, b))
masked_d._mask = m
masked_d._update_from(d)
return masked_d

def accumulate (self, target, axis=0):
def accumulate(self, target, axis=0):
"""Accumulate `target` along `axis` after filling with y fill
value.
"""
if isinstance(target, MaskedArray):
tclass = type(target)
else:
tclass = MaskedArray
tclass = get_masked_subclass(target)
t = filled(target, self.filly)
return self.f.accumulate(t, axis).view(tclass)
result = self.f.accumulate(t, axis)
masked_result = result.view(tclass)
masked_result._update_from(result)
return masked_result

def __str__ (self):
def __str__(self):
return "Masked version of " + str(self.f)


Expand Down Expand Up @@ -1074,19 +1069,15 @@ def __init__ (self, dbfunc, domain, fillx=0, filly=0):

def __call__(self, a, b, *args, **kwargs):
"Execute the call behavior."
# Get the data and the mask
(da, db) = (getdata(a, subok=False), getdata(b, subok=False))
(ma, mb) = (getmask(a), getmask(b))
# Get the data
(da, db) = (getdata(a), getdata(b))
# Get the result
with np.errstate(divide='ignore', invalid='ignore'):
result = self.f(da, db, *args, **kwargs)
# check it worked
if result is NotImplemented:
return NotImplemented
# Get the mask as a combination of ma, mb and invalid
# Get the mask as a combination of the source masks and invalid
m = ~umath.isfinite(result)
m |= ma
m |= mb
m |= getmask(a)
m |= getmask(b)
# Apply the domain
domain = ufunc_domain.get(self.f, None)
if domain is not None:
Expand All @@ -1097,18 +1088,23 @@ def __call__(self, a, b, *args, **kwargs):
return masked
else:
return result
# When the mask is True, put back da
np.copyto(result, da, casting='unsafe', where=m)
result = result.view(get_masked_subclass(a, b))
result._mask = m
if isinstance(b, MaskedArray):
if isinstance(a, MaskedArray):
result._update_from(a)
else:
result._update_from(b)
elif isinstance(a, MaskedArray):
result._update_from(a)
return result
# When the mask is True, put back da if possible
# any errors, just abort; impossible to guarantee masked values
try:
np.copyto(result, 0, casting='unsafe', where=m)
# avoid using "*" since this may be overlaid
masked_da = umath.multiply(m, da)
# only add back if it can be cast safely
if np.can_cast(masked_da.dtype, result.dtype, casting='safe'):
result += masked_da
except:
pass

# Transforms to a (subclass of) MaskedArray
masked_result = result.view(get_masked_subclass(a, b))
masked_result._mask = m
masked_result._update_from(result)
return masked_result

def __str__ (self):
return "Masked version of " + str(self.f)
Expand Down Expand Up @@ -1361,7 +1357,7 @@ def getmaskarray(arr):
"""
mask = getmask(arr)
if mask is nomask:
mask = make_mask_none(np.shape(arr), getdata(arr).dtype)
mask = make_mask_none(np.shape(arr), getattr(arr, 'dtype', None))
return mask

def is_mask(m):
Expand Down Expand Up @@ -3756,34 +3752,38 @@ def __ne__(self, other):
return check
#
def __add__(self, other):
"Add other to self, and return a new masked array."
"Add self to other, and return a new masked array."
if self._delegate_binop(other):
return NotImplemented
return add(self, other)
#
def __radd__(self, other):
"Add other to self, and return a new masked array."
return add(self, other)
# In analogy with __rsub__ and __rdiv__, use original order:
# we get here from `other + self`.
return add(other, self)
#
def __sub__(self, other):
"Subtract other to self, and return a new masked array."
"Subtract other from self, and return a new masked array."
if self._delegate_binop(other):
return NotImplemented
return subtract(self, other)
#
def __rsub__(self, other):
"Subtract other to self, and return a new masked array."
"Subtract self from other, and return a new masked array."
return subtract(other, self)
#
def __mul__(self, other):
"Multiply other by self, and return a new masked array."
"Multiply self by other, and return a new masked array."
if self._delegate_binop(other):
return NotImplemented
return multiply(self, other)
#
def __rmul__(self, other):
"Multiply other by self, and return a new masked array."
return multiply(self, other)
# In analogy with __rsub__ and __rdiv__, use original order:
# we get here from `other * self`.
return multiply(other, self)
#
def __div__(self, other):
"Divide other into self, and return a new masked array."
Expand All @@ -3798,7 +3798,7 @@ def __truediv__(self, other):
return true_divide(self, other)
#
def __rtruediv__(self, other):
"Divide other into self, and return a new masked array."
"Divide self into other, and return a new masked array."
return true_divide(other, self)
#
def __floordiv__(self, other):
Expand All @@ -3808,7 +3808,7 @@ def __floordiv__(self, other):
return floor_divide(self, other)
#
def __rfloordiv__(self, other):
"Divide other into self, and return a new masked array."
"Divide self into other, and return a new masked array."
return floor_divide(other, self)
#
def __pow__(self, other):
Expand All @@ -3818,7 +3818,7 @@ def __pow__(self, other):
return power(self, other)
#
def __rpow__(self, other):
"Raise self to the power other, masking the potential NaNs/Infs"
"Raise other to the power self, masking the potential NaNs/Infs"
return power(other, self)
#............................................
def __iadd__(self, other):
Expand Down
23 changes: 23 additions & 0 deletions numpy/ma/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,29 @@ def __rmul__(self, other):
assert_(me * a == "My mul")
assert_(a * me == "My rmul")

# and that __array_priority__ is respected
class MyClass2(object):
__array_priority__ = 100

def __mul__(self, other):
return "Me2mul"

def __rmul__(self, other):
return "Me2rmul"

def __rdiv__(self, other):
return "Me2rdiv"

__rtruediv__ = __rdiv__

me_too = MyClass2()
assert_(a.__mul__(me_too) is NotImplemented)
assert_(all(multiply.outer(a, me_too) == "Me2rmul"))
assert_(a.__truediv__(me_too) is NotImplemented)
assert_(me_too * a == "Me2mul")
assert_(a * me_too == "Me2rmul")
assert_(a / me_too == "Me2rdiv")


#------------------------------------------------------------------------------
class TestMaskedArrayInPlaceArithmetics(TestCase):
Expand Down

0 comments on commit 6c1e1de

Please sign in to comment.