Skip to content

Commit

Permalink
Backport PR astropy#14995: Bugfix for bitmasks passed to nddata
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 authored and meeseeksmachine committed Jul 5, 2023
1 parent 2cc7c5b commit 697faa5
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
4 changes: 2 additions & 2 deletions astropy/nddata/mixins/ndarithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,10 @@ def _arithmetic_mask(self, operation, operand, handle_mask, axis=None, **kwds):
elif self.mask is None and operand is not None:
# Make a copy so there is no reference in the result.
return deepcopy(operand.mask)
elif operand is None:
elif operand.mask is None:
return deepcopy(self.mask)
else:
# Now lets calculate the resulting mask (operation enforces copy)
# Now let's calculate the resulting mask (operation enforces copy)
return handle_mask(self.mask, operand.mask, **kwds)

def _arithmetic_wcs(self, operation, operand, compare_wcs, **kwds):
Expand Down
39 changes: 39 additions & 0 deletions astropy/nddata/mixins/tests/test_ndarithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,3 +1310,42 @@ def test_raise_method_not_supported():
# raise error for unsupported propagation operations:
with pytest.raises(ValueError):
ndd1.uncertainty.propagate(np.mod, ndd2, result, correlation)


def test_nddata_bitmask_arithmetic():
# NDData.mask is usually assumed to be boolean, but could be
# a bitmask. Ensure bitmask works:
array = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]])
mask = np.array([[0, 1, 64], [8, 0, 1], [2, 1, 0]])

nref_nomask = NDDataRef(array)
nref_masked = NDDataRef(array, mask=mask)

# multiply no mask by constant (no mask * no mask)
assert nref_nomask.multiply(1.0, handle_mask=np.bitwise_or).mask is None

# multiply no mask by itself (no mask * no mask)
assert nref_nomask.multiply(nref_nomask, handle_mask=np.bitwise_or).mask is None

# multiply masked by constant (mask * no mask)
np.testing.assert_equal(
nref_masked.multiply(1.0, handle_mask=np.bitwise_or).mask, mask
)

# multiply masked by itself (mask * mask)
np.testing.assert_equal(
nref_masked.multiply(nref_masked, handle_mask=np.bitwise_or).mask, mask
)

# multiply masked by no mask (mask * no mask)
np.testing.assert_equal(
nref_masked.multiply(nref_nomask, handle_mask=np.bitwise_or).mask, mask
)

# check bitwise logic still works
other_mask = np.array([[64, 1, 0], [2, 1, 0], [8, 0, 2]])
nref_mask_other = NDDataRef(array, mask=other_mask)
np.testing.assert_equal(
nref_mask_other.multiply(nref_masked, handle_mask=np.bitwise_or).mask,
np.bitwise_or(mask, other_mask),
)
2 changes: 2 additions & 0 deletions docs/changes/nddata/14995.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Restore bitmask propagation behavior in ``NDData.mask``, plus a fix
for arithmetic between masked and unmasked ``NDData`` objects.

0 comments on commit 697faa5

Please sign in to comment.