Skip to content

Commit

Permalink
ensure raster and label source/store extents match up in scene
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Mar 28, 2023
1 parent cb2b90b commit f29f4d0
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 7 deletions.
25 changes: 19 additions & 6 deletions rastervision_core/rastervision/core/data/scene.py
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, Any, Optional, Tuple

from rastervision.core.data.utils import match_extents

if TYPE_CHECKING:
from rastervision.core.box import Box
from rastervision.core.data import (RasterSource, LabelSource, LabelStore)
Expand All @@ -14,19 +16,30 @@ def __init__(self,
label_source: Optional['LabelSource'] = None,
label_store: Optional['LabelStore'] = None,
aoi_polygons: Optional[list] = None):
"""Construct a new Scene.
"""Constructor.
During initialization, ``Scene`` attempts to set the extents of the
given ``label_source`` and the ``label_store`` to be identical to the
extent of the given ``raster_source``.
Args:
id: ID for this scene
raster_source: RasterSource for this scene
ground_truth_label_store: optional LabelSource
label_store: optional LabelStore
aoi: Optional list of AOI polygons in pixel coordinates
id: ID for this scene.
raster_source: Source of imagery for this scene.
label_source: Source of labels for this scene.
label_store: Store of predictions for this scene.
aoi: Optional list of AOI polygons in pixel coordinates.
"""
if label_source is not None:
match_extents(raster_source, label_source)

if label_store is not None:
match_extents(raster_source, label_store)

self.id = id
self.raster_source = raster_source
self.label_source = label_source
self.label_store = label_store

if aoi_polygons is None:
self.aoi_polygons = []
else:
Expand Down
36 changes: 35 additions & 1 deletion rastervision_core/rastervision/core/data/utils/misc.py
@@ -1,8 +1,14 @@
from typing import List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
import logging

import numpy as np
from PIL import ImageColor

if TYPE_CHECKING:
from rastervision.core.data import (RasterSource, LabelSource, LabelStore)

log = logging.getLogger(__name__)


def color_to_triple(
color: Optional[Union[str, Sequence]] = None) -> Tuple[int, int, int]:
Expand Down Expand Up @@ -87,3 +93,31 @@ def listify_uris(uris: Union[str, List[str]]) -> List[str]:
else:
raise TypeError(f'Expected str or List[str], but got {type(uris)}.')
return uris


def match_extents(raster_source: 'RasterSource',
label_source: Union['LabelSource', 'LabelStore']) -> None:
"""Set ``label_souce`` extent equal to ``raster_source`` extent.
Logs a warning if ``raster_source`` and ``label_source`` extents don't
intersect when converted to map coordinates.
Args:
raster_source (RasterSource): Source of imagery for a scene.
label_source (Union[LabelSource, LabelStore]): Source of labels for a
scene. Can be a ``LabelStore``.
"""
crs_tf_img = raster_source.crs_transformer
crs_tf_label = label_source.crs_transformer
extent_img_map = crs_tf_img.pixel_to_map(raster_source.extent)
if label_source.extent is not None:
extent_label_map = crs_tf_label.pixel_to_map(label_source.extent)
if not extent_img_map.intersects(extent_label_map):
rs_cls = type(raster_source).__name__
ls_cls = type(label_source).__name__
log.warning(f'{rs_cls} extent ({extent_img_map}) does '
f'not intersect with {ls_cls} extent '
f'({extent_label_map}).')
# set LabelStore extent to RasterSource extent
extent_label_pixel = crs_tf_label.map_to_pixel(extent_img_map)
label_source.set_extent(extent_label_pixel)
94 changes: 94 additions & 0 deletions tests/core/data/utils/test_misc.py
@@ -0,0 +1,94 @@
import unittest
from os.path import join

from rastervision.pipeline.file_system.utils import get_tmp_dir, json_to_file
from rastervision.core.box import Box
from rastervision.core.data import (
ClassConfig, GeoJSONVectorSource, RasterioSource,
ChipClassificationLabelSource, ChipClassificationLabelSourceConfig,
ChipClassificationGeoJSONStore, ObjectDetectionLabelSource,
ObjectDetectionGeoJSONStore, SemanticSegmentationLabelSource,
SemanticSegmentationLabelStore)
from rastervision.core.data.utils.geojson import geoms_to_geojson
from rastervision.core.data.utils.misc import (match_extents)

from tests import data_file_path


class TestMatchExtents(unittest.TestCase):
def setUp(self) -> None:
self.class_config = ClassConfig(names=['class_1'])
self.rs_path = data_file_path(
'multi_raster_source/const_100_600x600.tiff')
self.extent_rs = Box(4, 4, 8, 8)
self.raster_source = RasterioSource(
self.rs_path, extent=self.extent_rs)
self.crs_tf = self.raster_source.crs_transformer

self.extent_ls = Box(0, 0, 12, 12)
geoms = [b.to_shapely() for b in self.extent_ls.get_windows(2, 2)]
geoms = [self.crs_tf.pixel_to_map(g) for g in geoms]
properties = [dict(class_id=0) for _ in geoms]
geojson = geoms_to_geojson(geoms, properties)
self._tmp_dir = get_tmp_dir()
self.tmp_dir = self._tmp_dir.name
uri = join(self.tmp_dir, 'labels.json')
json_to_file(geojson, uri)
self.vector_source = GeoJSONVectorSource(
uri, self.raster_source.crs_transformer, ignore_crs_field=True)

def tearDown(self) -> None:
self._tmp_dir.cleanup()

def test_cc_label_source(self):
label_source = ChipClassificationLabelSource(
ChipClassificationLabelSourceConfig(),
vector_source=self.vector_source)
self.assertEqual(label_source.extent, self.extent_ls)
match_extents(self.raster_source, label_source)
self.assertEqual(label_source.extent, self.raster_source.extent)

def test_cc_label_store(self):
uri = join(self.tmp_dir, 'cc_labels.json')
label_store = ChipClassificationGeoJSONStore(uri, self.class_config,
self.crs_tf)
self.assertIsNone(label_store.extent)
match_extents(self.raster_source, label_store)
self.assertEqual(label_store.extent, self.raster_source.extent)

def test_od_label_source(self):
label_source = ObjectDetectionLabelSource(
vector_source=self.vector_source)
self.assertEqual(label_source.extent, self.extent_ls)
match_extents(self.raster_source, label_source)
self.assertEqual(label_source.extent, self.raster_source.extent)

def test_od_label_store(self):
uri = join(self.tmp_dir, 'od_labels.json')
label_store = ObjectDetectionGeoJSONStore(uri, self.class_config,
self.crs_tf)
self.assertIsNone(label_store.extent)
match_extents(self.raster_source, label_store)
self.assertEqual(label_store.extent, self.raster_source.extent)

def test_ss_label_source(self):
label_source = SemanticSegmentationLabelSource(
self.raster_source, class_config=self.class_config)
self.assertEqual(label_source.extent, self.extent_rs)
match_extents(self.raster_source, label_source)
self.assertEqual(label_source.extent, self.raster_source.extent)

def test_ss_label_store(self):
uri = join(self.tmp_dir, 'ss_labels')
label_store = SemanticSegmentationLabelStore(
uri,
extent=self.extent_ls,
crs_transformer=self.crs_tf,
class_config=self.class_config)
self.assertEqual(label_store.extent, self.extent_ls)
match_extents(self.raster_source, label_store)
self.assertEqual(label_store.extent, self.raster_source.extent)


if __name__ == '__main__':
unittest.main()

0 comments on commit f29f4d0

Please sign in to comment.