Skip to content

Commit

Permalink
Merge pull request #1899 from AdeelH/rs_cfg_bbox
Browse files Browse the repository at this point in the history
Fix inconsistent handling of `RasterSourceConfig.bbox`'s type
  • Loading branch information
AdeelH committed Sep 6, 2023
2 parents 6249e01 + 9dc948f commit 1c9cc25
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import conint, conlist

from rastervision.pipeline.config import (Field, register_config, validator)
from rastervision.core.box import Box
from rastervision.core.data.raster_source import (RasterSourceConfig,
MultiRasterSource)

Expand Down Expand Up @@ -66,6 +67,8 @@ def build(self, tmp_dir: str,
built_raster_sources = [
rs.build(tmp_dir, use_transformers) for rs in self.raster_sources
]
bbox = Box(*self.bbox) if self.bbox is not None else None

if self.temporal:
from rastervision.core.data.raster_source import (
TemporalMultiRasterSource)
Expand All @@ -74,15 +77,15 @@ def build(self, tmp_dir: str,
primary_source_idx=self.primary_source_idx,
force_same_dtype=self.force_same_dtype,
raster_transformers=raster_transformers,
bbox=self.bbox)
bbox=bbox)
else:
multi_raster_source = MultiRasterSource(
raster_sources=built_raster_sources,
primary_source_idx=self.primary_source_idx,
force_same_dtype=self.force_same_dtype,
channel_order=self.channel_order,
raster_transformers=raster_transformers,
bbox=self.bbox)
bbox=bbox)
return multi_raster_source

def update(self, pipeline=None, scene=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import TYPE_CHECKING, List, Optional, Tuple

from rastervision.core.box import Box
from rastervision.pipeline.config import (Config, register_config, Field,
validator, ConfigError)
ConfigError)
from rastervision.core.data.raster_transformer import RasterTransformerConfig

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,9 +53,3 @@ def update(self,
scene: Optional['SceneConfig'] = None) -> None:
for t in self.transformers:
t.update(pipeline, scene)

@validator('bbox')
def validate_bbox(cls, v):
if v is None:
return None
return Box(*v)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Union

from rastervision.core.box import Box
from rastervision.core.data.raster_source import RasterSourceConfig, RasterioSource
from rastervision.pipeline.config import ConfigError, Field, register_config

Expand Down Expand Up @@ -37,11 +38,11 @@ class RasterioSourceConfig(RasterSourceConfig):
def build(self, tmp_dir, use_transformers=True):
raster_transformers = ([rt.build() for rt in self.transformers]
if use_transformers else [])

bbox = Box(*self.bbox) if self.bbox is not None else None
return RasterioSource(
uris=self.uris,
raster_transformers=raster_transformers,
tmp_dir=tmp_dir,
allow_streaming=self.allow_streaming,
channel_order=self.channel_order,
bbox=self.bbox)
bbox=bbox)
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import TYPE_CHECKING, Optional

from rastervision.core.data.raster_source import (RasterizedSource)
from rastervision.core.data.vector_source import (VectorSourceConfig)
from rastervision.core.data.vector_transformer import (
ClassInferenceTransformerConfig, BufferTransformerConfig)
from rastervision.pipeline.config import (register_config, Config, Field,
validator)

if TYPE_CHECKING:
from rastervision.core.box import Box
from rastervision.core.data import ClassConfig, CRSTransformer


@register_config('rasterizer')
class RasterizerConfig(Config):
Expand Down Expand Up @@ -55,7 +61,10 @@ def update(self, pipeline=None, scene=None):
super().update(pipeline, scene)
self.vector_source.update(pipeline, scene)

def build(self, class_config, crs_transformer, bbox) -> RasterizedSource:
def build(self,
class_config: 'ClassConfig',
crs_transformer: 'CRSTransformer',
bbox: Optional['Box'] = None) -> RasterizedSource:
vector_source = self.vector_source.build(class_config, crs_transformer)
return RasterizedSource(
vector_source=vector_source,
Expand Down
9 changes: 5 additions & 4 deletions tests/core/data/raster_source/test_multi_raster_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def test_build(self):
cfg = make_cfg()
self.assertNoError(lambda: cfg.build(tmp_dir=get_tmp_dir()))

def test_build_with_bbox(self):
cfg = make_cfg(bbox=(0, 0, 1, 1))
rs = cfg.build(tmp_dir=get_tmp_dir())
self.assertEqual(rs.bbox, Box(0, 0, 1, 1))

def test_build_temporal(self):
cfg = make_cfg(temporal=True)
rs = cfg.build(tmp_dir=get_tmp_dir())
Expand Down Expand Up @@ -128,10 +133,6 @@ def test_bbox(self):
self.assertEqual(rs.bbox, Box(0, 0, 256, 256))
self.assertEqual(rs.extent, Box(0, 0, 256, 256))

# test validators
cfg = make_cfg('small-rgb-tile.tif', bbox=(64, 64, 192, 192))
self.assertIsInstance(cfg.bbox, Box)

# /w user specified extent
cfg_crop = make_cfg('small-rgb-tile.tif', bbox=(64, 64, 192, 192))
rs_crop = cfg_crop.build(tmp_dir=self.tmp_dir)
Expand Down
24 changes: 20 additions & 4 deletions tests/core/data/raster_source/test_rasterio_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Callable
import unittest
from os.path import join
from tempfile import NamedTemporaryFile
Expand All @@ -16,6 +17,25 @@
from tests import data_file_path


class TestRasterioSourceConfig(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
fn()
except Exception:
self.fail(msg)

def test_build(self):
img_path = data_file_path('small-rgb-tile.tif')
cfg = RasterioSourceConfig(uris=[img_path])
self.assertNoError(lambda: cfg.build(tmp_dir=get_tmp_dir()))

def test_build_with_bbox(self):
img_path = data_file_path('small-rgb-tile.tif')
cfg = RasterioSourceConfig(uris=[img_path], bbox=(0, 0, 1, 1))
rs = cfg.build(tmp_dir=get_tmp_dir())
self.assertEqual(rs.bbox, Box(0, 0, 1, 1))


class TestRasterioSource(unittest.TestCase):
def setUp(self):
self.tmp_dir_obj = get_tmp_dir()
Expand Down Expand Up @@ -226,10 +246,6 @@ def test_bbox(self):
self.assertEqual(rs_crop.bbox, Box(64, 64, 192, 192))
self.assertEqual(rs_crop.extent, Box(0, 0, 128, 128))

# test validators
rs_cfg = RasterioSourceConfig(uris=[img_path], bbox=(0, 0, 1, 1))
self.assertIsInstance(rs_cfg.bbox, Box)

def test_extent_overflow(self):
arr = np.ones((100, 100), dtype=np.uint8)
with NamedTemporaryFile('wb') as fp:
Expand Down

0 comments on commit 1c9cc25

Please sign in to comment.