Skip to content

Commit

Permalink
Merge pull request #1740 from AdeelH/label_source_extent
Browse files Browse the repository at this point in the history
Ensure `RasterSource` and `LabelSource` extents match up in `Scene`
  • Loading branch information
AdeelH committed Apr 4, 2023
2 parents 0889207 + f29f4d0 commit 279184c
Show file tree
Hide file tree
Showing 26 changed files with 484 additions and 149 deletions.
56 changes: 35 additions & 21 deletions rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def __init__(self, ymin, xmin, ymax, xmax):
xmax: maximum x value
"""
ymin, ymax = sorted((ymin, ymax))
xmin, xmax = sorted((xmin, xmax))
self.ymin = ymin
self.xmin = xmin
self.ymax = ymax
Expand Down Expand Up @@ -69,6 +67,12 @@ def area(self) -> int:
"""Return area of Box."""
return self.height * self.width

def normalize(self) -> 'Box':
"""Ensure ymin <= ymax and xmin <= xmax."""
ymin, ymax = sorted((self.ymin, self.ymax))
xmin, xmax = sorted((self.xmin, self.xmax))
return Box(ymin, xmin, ymax, xmax)

def rasterio_format(self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
"""Return Box in Rasterio format: ((ymin, ymax), (xmin, xmax))."""
return ((self.ymin, self.ymax), (self.xmin, self.xmax))
Expand Down Expand Up @@ -150,15 +154,17 @@ def make_random_box_container(self, out_h: int, out_w: int) -> 'Box':
if out_w < self_w:
raise BoxSizeError('size of random container cannot be < width')

lb = self.ymin - (out_h - self_h)
ub = self.ymin
ymin = random.randint(int(lb), int(ub))
ymin, xmin, _, _ = self.normalize()

lb = ymin - (out_h - self_h)
ub = ymin
out_ymin = random.randint(int(lb), int(ub))

lb = self.xmin - (out_w - self_w)
ub = self.xmin
xmin = random.randint(int(lb), int(ub))
lb = xmin - (out_w - self_w)
ub = xmin
out_xmin = random.randint(int(lb), int(ub))

return Box(ymin, xmin, ymin + out_h, xmin + out_w)
return Box(out_ymin, out_xmin, out_ymin + out_h, out_xmin + out_w)

def make_random_square(self, size: int) -> 'Box':
"""Return new randomly positioned square Box that lies inside this Box.
Expand All @@ -173,12 +179,14 @@ def make_random_square(self, size: int) -> 'Box':
if size >= self.height:
raise BoxSizeError('size of random square cannot be >= height')

lb = self.ymin
ub = self.ymax - size
ymin, xmin, ymax, xmax = self.normalize()

lb = ymin
ub = ymax - size
rand_y = random.randint(int(lb), int(ub))

lb = self.xmin
ub = self.xmax - size
lb = xmin
ub = xmax - size
rand_x = random.randint(int(lb), int(ub))

return Box.make_square(rand_y, rand_x, size)
Expand All @@ -195,16 +203,22 @@ def intersection(self, other: 'Box') -> 'Box':
"""
if not self.intersects(other):
return Box(0, 0, 0, 0)
xmin = max(self.xmin, other.xmin)
ymin = max(self.ymin, other.ymin)
xmax = min(self.xmax, other.xmax)
ymax = min(self.ymax, other.ymax)

box1 = self.normalize()
box2 = other.normalize()

xmin = max(box1.xmin, box2.xmin)
ymin = max(box1.ymin, box2.ymin)
xmax = min(box1.xmax, box2.xmax)
ymax = min(box1.ymax, box2.ymax)
return Box(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)

def intersects(self, other: 'Box') -> bool:
if self.ymax <= other.ymin or self.ymin >= other.ymax:
box1 = self.normalize()
box2 = other.normalize()
if box1.ymax <= box2.ymin or box1.ymin >= box2.ymax:
return False
if self.xmax <= other.xmin or self.xmin >= other.xmax:
if box1.xmax <= box2.xmin or box1.xmin >= box2.xmax:
return False
return True

Expand Down Expand Up @@ -240,11 +254,11 @@ def to_points(self) -> np.ndarray:

def to_shapely(self) -> Polygon:
"""Convert to shapely Polygon."""
return Polygon.from_bounds(*(self.shapely_format()))
return Polygon.from_bounds(*self.shapely_format())

def to_rasterio(self) -> RioWindow:
"""Convert to a Rasterio Window."""
return RioWindow.from_slices(*self.to_slices())
return RioWindow.from_slices(*self.normalize().to_slices())

def to_slices(self) -> Tuple[slice, slice]:
"""Convert to slices: ymin:ymax, xmin:xmax"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, overload, Tuple

from rasterio.windows import Window
from shapely.ops import transform
from shapely.geometry.base import BaseGeometry

Expand Down Expand Up @@ -38,10 +37,6 @@ def map_to_pixel(self, inp: Tuple['np.array', 'np.array']
def map_to_pixel(self, inp: Box) -> Box:
...

@overload
def map_to_pixel(self, inp: Window) -> Window:
...

@overload
def map_to_pixel(self, inp: BaseGeometry) -> BaseGeometry:
...
Expand All @@ -50,9 +45,8 @@ def map_to_pixel(self, inp):
"""Transform input from pixel to map coords.
Args:
inp: (x, y) tuple or Box or rasterio Window or shapely geometry in
pixel coordinates. If tuple, x and y can be single values or
array-like.
inp: (x, y) tuple or Box or shapely geometry in map coordinates.
If tuple, x and y can be single values or array-like.
Returns:
Coordinate-transformed input in the same format.
Expand All @@ -64,14 +58,6 @@ def map_to_pixel(self, inp):
xmax_tf, ymax_tf = self._map_to_pixel((xmax, ymax))
box_out = Box(ymin_tf, xmin_tf, ymax_tf, xmax_tf)
return box_out
elif isinstance(inp, Window):
window_in = inp
(ymin, ymax), (xmin, xmax) = window_in.toranges()
xmin_tf, ymin_tf = self._map_to_pixel((xmin, ymin))
xmax_tf, ymax_tf = self._map_to_pixel((xmax, ymax))
window_out = Window.from_slices(
slice(ymin_tf, ymax_tf), slice(xmin_tf, xmax_tf))
return window_out
elif isinstance(inp, BaseGeometry):
geom_in = inp
geom_out = transform(
Expand All @@ -80,8 +66,8 @@ def map_to_pixel(self, inp):
elif len(inp) == 2:
return self._map_to_pixel(inp)
else:
raise TypeError('Input must be 2-tuple or Box or rasterio Window '
'or shapely geometry.')
raise TypeError(
'Input must be 2-tuple or Box or shapely geometry.')

@overload
def pixel_to_map(self, inp: Tuple[float, float]) -> Tuple[float, float]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Any, Iterable, List, Optional

import numpy as np
import geopandas as gpd

from rastervision.core.data.label import ChipClassificationLabels
Expand All @@ -9,7 +8,7 @@

if TYPE_CHECKING:
from rastervision.core.data import (ChipClassificationLabelSourceConfig,
VectorSource)
CRSTransformer, VectorSource)


def infer_cells(cells: List[Box], labels_df: gpd.GeoDataFrame,
Expand Down Expand Up @@ -110,16 +109,9 @@ def read_labels(labels_df: gpd.GeoDataFrame,
Returns:
ChipClassificationLabels
"""
boxes = [Box.from_shapely(g).to_int() for g in labels_df.geometry]
if extent is not None:
extent_polygon = extent.to_shapely()
labels_df = labels_df[labels_df.intersects(extent_polygon)]
boxes = np.array([
Box.from_shapely(c).to_int().shift_origin(extent)
for c in labels_df.geometry
])
else:
boxes = np.array(
[Box.from_shapely(c).to_int() for c in labels_df.geometry])
boxes = [b for b in boxes if b.intersects(extent)]
class_ids = labels_df['class_id'].astype(int)
cells_to_class_id = {
cell: (class_id, None)
Expand All @@ -143,26 +135,30 @@ class ChipClassificationLabelSource(LabelSource):
def __init__(self,
label_source_config: 'ChipClassificationLabelSourceConfig',
vector_source: 'VectorSource',
extent: Box = None,
extent: Optional[Box] = None,
lazy: bool = False):
"""Constructs a LabelSource for chip classification.
Args:
label_source_config (ChipClassificationLabelSourceConfig): Config
for class inference.
vector_source (VectorSource): Source of vector labels.
extent (Box): Box used to filter the labels by extent or
compute grid.
lazy (bool, optional): If True, labels are not populated during
extent (Optional[Box]): User-specified extent. If None, the full
extent of the vector source is used.
lazy (bool): If True, labels are not populated during
initialization. Defaults to False.
"""
self.cfg = label_source_config
self.vector_source = vector_source
if extent is None:
extent = vector_source.extent
self._extent = extent
self.lazy = lazy
self.labels_df = vector_source.get_dataframe()
self.validate_labels(self.labels_df)

self.labels = ChipClassificationLabels.make_empty()
if not lazy:
if not self.lazy:
self.populate_labels()

def populate_labels(self, cells: Optional[Iterable[Box]] = None) -> None:
Expand Down Expand Up @@ -245,3 +241,12 @@ def validate_labels(self, df: gpd.GeoDataFrame) -> None:
@property
def extent(self) -> Box:
return self._extent

@property
def crs_transformer(self) -> 'CRSTransformer':
return self.vector_source.crs_transformer

def set_extent(self, extent: 'Box') -> None:
self._extent = extent
if not self.lazy:
self.populate_labels()
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def ensure_bg_class_id_if_inferring(cls, values: dict) -> dict:
'background_class_id is required if infer_cells=True.')
return values

def build(self, class_config, crs_transformer, extent=None, tmp_dir=None):
def build(self, class_config, crs_transformer, extent=None,
tmp_dir=None) -> ChipClassificationLabelSource:
if self.vector_source is None:
raise ValueError('Cannot build with a None vector_source.')
if self.infer_cells and self.cell_sz is None and not self.lazy:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from typing import Any
from typing import TYPE_CHECKING, Any, Optional
from abc import ABC, abstractmethod, abstractproperty

from rastervision.core.box import Box
if TYPE_CHECKING:
from rastervision.core.box import Box
from rastervision.core.data import CRSTransformer, Labels


class LabelSource(ABC):
"""An interface for storage of labels for a scene.
An LabelSource is a read source of labels for a scene
A LabelSource is a read-only source of labels for a scene
that could be backed by a file, a database, an API, etc. The difference
between LabelSources and Labels can be understood by analogy to the
difference between a database and result sets queried from a database.
"""

@abstractmethod
def get_labels(self, window=None):
def get_labels(self, window: Optional['Box'] = None) -> 'Labels':
"""Return labels overlapping with window.
Args:
Expand All @@ -27,7 +29,22 @@ def get_labels(self, window=None):
pass

@abstractproperty
def extent(self) -> Box:
def extent(self) -> 'Box':
pass

@abstractproperty
def crs_transformer(self) -> 'CRSTransformer':
pass

@abstractmethod
def set_extent(self, extent: 'Box') -> None:
"""Set self.extent to the given value.
.. note:: This method is idempotent.
Args:
extent (Box): User-specified extent in pixel coordinates.
"""
pass

def __getitem__(self, key: Any) -> Any:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from typing import TYPE_CHECKING, Optional

from rastervision.pipeline.config import Config, register_config

if TYPE_CHECKING:
from rastervision.core.box import Box
from rastervision.core.data import (ClassConfig, CRSTransformer,
LabelSource, SceneConfig)
from rastervision.core.rv_pipeline import RVPipelineConfig


@register_config('label_source')
class LabelSourceConfig(Config):
"""Configure a :class:`.LabelSource`."""

def build(self, class_config, crs_transformer, extent, tmp_dir):
def build(self,
class_config: 'ClassConfig',
crs_transformer: 'CRSTransformer',
extent: Optional['Box'] = None,
tmp_dir: Optional[str] = None) -> 'LabelSource':
raise NotImplementedError()

def update(self, pipeline=None, scene=None):
def update(self,
pipeline: Optional['RVPipelineConfig'] = None,
scene: Optional['SceneConfig'] = None) -> None:
pass
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional, Tuple

import numpy as np

Expand All @@ -7,20 +7,24 @@
from rastervision.core.data.label_source import LabelSource
from rastervision.core.data.vector_source import VectorSource

if TYPE_CHECKING:
from rastervision.core.data import CRSTransformer


class ObjectDetectionLabelSource(LabelSource):
"""A read-only label source for object detection."""

def __init__(self,
vector_source: VectorSource,
extent: Box,
extent: Optional[Box] = None,
ioa_thresh: Optional[float] = None,
clip: bool = False):
"""Constructor.
Args:
vector_source (VectorSource): A VectorSource.
extent (Box): Box used to filter the labels by extent.
extent (Optional[Box]): User-specified extent. If None, the full
extent of the vector source is used.
ioa_thresh (Optional[float], optional): IOA threshold to apply when
retieving labels for a window. Defaults to None.
clip (bool, optional): Clip bounding boxes to window limits when
Expand All @@ -31,6 +35,8 @@ def __init__(self,
self.validate_geojson(geojson)
self.labels = ObjectDetectionLabels.from_geojson(
geojson, extent=extent)
if extent is None:
extent = vector_source.extent
self._extent = extent
self.ioa_thresh = ioa_thresh if ioa_thresh is not None else 1e-6
self.clip = clip
Expand Down Expand Up @@ -138,3 +144,10 @@ def validate_geojson(self, geojson: dict) -> None:
@property
def extent(self) -> Box:
return self._extent

@property
def crs_transformer(self) -> 'CRSTransformer':
return self.vector_source.crs_transformer

def set_extent(self, extent: 'Box') -> None:
self._extent = extent

0 comments on commit 279184c

Please sign in to comment.