From a024ee161769b7bc7db5b60cfd4c217836c770fc Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Wed, 5 Apr 2023 14:35:27 +0500 Subject: [PATCH] factor out numpy-like array indexing implementation and add unit tests --- rastervision_core/rastervision/core/box.py | 9 +- .../core/data/label_source/label_source.py | 30 +---- .../object_detection_label_source.py | 29 +---- .../core/data/raster_source/raster_source.py | 32 +---- .../data/raster_source/rasterio_source.py | 36 +----- .../rastervision/core/data/utils/misc.py | 87 ++++++++++++- tests/core/data/utils/test_misc.py | 116 +++++++++++++++++- 7 files changed, 219 insertions(+), 120 deletions(-) diff --git a/rastervision_core/rastervision/core/box.py b/rastervision_core/rastervision/core/box.py index 705e215cc4..30dcce6f1d 100644 --- a/rastervision_core/rastervision/core/box.py +++ b/rastervision_core/rastervision/core/box.py @@ -260,9 +260,12 @@ def to_rasterio(self) -> RioWindow: """Convert to a Rasterio Window.""" return RioWindow.from_slices(*self.normalize().to_slices()) - def to_slices(self) -> Tuple[slice, slice]: - """Convert to slices: ymin:ymax, xmin:xmax""" - return slice(self.ymin, self.ymax), slice(self.xmin, self.xmax) + def to_slices(self, + h_step: Optional[int] = None, + w_step: Optional[int] = None) -> Tuple[slice, slice]: + """Convert to slices: ymin:ymax[:h_step], xmin:xmax[:w_step]""" + return slice(self.ymin, self.ymax, h_step), slice( + self.xmin, self.xmax, w_step) def translate(self, dy: int, dx: int) -> 'Box': """Translate window along y and x axes by the given distances.""" diff --git a/rastervision_core/rastervision/core/data/label_source/label_source.py b/rastervision_core/rastervision/core/data/label_source/label_source.py index 65c85d9e44..ee7cfe463c 100644 --- a/rastervision_core/rastervision/core/data/label_source/label_source.py +++ b/rastervision_core/rastervision/core/data/label_source/label_source.py @@ -1,6 +1,8 @@ from typing import TYPE_CHECKING, Any, Optional from abc import ABC, abstractmethod, abstractproperty +from rastervision.core.data.utils import parse_array_slices + if TYPE_CHECKING: from rastervision.core.box import Box from rastervision.core.data import CRSTransformer, Labels @@ -50,31 +52,5 @@ def set_extent(self, extent: 'Box') -> None: def __getitem__(self, key: Any) -> Any: if isinstance(key, Box): raise NotImplementedError() - elif isinstance(key, slice): - key = [key] - elif isinstance(key, tuple): - pass - else: - raise TypeError('Unsupported key type.') - slices = list(key) - assert 1 <= len(slices) <= 2 - assert all(s is not None for s in slices) - assert isinstance(slices[0], slice) - if len(slices) == 1: - h, = slices - w = slice(None, None) - else: - assert isinstance(slices[1], slice) - h, w = slices - - if any(x is not None and x < 0 - for x in [h.start, h.stop, w.start, w.stop]): - raise NotImplementedError() - - ymin, xmin, ymax, xmax = self.extent - _ymin = 0 if h.start is None else h.start - _xmin = 0 if w.start is None else w.start - _ymax = ymax if h.stop is None else h.stop - _xmax = xmax if w.stop is None else w.stop - window = Box(_ymin, _xmin, _ymax, _xmax) + window, _ = parse_array_slices(key, extent=self.extent, dims=2) return self[window] diff --git a/rastervision_core/rastervision/core/data/label_source/object_detection_label_source.py b/rastervision_core/rastervision/core/data/label_source/object_detection_label_source.py index 37aa350b44..86b10afda6 100644 --- a/rastervision_core/rastervision/core/data/label_source/object_detection_label_source.py +++ b/rastervision_core/rastervision/core/data/label_source/object_detection_label_source.py @@ -6,6 +6,7 @@ from rastervision.core.data.label import ObjectDetectionLabels from rastervision.core.data.label_source import LabelSource from rastervision.core.data.vector_source import VectorSource +from rastervision.core.data.utils import parse_array_slices if TYPE_CHECKING: from rastervision.core.data import CRSTransformer @@ -88,34 +89,8 @@ class labels for each of the boxes. npboxes = labels.get_npboxes() npboxes = ObjectDetectionLabels.global_to_local(npboxes, window) return npboxes, class_ids, 'yxyx' - elif isinstance(key, slice): - key = [key] - elif isinstance(key, tuple): - pass - else: - raise TypeError('Unsupported key type.') - slices = list(key) - assert 1 <= len(slices) <= 2 - assert all(s is not None for s in slices) - assert isinstance(slices[0], slice) - if len(slices) == 1: - h, = slices - w = slice(None, None) - else: - assert isinstance(slices[1], slice) - h, w = slices - - if any(x is not None and x < 0 - for x in [h.start, h.stop, w.start, w.stop]): - raise NotImplementedError() - - ymin, xmin, ymax, xmax = self.extent - _ymin = 0 if h.start is None else h.start - _xmin = 0 if w.start is None else w.start - _ymax = ymax if h.stop is None else h.stop - _xmax = xmax if w.stop is None else w.stop - window = Box(_ymin, _xmin, _ymax, _xmax) + window, (h, w) = parse_array_slices(key, extent=self.extent, dims=2) npboxes, class_ids, fmt = self[window] # rescale if steps specified diff --git a/rastervision_core/rastervision/core/data/raster_source/raster_source.py b/rastervision_core/rastervision/core/data/raster_source/raster_source.py index ddd521887e..6f175f97fa 100644 --- a/rastervision_core/rastervision/core/data/raster_source/raster_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/raster_source.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod, abstractproperty from rastervision.core.box import Box +from rastervision.core.data.utils import parse_array_slices if TYPE_CHECKING: from rastervision.core.data import (CRSTransformer, RasterTransformer) @@ -103,38 +104,13 @@ def _get_chip(self, window: 'Box') -> 'np.ndarray': def __getitem__(self, key: Any) -> 'np.ndarray': if isinstance(key, Box): return self.get_chip(key) - elif isinstance(key, slice): - key = [key] - elif isinstance(key, tuple): - pass - else: - raise TypeError('Unsupported key type.') - slices = list(key) - - assert 1 <= len(slices) <= 2 - assert all(s is not None for s in slices) - assert isinstance(slices[0], slice) - if len(slices) == 1: - h, = slices - w = slice(None, None) - else: - assert isinstance(slices[1], slice) - h, w = slices - - if any(x is not None and x < 0 - for x in [h.start, h.stop, w.start, w.stop]): - raise NotImplementedError() - - ymin, xmin, ymax, xmax = self.extent - _ymin = 0 if h.start is None else h.start - _xmin = 0 if w.start is None else w.start - _ymax = ymax if h.stop is None else h.stop - _xmax = xmax if w.stop is None else w.stop - window = Box(_ymin, _xmin, _ymax, _xmax) + window, (h, w, c) = parse_array_slices(key, extent=self.extent, dims=3) chip = self.get_chip(window) if h.step is not None or w.step is not None: chip = chip[::h.step, ::w.step] + chip = chip[..., c] + return chip def get_chip(self, window: 'Box') -> 'np.ndarray': diff --git a/rastervision_core/rastervision/core/data/raster_source/rasterio_source.py b/rastervision_core/rastervision/core/data/raster_source/rasterio_source.py index 406e879fad..76aac04506 100644 --- a/rastervision_core/rastervision/core/data/raster_source/rasterio_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/rasterio_source.py @@ -11,7 +11,7 @@ from rastervision.core.box import Box from rastervision.core.data.crs_transformer import RasterioCRSTransformer from rastervision.core.data.raster_source import RasterSource -from rastervision.core.data.utils import listify_uris +from rastervision.core.data.utils import listify_uris, parse_array_slices if TYPE_CHECKING: from rasterio.io import DatasetReader @@ -349,38 +349,8 @@ def get_chip(self, def __getitem__(self, key: Any) -> 'np.ndarray': if isinstance(key, Box): return self.get_chip(key) - elif isinstance(key, slice): - key = [key] - elif isinstance(key, tuple): - pass - else: - raise TypeError('Unsupported key type.') - - slices = list(key) - assert 1 <= len(slices) <= 3 - assert all(s is not None for s in slices) - assert isinstance(slices[0], slice) - if len(slices) == 1: - h, = slices - w = slice(None, None) - c = None - elif len(slices) == 2: - assert isinstance(slices[1], slice) - h, w = slices - c = None - else: - h, w, c = slices - - if any(x is not None and x < 0 - for x in [h.start, h.stop, w.start, w.stop]): - raise NotImplementedError() - - ymin, xmin, ymax, xmax = self.extent - _ymin = 0 if h.start is None else h.start - _xmin = 0 if w.start is None else w.start - _ymax = ymax if h.stop is None else h.stop - _xmax = xmax if w.stop is None else w.stop - window = Box(_ymin, _xmin, _ymax, _xmax) + + window, (h, w, c) = parse_array_slices(key, extent=self.extent, dims=3) out_shape = None if h.step is not None or w.step is not None: diff --git a/rastervision_core/rastervision/core/data/utils/misc.py b/rastervision_core/rastervision/core/data/utils/misc.py index d56fb116a8..128c0d9103 100644 --- a/rastervision_core/rastervision/core/data/utils/misc.py +++ b/rastervision_core/rastervision/core/data/utils/misc.py @@ -1,9 +1,11 @@ -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union import logging import numpy as np from PIL import ImageColor +from rastervision.core.box import Box + if TYPE_CHECKING: from rastervision.core.data import (RasterSource, LabelSource, LabelStore) @@ -121,3 +123,86 @@ def match_extents(raster_source: 'RasterSource', # set LabelStore extent to RasterSource extent extent_label_pixel = crs_tf_label.map_to_pixel(extent_img_map) label_source.set_extent(extent_label_pixel) + + +def parse_array_slices(key: Union[tuple, slice], extent: Box, + dims: int = 2) -> Tuple[Box, List[Optional[Any]]]: + """Parse multi-dim array-indexing inputs into a Box and slices. + + Args: + key (Union[tuple, slice]): Input to __getitem__. + extent (Box): Extent of the raster/label source being indexed. + dims (int, optional): Total available indexable dims. Defaults to 2. + + Raises: + NotImplementedError: If not (1 <= dims <= 3). + TypeError: If key is not a slice or tuple. + IndexError: if not (1 <= len(key) <= dims). + TypeError: If the index for any of the dims is None. + ValueError: If more than one Ellipsis ("...") in the input. + ValueError: If h and w indices (first 2 dims) are not slices. + NotImplementedError: If input contains negative values. + + Returns: + Tuple[Box, list]: A Box representing the h and w slices and a list + containing slices/index-values for all the dims. + """ + if isinstance(key, slice): + key = [key] + elif isinstance(key, tuple): + pass + else: + raise TypeError('Unsupported key type.') + + input_slices = list(key) + + if not (1 <= len(input_slices) <= dims): + raise IndexError(f'Too many indices for {dims}-dimensional source.') + if any(s is None for s in input_slices): + raise TypeError('None is not a valid index.') + + if Ellipsis in input_slices: + if input_slices.count(Ellipsis) > 1: + raise ValueError('Only one ellipsis is allowed.') + num_missing_dims = dims - (len(input_slices) - 1) + filler_slices = [None] * num_missing_dims + idx = input_slices.index(Ellipsis) + # at the start + if idx == 0: + dim_slices = filler_slices + input_slices[(idx + 1):] + # somewhere in the middle + elif idx < (len(input_slices) - 1): + dim_slices = ( + input_slices[:idx] + filler_slices + input_slices[(idx + 1):]) + # at the end + else: + dim_slices = input_slices[:idx] + filler_slices + else: + num_missing_dims = dims - len(input_slices) + filler_slices = [None] * num_missing_dims + dim_slices = input_slices + filler_slices + + if dim_slices[0] is None: + dim_slices[0] = slice(None, None) + if dim_slices[1] is None: + dim_slices[1] = slice(None, None) + h, w = dim_slices[:2] + if not (isinstance(h, slice) and isinstance(w, slice)): + raise ValueError('h and w indices (first 2 dims) must be slices.') + + if any(x is not None and x < 0 + for x in [h.start, h.stop, h.step, w.start, w.stop, w.step]): + raise NotImplementedError( + 'Negative indices are currently not supported.') + + # slices with missing endpoints get expanded to the extent limits + H, W = extent.size + _ymin = 0 if h.start is None else h.start + _xmin = 0 if w.start is None else w.start + _ymax = H if h.stop is None else h.stop + _xmax = W if w.stop is None else w.stop + window = Box(_ymin, _xmin, _ymax, _xmax) + + dim_slices = list(window.to_slices(h.step, w.step)) + dim_slices[2:] + + return window, dim_slices diff --git a/tests/core/data/utils/test_misc.py b/tests/core/data/utils/test_misc.py index 337df0e254..af7df5ddd6 100644 --- a/tests/core/data/utils/test_misc.py +++ b/tests/core/data/utils/test_misc.py @@ -1,3 +1,4 @@ +from typing import Any, Tuple import unittest from os.path import join @@ -10,7 +11,8 @@ ObjectDetectionGeoJSONStore, SemanticSegmentationLabelSource, SemanticSegmentationLabelStore) from rastervision.core.data.utils.geojson import geoms_to_geojson -from rastervision.core.data.utils.misc import (match_extents) +from rastervision.core.data.utils.misc import (match_extents, + parse_array_slices) from tests import data_file_path @@ -90,5 +92,117 @@ def test_ss_label_store(self): self.assertEqual(label_store.extent, self.raster_source.extent) +class TestParseArraySlices(unittest.TestCase): + class MockSource: + def __init__(self, dims: int, extent: Box) -> None: + self.dims = dims + self.extent = extent + + def __getitem__(self, key: Any) -> Tuple[Box, list]: + return parse_array_slices(key, self.extent, dims=self.dims) + + def test_errors(self): + source = self.MockSource(dims=3, extent=Box(0, 0, 100, 100)) + self.assertRaises(TypeError, lambda: source['a']) + self.assertRaises(IndexError, lambda: source[:10, :10, 0, 0]) + self.assertRaises(TypeError, lambda: source[:10, :10, None]) + self.assertRaises(ValueError, lambda: source[10, :10]) + self.assertRaises(NotImplementedError, lambda: source[:-10, :10]) + self.assertRaises(NotImplementedError, lambda: source[:10, :-10]) + self.assertRaises(NotImplementedError, lambda: source[::-1]) + self.assertRaises(NotImplementedError, lambda: source[:, ::-1]) + + def test_window(self): + source = self.MockSource(dims=2, extent=Box(0, 0, 100, 100)) + + window, _ = source[5:10, 15:20] + self.assertEqual(window, Box(5, 15, 10, 20)) + + window, _ = source[5:10, :] + self.assertEqual(window, Box(5, 0, 10, 100)) + + window, _ = source[:, 15:20] + self.assertEqual(window, Box(0, 15, 100, 20)) + + window, _ = source[5:10] + self.assertEqual(window, Box(5, 0, 10, 100)) + + def test_dim_slices(self): + source = self.MockSource(dims=3, extent=Box(0, 0, 100, 100)) + + _, dim_slices = source[5:10, 15:20] + self.assertListEqual(dim_slices, [slice(5, 10), slice(15, 20), None]) + + _, dim_slices = source[5:10, 15:20, 0] + self.assertListEqual(dim_slices, [slice(5, 10), slice(15, 20), 0]) + + _, dim_slices = source[5:10, 15:20, 1:4] + self.assertListEqual( + dim_slices, + [slice(5, 10), slice(15, 20), + slice(1, 4)]) + + _, dim_slices = source[5:10, 15:20, [3, 1]] + self.assertListEqual(dim_slices, [slice(5, 10), slice(15, 20), [3, 1]]) + + source = self.MockSource(dims=4, extent=Box(0, 0, 100, 100)) + _, dim_slices = source[5:10, 15:20, 0] + self.assertListEqual( + dim_slices, [slice(5, 10), slice(15, 20), 0, None]) + + def test_ellipsis(self): + source = self.MockSource(dims=3, extent=Box(0, 0, 100, 100)) + + window, dim_slices = source[5:10, 15:20, ...] + self.assertEqual(window, Box(5, 15, 10, 20)) + self.assertListEqual(dim_slices, [slice(5, 10), slice(15, 20), None]) + + window, dim_slices = source[5:10, ...] + self.assertEqual(window, Box(5, 0, 10, 100)) + self.assertListEqual(dim_slices, [slice(5, 10), slice(0, 100), None]) + + window, dim_slices = source[5:10, ..., 0] + self.assertEqual(window, Box(5, 0, 10, 100)) + self.assertListEqual(dim_slices, [slice(5, 10), slice(0, 100), 0]) + + window, dim_slices = source[..., 15:20, 0] + self.assertEqual(window, Box(0, 15, 100, 20)) + self.assertListEqual(dim_slices, [slice(0, 100), slice(15, 20), 0]) + + window, dim_slices = source[..., 0] + self.assertEqual(window, Box(0, 0, 100, 100)) + self.assertListEqual(dim_slices, [slice(0, 100), slice(0, 100), 0]) + + def test_cropped_extent(self): + source = self.MockSource(dims=2, extent=Box(20, 30, 80, 70)) + + window, dim_slices = source[5:10, 15:20] + self.assertEqual(window, Box(5, 15, 10, 20)) + self.assertListEqual(dim_slices, [slice(5, 10), slice(15, 20)]) + + window, dim_slices = source[:, :] + self.assertEqual(window, Box(0, 0, 60, 40)) + self.assertListEqual(dim_slices, [slice(0, 60), slice(0, 40)]) + + window, dim_slices = source[:] + self.assertEqual(window, Box(0, 0, 60, 40)) + self.assertListEqual(dim_slices, [slice(0, 60), slice(0, 40)]) + + window, dim_slices = source[..., :] + self.assertEqual(window, Box(0, 0, 60, 40)) + self.assertListEqual(dim_slices, [slice(0, 60), slice(0, 40)]) + + def test_step(self): + source = self.MockSource(dims=2, extent=Box(20, 30, 80, 70)) + + window, dim_slices = source[5:10:2, 15:20:3] + self.assertEqual(window, Box(5, 15, 10, 20)) + self.assertListEqual(dim_slices, [slice(5, 10, 2), slice(15, 20, 3)]) + + window, dim_slices = source[::2, ::3] + self.assertEqual(window, Box(0, 0, 60, 40)) + self.assertListEqual(dim_slices, [slice(0, 60, 2), slice(0, 40, 3)]) + + if __name__ == '__main__': unittest.main()