From 2b9e7909e1019bf439b25b492f3469bed6fdcc57 Mon Sep 17 00:00:00 2001 From: James McClain Date: Wed, 12 Dec 2018 14:01:36 -0500 Subject: [PATCH] Method for Transform --- rastervision/data/crs_transformer/crs_transformer.py | 3 +++ .../data/crs_transformer/rasterio_crs_transformer.py | 3 +++ .../data/label_store/semantic_segmentation_raster_store.py | 6 +----- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/rastervision/data/crs_transformer/crs_transformer.py b/rastervision/data/crs_transformer/crs_transformer.py index 09e8a24324..2213f99fcd 100644 --- a/rastervision/data/crs_transformer/crs_transformer.py +++ b/rastervision/data/crs_transformer/crs_transformer.py @@ -34,3 +34,6 @@ def get_image_crs(self): def get_map_crs(self): return self.map_crs + + def get_affine_transform(self): + raise NotImplementedError() diff --git a/rastervision/data/crs_transformer/rasterio_crs_transformer.py b/rastervision/data/crs_transformer/rasterio_crs_transformer.py index 7ec946a8b3..e8ae5b7b81 100644 --- a/rastervision/data/crs_transformer/rasterio_crs_transformer.py +++ b/rastervision/data/crs_transformer/rasterio_crs_transformer.py @@ -59,3 +59,6 @@ def from_dataset(cls, dataset, map_crs='epsg:4326'): transform = dataset.transform image_crs = dataset.crs['init'] return cls(transform, image_crs, map_crs) + + def get_affine_transform(self): + return self.transform diff --git a/rastervision/data/label_store/semantic_segmentation_raster_store.py b/rastervision/data/label_store/semantic_segmentation_raster_store.py index f0f23c72f0..f5acaf9fb4 100644 --- a/rastervision/data/label_store/semantic_segmentation_raster_store.py +++ b/rastervision/data/label_store/semantic_segmentation_raster_store.py @@ -119,7 +119,6 @@ def save(self, labels): if self.vector_output: import mask_to_polygons.vectorification as m2p - from affine import Affine for vo in self.vector_output: uri = vo['uri'] @@ -128,10 +127,7 @@ def save(self, labels): class_mask = np.array(mask == class_id, dtype=np.uint8) local_geojson_path = get_local_path(uri, self.tmp_dir) - if isinstance(self.crs_transformer, RasterioCRSTransformer): - transform = self.crs_transformer.transform - else: - transform = Affine.identity() + transform = self.crs_transformer.get_affine_transform() if uri and mode == 'buildings': geojson = m2p.geojson_from_mask(class_mask, transform)