Skip to content

Commit

Permalink
Merge pull request #2058 from AdeelH/fix-xr
Browse files Browse the repository at this point in the history
Fix handing of some edge cases when reading chips from `XarraySource`
  • Loading branch information
AdeelH committed Feb 12, 2024
2 parents 2359be1 + dd2562c commit cf95e94
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def __init__(self,
self.full_extent = Box(0, 0, height, width)
if bbox is None:
bbox = self.full_extent
else:
if bbox not in self.full_extent:
new_bbox = bbox.intersection(self.full_extent)
log.warning(f'Clipping ({bbox}) to the DataArray\'s '
f'full extent ({self.full_extent}). '
f'New bbox={new_bbox}')
bbox = new_bbox

super().__init__(
channel_order,
Expand Down Expand Up @@ -133,20 +140,23 @@ def _get_chip(self,
out_shape: Optional[Tuple[int, ...]] = None) -> np.ndarray:
window = window.to_global_coords(self.bbox)

yslice, xsclice = window.to_slices()
window_within_bbox = window.intersection(self.bbox)

yslice, xslice = window_within_bbox.to_slices()
if self.temporal:
chip = self.data_array.isel(
x=xsclice, y=yslice, band=bands, time=time).to_numpy()
x=xslice, y=yslice, band=bands, time=time).to_numpy()
else:
chip = self.data_array.isel(
x=xsclice, y=yslice, band=bands).to_numpy()
x=xslice, y=yslice, band=bands).to_numpy()

*batch_dims, h, w, c = chip.shape
if window.size != (h, w):
window_actual = window.intersection(self.full_extent)
yslice, xsclice = window_actual.to_local_coords(window).to_slices()
if window != window_within_bbox:
*batch_dims, h, w, c = chip.shape
# coords of window_within_bbox within window
yslice, xslice = window_within_bbox.to_local_coords(
window).to_slices()
tmp = np.zeros((*batch_dims, *window.size, c))
tmp[..., yslice, xsclice, :] = chip
tmp[..., yslice, xslice, :] = chip
chip = tmp

chip = fill_overflow(self.bbox, window, chip)
Expand Down
29 changes: 29 additions & 0 deletions tests/core/data/raster_source/test_xarray_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,35 @@ def test_get_raw_chip(self):
chip_expected = np.array([[[0, 1, 2, 3]]], dtype=arr.dtype)
np.testing.assert_array_equal(chip, chip_expected)

def test_get_raw_chip_overflowing_window(self):
arr = np.arange(100).reshape(10, 10, 1)
da = DataArray(arr, dims=['y', 'x', 'band'])
rs = XarraySource(da, IdentityCRSTransformer(), bbox=Box(2, 2, 7, 7))

chip = rs.get_raw_chip(Box(3, 3, 7, 7))
chip_expected = np.zeros((4, 4, 1))
chip_expected[:2, :2] = arr[5:7, 5:7]
np.testing.assert_array_equal(chip, chip_expected)

chip = rs.get_raw_chip(Box(-2, -2, 2, 2))
chip_expected = np.zeros((4, 4, 1))
chip_expected[2:, 2:] = arr[2:4, 2:4]
np.testing.assert_array_equal(chip, chip_expected)

chip = rs.get_raw_chip(Box(-5, -5, 0, 0))
chip_expected = np.zeros((5, 5, 1))
np.testing.assert_array_equal(chip, chip_expected)

chip = rs.get_raw_chip(Box(6, 6, 9, 9))
chip_expected = np.zeros((3, 3, 1))
np.testing.assert_array_equal(chip, chip_expected)

def test_get_bbox_overflows_full_extent(self):
arr = np.empty((5, 5, 1))
da = DataArray(arr, dims=['y', 'x', 'band'])
rs = XarraySource(da, IdentityCRSTransformer(), bbox=Box(2, 2, 5, 7))
self.assertEqual(rs.bbox, Box(2, 2, 5, 5))

def test_get_chip(self):
arr = np.ones((5, 5, 4), dtype=np.uint8)
arr *= np.arange(4, dtype=np.uint8)
Expand Down

0 comments on commit cf95e94

Please sign in to comment.