diff --git a/rastervision_core/rastervision/core/data/scene.py b/rastervision_core/rastervision/core/data/scene.py index f125942c4..a1f8cd8d1 100644 --- a/rastervision_core/rastervision/core/data/scene.py +++ b/rastervision_core/rastervision/core/data/scene.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, Any, Optional, Tuple +from rastervision.core.data.utils import match_extents + if TYPE_CHECKING: from rastervision.core.box import Box from rastervision.core.data import (RasterSource, LabelSource, LabelStore) @@ -14,19 +16,30 @@ def __init__(self, label_source: Optional['LabelSource'] = None, label_store: Optional['LabelStore'] = None, aoi_polygons: Optional[list] = None): - """Construct a new Scene. + """Constructor. + + During initialization, ``Scene`` attempts to set the extents of the + given ``label_source`` and the ``label_store`` to be identical to the + extent of the given ``raster_source``. Args: - id: ID for this scene - raster_source: RasterSource for this scene - ground_truth_label_store: optional LabelSource - label_store: optional LabelStore - aoi: Optional list of AOI polygons in pixel coordinates + id: ID for this scene. + raster_source: Source of imagery for this scene. + label_source: Source of labels for this scene. + label_store: Store of predictions for this scene. + aoi: Optional list of AOI polygons in pixel coordinates. """ + if label_source is not None: + match_extents(raster_source, label_source) + + if label_store is not None: + match_extents(raster_source, label_store) + self.id = id self.raster_source = raster_source self.label_source = label_source self.label_store = label_store + if aoi_polygons is None: self.aoi_polygons = [] else: diff --git a/rastervision_core/rastervision/core/data/utils/misc.py b/rastervision_core/rastervision/core/data/utils/misc.py index 485acaca1..d56fb116a 100644 --- a/rastervision_core/rastervision/core/data/utils/misc.py +++ b/rastervision_core/rastervision/core/data/utils/misc.py @@ -1,8 +1,14 @@ -from typing import List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union +import logging import numpy as np from PIL import ImageColor +if TYPE_CHECKING: + from rastervision.core.data import (RasterSource, LabelSource, LabelStore) + +log = logging.getLogger(__name__) + def color_to_triple( color: Optional[Union[str, Sequence]] = None) -> Tuple[int, int, int]: @@ -87,3 +93,31 @@ def listify_uris(uris: Union[str, List[str]]) -> List[str]: else: raise TypeError(f'Expected str or List[str], but got {type(uris)}.') return uris + + +def match_extents(raster_source: 'RasterSource', + label_source: Union['LabelSource', 'LabelStore']) -> None: + """Set ``label_souce`` extent equal to ``raster_source`` extent. + + Logs a warning if ``raster_source`` and ``label_source`` extents don't + intersect when converted to map coordinates. + + Args: + raster_source (RasterSource): Source of imagery for a scene. + label_source (Union[LabelSource, LabelStore]): Source of labels for a + scene. Can be a ``LabelStore``. + """ + crs_tf_img = raster_source.crs_transformer + crs_tf_label = label_source.crs_transformer + extent_img_map = crs_tf_img.pixel_to_map(raster_source.extent) + if label_source.extent is not None: + extent_label_map = crs_tf_label.pixel_to_map(label_source.extent) + if not extent_img_map.intersects(extent_label_map): + rs_cls = type(raster_source).__name__ + ls_cls = type(label_source).__name__ + log.warning(f'{rs_cls} extent ({extent_img_map}) does ' + f'not intersect with {ls_cls} extent ' + f'({extent_label_map}).') + # set LabelStore extent to RasterSource extent + extent_label_pixel = crs_tf_label.map_to_pixel(extent_img_map) + label_source.set_extent(extent_label_pixel) diff --git a/tests/core/data/utils/test_misc.py b/tests/core/data/utils/test_misc.py new file mode 100644 index 000000000..337df0e25 --- /dev/null +++ b/tests/core/data/utils/test_misc.py @@ -0,0 +1,94 @@ +import unittest +from os.path import join + +from rastervision.pipeline.file_system.utils import get_tmp_dir, json_to_file +from rastervision.core.box import Box +from rastervision.core.data import ( + ClassConfig, GeoJSONVectorSource, RasterioSource, + ChipClassificationLabelSource, ChipClassificationLabelSourceConfig, + ChipClassificationGeoJSONStore, ObjectDetectionLabelSource, + ObjectDetectionGeoJSONStore, SemanticSegmentationLabelSource, + SemanticSegmentationLabelStore) +from rastervision.core.data.utils.geojson import geoms_to_geojson +from rastervision.core.data.utils.misc import (match_extents) + +from tests import data_file_path + + +class TestMatchExtents(unittest.TestCase): + def setUp(self) -> None: + self.class_config = ClassConfig(names=['class_1']) + self.rs_path = data_file_path( + 'multi_raster_source/const_100_600x600.tiff') + self.extent_rs = Box(4, 4, 8, 8) + self.raster_source = RasterioSource( + self.rs_path, extent=self.extent_rs) + self.crs_tf = self.raster_source.crs_transformer + + self.extent_ls = Box(0, 0, 12, 12) + geoms = [b.to_shapely() for b in self.extent_ls.get_windows(2, 2)] + geoms = [self.crs_tf.pixel_to_map(g) for g in geoms] + properties = [dict(class_id=0) for _ in geoms] + geojson = geoms_to_geojson(geoms, properties) + self._tmp_dir = get_tmp_dir() + self.tmp_dir = self._tmp_dir.name + uri = join(self.tmp_dir, 'labels.json') + json_to_file(geojson, uri) + self.vector_source = GeoJSONVectorSource( + uri, self.raster_source.crs_transformer, ignore_crs_field=True) + + def tearDown(self) -> None: + self._tmp_dir.cleanup() + + def test_cc_label_source(self): + label_source = ChipClassificationLabelSource( + ChipClassificationLabelSourceConfig(), + vector_source=self.vector_source) + self.assertEqual(label_source.extent, self.extent_ls) + match_extents(self.raster_source, label_source) + self.assertEqual(label_source.extent, self.raster_source.extent) + + def test_cc_label_store(self): + uri = join(self.tmp_dir, 'cc_labels.json') + label_store = ChipClassificationGeoJSONStore(uri, self.class_config, + self.crs_tf) + self.assertIsNone(label_store.extent) + match_extents(self.raster_source, label_store) + self.assertEqual(label_store.extent, self.raster_source.extent) + + def test_od_label_source(self): + label_source = ObjectDetectionLabelSource( + vector_source=self.vector_source) + self.assertEqual(label_source.extent, self.extent_ls) + match_extents(self.raster_source, label_source) + self.assertEqual(label_source.extent, self.raster_source.extent) + + def test_od_label_store(self): + uri = join(self.tmp_dir, 'od_labels.json') + label_store = ObjectDetectionGeoJSONStore(uri, self.class_config, + self.crs_tf) + self.assertIsNone(label_store.extent) + match_extents(self.raster_source, label_store) + self.assertEqual(label_store.extent, self.raster_source.extent) + + def test_ss_label_source(self): + label_source = SemanticSegmentationLabelSource( + self.raster_source, class_config=self.class_config) + self.assertEqual(label_source.extent, self.extent_rs) + match_extents(self.raster_source, label_source) + self.assertEqual(label_source.extent, self.raster_source.extent) + + def test_ss_label_store(self): + uri = join(self.tmp_dir, 'ss_labels') + label_store = SemanticSegmentationLabelStore( + uri, + extent=self.extent_ls, + crs_transformer=self.crs_tf, + class_config=self.class_config) + self.assertEqual(label_store.extent, self.extent_ls) + match_extents(self.raster_source, label_store) + self.assertEqual(label_store.extent, self.raster_source.extent) + + +if __name__ == '__main__': + unittest.main()