Skip to content

Commit

Permalink
factor out numpy-like array indexing implementation and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Apr 5, 2023
1 parent 3f8e7e4 commit a024ee1
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 120 deletions.
9 changes: 6 additions & 3 deletions rastervision_core/rastervision/core/box.py
Expand Up @@ -260,9 +260,12 @@ def to_rasterio(self) -> RioWindow:
"""Convert to a Rasterio Window."""
return RioWindow.from_slices(*self.normalize().to_slices())

def to_slices(self) -> Tuple[slice, slice]:
"""Convert to slices: ymin:ymax, xmin:xmax"""
return slice(self.ymin, self.ymax), slice(self.xmin, self.xmax)
def to_slices(self,
h_step: Optional[int] = None,
w_step: Optional[int] = None) -> Tuple[slice, slice]:
"""Convert to slices: ymin:ymax[:h_step], xmin:xmax[:w_step]"""
return slice(self.ymin, self.ymax, h_step), slice(
self.xmin, self.xmax, w_step)

def translate(self, dy: int, dx: int) -> 'Box':
"""Translate window along y and x axes by the given distances."""
Expand Down
@@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Any, Optional
from abc import ABC, abstractmethod, abstractproperty

from rastervision.core.data.utils import parse_array_slices

if TYPE_CHECKING:
from rastervision.core.box import Box
from rastervision.core.data import CRSTransformer, Labels
Expand Down Expand Up @@ -50,31 +52,5 @@ def set_extent(self, extent: 'Box') -> None:
def __getitem__(self, key: Any) -> Any:
if isinstance(key, Box):
raise NotImplementedError()
elif isinstance(key, slice):
key = [key]
elif isinstance(key, tuple):
pass
else:
raise TypeError('Unsupported key type.')
slices = list(key)
assert 1 <= len(slices) <= 2
assert all(s is not None for s in slices)
assert isinstance(slices[0], slice)
if len(slices) == 1:
h, = slices
w = slice(None, None)
else:
assert isinstance(slices[1], slice)
h, w = slices

if any(x is not None and x < 0
for x in [h.start, h.stop, w.start, w.stop]):
raise NotImplementedError()

ymin, xmin, ymax, xmax = self.extent
_ymin = 0 if h.start is None else h.start
_xmin = 0 if w.start is None else w.start
_ymax = ymax if h.stop is None else h.stop
_xmax = xmax if w.stop is None else w.stop
window = Box(_ymin, _xmin, _ymax, _xmax)
window, _ = parse_array_slices(key, extent=self.extent, dims=2)
return self[window]
Expand Up @@ -6,6 +6,7 @@
from rastervision.core.data.label import ObjectDetectionLabels
from rastervision.core.data.label_source import LabelSource
from rastervision.core.data.vector_source import VectorSource
from rastervision.core.data.utils import parse_array_slices

if TYPE_CHECKING:
from rastervision.core.data import CRSTransformer
Expand Down Expand Up @@ -88,34 +89,8 @@ class labels for each of the boxes.
npboxes = labels.get_npboxes()
npboxes = ObjectDetectionLabels.global_to_local(npboxes, window)
return npboxes, class_ids, 'yxyx'
elif isinstance(key, slice):
key = [key]
elif isinstance(key, tuple):
pass
else:
raise TypeError('Unsupported key type.')
slices = list(key)
assert 1 <= len(slices) <= 2
assert all(s is not None for s in slices)
assert isinstance(slices[0], slice)
if len(slices) == 1:
h, = slices
w = slice(None, None)
else:
assert isinstance(slices[1], slice)
h, w = slices

if any(x is not None and x < 0
for x in [h.start, h.stop, w.start, w.stop]):
raise NotImplementedError()

ymin, xmin, ymax, xmax = self.extent
_ymin = 0 if h.start is None else h.start
_xmin = 0 if w.start is None else w.start
_ymax = ymax if h.stop is None else h.stop
_xmax = xmax if w.stop is None else w.stop
window = Box(_ymin, _xmin, _ymax, _xmax)

window, (h, w) = parse_array_slices(key, extent=self.extent, dims=2)
npboxes, class_ids, fmt = self[window]

# rescale if steps specified
Expand Down
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod, abstractproperty

from rastervision.core.box import Box
from rastervision.core.data.utils import parse_array_slices

if TYPE_CHECKING:
from rastervision.core.data import (CRSTransformer, RasterTransformer)
Expand Down Expand Up @@ -103,38 +104,13 @@ def _get_chip(self, window: 'Box') -> 'np.ndarray':
def __getitem__(self, key: Any) -> 'np.ndarray':
if isinstance(key, Box):
return self.get_chip(key)
elif isinstance(key, slice):
key = [key]
elif isinstance(key, tuple):
pass
else:
raise TypeError('Unsupported key type.')
slices = list(key)

assert 1 <= len(slices) <= 2
assert all(s is not None for s in slices)
assert isinstance(slices[0], slice)
if len(slices) == 1:
h, = slices
w = slice(None, None)
else:
assert isinstance(slices[1], slice)
h, w = slices

if any(x is not None and x < 0
for x in [h.start, h.stop, w.start, w.stop]):
raise NotImplementedError()

ymin, xmin, ymax, xmax = self.extent
_ymin = 0 if h.start is None else h.start
_xmin = 0 if w.start is None else w.start
_ymax = ymax if h.stop is None else h.stop
_xmax = xmax if w.stop is None else w.stop
window = Box(_ymin, _xmin, _ymax, _xmax)

window, (h, w, c) = parse_array_slices(key, extent=self.extent, dims=3)
chip = self.get_chip(window)
if h.step is not None or w.step is not None:
chip = chip[::h.step, ::w.step]
chip = chip[..., c]

return chip

def get_chip(self, window: 'Box') -> 'np.ndarray':
Expand Down
Expand Up @@ -11,7 +11,7 @@
from rastervision.core.box import Box
from rastervision.core.data.crs_transformer import RasterioCRSTransformer
from rastervision.core.data.raster_source import RasterSource
from rastervision.core.data.utils import listify_uris
from rastervision.core.data.utils import listify_uris, parse_array_slices

if TYPE_CHECKING:
from rasterio.io import DatasetReader
Expand Down Expand Up @@ -349,38 +349,8 @@ def get_chip(self,
def __getitem__(self, key: Any) -> 'np.ndarray':
if isinstance(key, Box):
return self.get_chip(key)
elif isinstance(key, slice):
key = [key]
elif isinstance(key, tuple):
pass
else:
raise TypeError('Unsupported key type.')

slices = list(key)
assert 1 <= len(slices) <= 3
assert all(s is not None for s in slices)
assert isinstance(slices[0], slice)
if len(slices) == 1:
h, = slices
w = slice(None, None)
c = None
elif len(slices) == 2:
assert isinstance(slices[1], slice)
h, w = slices
c = None
else:
h, w, c = slices

if any(x is not None and x < 0
for x in [h.start, h.stop, w.start, w.stop]):
raise NotImplementedError()

ymin, xmin, ymax, xmax = self.extent
_ymin = 0 if h.start is None else h.start
_xmin = 0 if w.start is None else w.start
_ymax = ymax if h.stop is None else h.stop
_xmax = xmax if w.stop is None else w.stop
window = Box(_ymin, _xmin, _ymax, _xmax)

window, (h, w, c) = parse_array_slices(key, extent=self.extent, dims=3)

out_shape = None
if h.step is not None or w.step is not None:
Expand Down
87 changes: 86 additions & 1 deletion rastervision_core/rastervision/core/data/utils/misc.py
@@ -1,9 +1,11 @@
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
import logging

import numpy as np
from PIL import ImageColor

from rastervision.core.box import Box

if TYPE_CHECKING:
from rastervision.core.data import (RasterSource, LabelSource, LabelStore)

Expand Down Expand Up @@ -121,3 +123,86 @@ def match_extents(raster_source: 'RasterSource',
# set LabelStore extent to RasterSource extent
extent_label_pixel = crs_tf_label.map_to_pixel(extent_img_map)
label_source.set_extent(extent_label_pixel)


def parse_array_slices(key: Union[tuple, slice], extent: Box,
dims: int = 2) -> Tuple[Box, List[Optional[Any]]]:
"""Parse multi-dim array-indexing inputs into a Box and slices.
Args:
key (Union[tuple, slice]): Input to __getitem__.
extent (Box): Extent of the raster/label source being indexed.
dims (int, optional): Total available indexable dims. Defaults to 2.
Raises:
NotImplementedError: If not (1 <= dims <= 3).
TypeError: If key is not a slice or tuple.
IndexError: if not (1 <= len(key) <= dims).
TypeError: If the index for any of the dims is None.
ValueError: If more than one Ellipsis ("...") in the input.
ValueError: If h and w indices (first 2 dims) are not slices.
NotImplementedError: If input contains negative values.
Returns:
Tuple[Box, list]: A Box representing the h and w slices and a list
containing slices/index-values for all the dims.
"""
if isinstance(key, slice):
key = [key]
elif isinstance(key, tuple):
pass
else:
raise TypeError('Unsupported key type.')

input_slices = list(key)

if not (1 <= len(input_slices) <= dims):
raise IndexError(f'Too many indices for {dims}-dimensional source.')
if any(s is None for s in input_slices):
raise TypeError('None is not a valid index.')

if Ellipsis in input_slices:
if input_slices.count(Ellipsis) > 1:
raise ValueError('Only one ellipsis is allowed.')
num_missing_dims = dims - (len(input_slices) - 1)
filler_slices = [None] * num_missing_dims
idx = input_slices.index(Ellipsis)
# at the start
if idx == 0:
dim_slices = filler_slices + input_slices[(idx + 1):]
# somewhere in the middle
elif idx < (len(input_slices) - 1):
dim_slices = (
input_slices[:idx] + filler_slices + input_slices[(idx + 1):])
# at the end
else:
dim_slices = input_slices[:idx] + filler_slices
else:
num_missing_dims = dims - len(input_slices)
filler_slices = [None] * num_missing_dims
dim_slices = input_slices + filler_slices

if dim_slices[0] is None:
dim_slices[0] = slice(None, None)
if dim_slices[1] is None:
dim_slices[1] = slice(None, None)
h, w = dim_slices[:2]
if not (isinstance(h, slice) and isinstance(w, slice)):
raise ValueError('h and w indices (first 2 dims) must be slices.')

if any(x is not None and x < 0
for x in [h.start, h.stop, h.step, w.start, w.stop, w.step]):
raise NotImplementedError(
'Negative indices are currently not supported.')

# slices with missing endpoints get expanded to the extent limits
H, W = extent.size
_ymin = 0 if h.start is None else h.start
_xmin = 0 if w.start is None else w.start
_ymax = H if h.stop is None else h.stop
_xmax = W if w.stop is None else w.stop
window = Box(_ymin, _xmin, _ymax, _xmax)

dim_slices = list(window.to_slices(h.step, w.step)) + dim_slices[2:]

return window, dim_slices

0 comments on commit a024ee1

Please sign in to comment.