Skip to content

Commit

Permalink
Backport PR astropy#16499: Ensure Masked quantity works with np.block
Browse files Browse the repository at this point in the history
  • Loading branch information
mhvk authored and meeseeksmachine committed May 27, 2024
1 parent 38fb74e commit 20c5698
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
34 changes: 18 additions & 16 deletions astropy/units/quantity_helper/function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,32 +423,34 @@ def concatenate(arrays, axis=0, out=None, **kwargs):
return (arrays,), kwargs, unit, out


def _block(arrays, max_depth, result_ndim, depth=0):
# Block by concatenation, copied from np._core.shape_base,
# but ensuring that we call regular concatenate.
if depth < max_depth:
arrs = [_block(arr, max_depth, result_ndim, depth+1)
for arr in arrays]
# The one difference with the numpy code.
return np.concatenate(arrs, axis=-(max_depth-depth))
else:
return np_core.shape_base._atleast_nd(arrays, result_ndim)


@dispatched_function
def block(arrays):
# We need to override block since the numpy implementation can take two
# different paths, one for concatenation, one for creating a large empty
# result array in which parts are set. Each assumes array input and
# cannot be used directly. Since it would be very costly to inspect all
# arrays and then turn them back into a nested list, we just copy here the
# second implementation, np.core.shape_base._block_slicing, since it is
# shortest and easiest.
# first implementation, np.core.shape_base._block, which is the easiest to
# adjust while making sure that both units and class are properly kept.
(arrays, list_ndim, result_ndim, final_size) = np_core.shape_base._block_setup(
arrays
)
shape, slices, arrays = np_core.shape_base._block_info_recursion(
arrays, list_ndim, result_ndim
)
# Here, one line of difference!
arrays, unit = _quantities2arrays(*arrays)
# Back to _block_slicing
dtype = np.result_type(*[arr.dtype for arr in arrays])
F_order = all(arr.flags["F_CONTIGUOUS"] for arr in arrays)
C_order = all(arr.flags["C_CONTIGUOUS"] for arr in arrays)
order = "F" if F_order and not C_order else "C"
result = np.empty(shape=shape, dtype=dtype, order=order)
for the_slice, arr in zip(slices, arrays):
result[(Ellipsis,) + the_slice] = arr
return result, unit, None
result = _block(arrays, list_ndim, result_ndim)
if list_ndim == 0:
result = result.copy()
return result, None, None


@function_helper
Expand Down
12 changes: 11 additions & 1 deletion astropy/utils/masked/tests/test_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
from numpy.testing import assert_array_equal

import astropy.units as u
from astropy.units.tests.test_quantity_non_ufuncs import (
CheckSignatureCompatibilityBase,
get_covered_functions,
Expand Down Expand Up @@ -532,12 +533,21 @@ def test_dstack(self):

def test_block(self):
self.check(np.block)

# Check that this also works on MaskedQuantity, properly propagating
# the fact that we are based on MaskedNDArray.
self.check(np.block, ma_list=[self.ma << u.m, self.mc << u.km])
# And check a mix of float and masked values, with different dtype.
out = np.block([[0.0, Masked(1.0, True)], [Masked(1, False), Masked(2, False)]])
expected = np.array([[0, 1.0], [1, 2]])
expected_mask = np.array([[False, True], [False, False]])
assert_array_equal(out.unmasked, expected)
assert_array_equal(out.mask, expected_mask)
# And check single array.
in2 = Masked([1.0], [True])
out2 = np.block(Masked([1.0], [True]))
assert not np.may_share_memory(out2, in2)
assert_array_equal(out2.unmasked, in2.unmasked)
assert_array_equal(out2.mask, in2.mask)

def test_append(self):
out = np.append(self.ma, self.mc, axis=1)
Expand Down
1 change: 1 addition & 0 deletions docs/changes/utils/16499.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``MaskedQuantity`` now works properly with ``np.block``.

0 comments on commit 20c5698

Please sign in to comment.