Skip to content

Commit

Permalink
account for bbox when saving predictions (#1931)
Browse files Browse the repository at this point in the history
Co-authored-by: Adeel Hassan <ahassan@element84.com>
  • Loading branch information
AdeelH and AdeelH committed Sep 29, 2023
1 parent 1d23e46 commit ac69538
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 15 deletions.
Expand Up @@ -155,20 +155,26 @@ def extend(self, labels: 'ChipClassificationLabels') -> None:
for cell in labels.get_cells():
self.set_cell(cell, *labels[cell])

def save(self, uri: str, class_config: 'ClassConfig',
crs_transformer: 'CRSTransformer') -> None:
def save(self,
uri: str,
class_config: 'ClassConfig',
crs_transformer: 'CRSTransformer',
bbox: Optional[Box] = None) -> None:
"""Save labels as a GeoJSON file.
Args:
uri (str): URI of the output file.
class_config (ClassConfig): ClassConfig to map class IDs to names.
crs_transformer (CRSTransformer): CRSTransformer to convert from
pixel-coords to map-coords before saving.
bbox (Optional[Box]): User-specified crop of the extent. Must be
provided if the corresponding RasterSource has bbox != extent.
"""
from rastervision.core.data import ChipClassificationGeoJSONStore

label_store = ChipClassificationGeoJSONStore(
uri=uri,
class_config=class_config,
crs_transformer=crs_transformer)
crs_transformer=crs_transformer,
bbox=bbox)
label_store.save(self)
Expand Up @@ -294,20 +294,26 @@ def prune_duplicates(
score_threshold=score_thresh)
return ObjectDetectionLabels.from_boxlist(pruned_boxlist)

def save(self, uri: str, class_config: 'ClassConfig',
crs_transformer: 'CRSTransformer') -> None:
def save(self,
uri: str,
class_config: 'ClassConfig',
crs_transformer: 'CRSTransformer',
bbox: Optional[Box] = None) -> None:
"""Save labels as a GeoJSON file.
Args:
uri (str): URI of the output file.
class_config (ClassConfig): ClassConfig to map class IDs to names.
crs_transformer (CRSTransformer): CRSTransformer to convert from
pixel-coords to map-coords before saving.
bbox (Optional[Box]): User-specified crop of the extent. Must be
provided if the corresponding RasterSource has bbox != extent.
"""
from rastervision.core.data import ObjectDetectionGeoJSONStore

label_store = ObjectDetectionGeoJSONStore(
uri=uri,
class_config=class_config,
crs_transformer=crs_transformer)
crs_transformer=crs_transformer,
bbox=bbox)
label_store.save(self)
Expand Up @@ -357,6 +357,7 @@ def save(self,
uri: str,
crs_transformer: 'CRSTransformer',
class_config: 'ClassConfig',
bbox: Optional[Box] = None,
tmp_dir: Optional[str] = None,
save_as_rgb: bool = False,
raster_output: bool = True,
Expand All @@ -373,6 +374,8 @@ def save(self,
crs_transformer (CRSTransformer): CRSTransformer to configure CRS
and affine transform of the output GeoTiff.
class_config (ClassConfig): The ClassConfig.
bbox (Optional[Box]): User-specified crop of the extent. Must be
provided if the corresponding RasterSource has bbox != extent.
tmp_dir (Optional[str], optional): Temporary directory to use. If
None, will be auto-generated. Defaults to None.
save_as_rgb (bool, optional): If True, Saves labels as an RGB
Expand All @@ -397,6 +400,7 @@ def save(self,
uri=uri,
crs_transformer=crs_transformer,
class_config=class_config,
bbox=bbox,
tmp_dir=tmp_dir,
save_as_rgb=save_as_rgb,
discrete_output=raster_output,
Expand Down Expand Up @@ -529,6 +533,7 @@ def save(self,
uri: str,
crs_transformer: 'CRSTransformer',
class_config: 'ClassConfig',
bbox: Optional[Box] = None,
tmp_dir: Optional[str] = None,
save_as_rgb: bool = False,
discrete_output: bool = True,
Expand All @@ -547,6 +552,8 @@ def save(self,
crs_transformer (CRSTransformer): CRSTransformer to configure CRS
and affine transform of the output GeoTiff(s).
class_config (ClassConfig): The ClassConfig.
bbox (Optional[Box]): User-specified crop of the extent. Must be
provided if the corresponding RasterSource has bbox != extent.
tmp_dir (Optional[str], optional): Temporary directory to use. If
None, will be auto-generated. Defaults to None.
save_as_rgb (bool, optional): If True, saves labels as an RGB
Expand Down Expand Up @@ -577,6 +584,7 @@ def save(self,
uri=uri,
crs_transformer=crs_transformer,
class_config=class_config,
bbox=bbox,
tmp_dir=tmp_dir,
save_as_rgb=save_as_rgb,
discrete_output=discrete_output,
Expand Down
Expand Up @@ -30,7 +30,8 @@ def __init__(self,
in GeoJSON file to pixel coords.
bbox (Optional[Box], optional): User-specified crop of the extent.
If provided, only labels falling inside it are returned by
:meth:`.ChipClassificationGeoJSONStore.get_labels`.
:meth:`.ChipClassificationGeoJSONStore.get_labels`. Must be
provided if the corresponding RasterSource has bbox != extent.
"""
self.uri = uri
self.class_config = class_config
Expand All @@ -51,7 +52,8 @@ def save(self, labels: ChipClassificationLabels) -> None:
class_ids,
self.crs_transformer,
self.class_config,
scores=scores)
scores=scores,
bbox=self.bbox)
json_to_file(geojson, self.uri)

def get_labels(self) -> ChipClassificationLabels:
Expand Down
Expand Up @@ -32,7 +32,8 @@ def __init__(self,
in GeoJSON file to pixel coords.
bbox (Optional[Box], optional): User-specified crop of the extent.
If provided, only labels falling inside it are returned by
:meth:`.ObjectDetectionGeoJSONStore.get_labels`.
:meth:`.ObjectDetectionGeoJSONStore.get_labels`. Must be
provided if the corresponding RasterSource has bbox != extent.
"""
self.uri = uri
self.class_config = class_config
Expand Down
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import rasterio as rio
import rasterio.windows as rio_windows
from tqdm.auto import tqdm

from rastervision.pipeline.file_system import (
Expand Down Expand Up @@ -58,7 +59,8 @@ def __init__(
class_config (ClassConfig): Class config.
bbox (Optional[Box], optional): User-specified crop of the extent.
If provided, only labels falling inside it are returned by
:meth:`.SemanticSegmentationLabelStore.get_labels`.
:meth:`.SemanticSegmentationLabelStore.get_labels`. Must be
provided if the corresponding RasterSource has bbox != extent.
tmp_dir (Optional[str], optional): Temporary directory to use. If
None, will be auto-generated. Defaults to None.
vector_outputs (Optional[Sequence[VectorOutputConfig]], optional):
Expand Down Expand Up @@ -207,11 +209,17 @@ def save(self,
make_dir(local_root)

height, width = labels.extent.size
if self.bbox is not None:
bbox_rio_window = self.bbox.rasterio_format()
transform = rio_windows.transform(bbox_rio_window,
self.crs_transformer.transform)
else:
transform = self.crs_transformer.transform
out_profile = dict(
driver='GTiff',
height=height,
width=width,
transform=self.crs_transformer.transform,
transform=transform,
crs=self.crs_transformer.image_crs,
blockxsize=min(self.rasterio_block_size, width),
blockysize=min(self.rasterio_block_size, height))
Expand Down Expand Up @@ -257,6 +265,7 @@ def write_smooth_raster_output(
out_profile.update(dict(count=num_bands, dtype=dtype))

extent = labels.extent

with rio.open(scores_path, 'w', **out_profile) as ds:
windows = [Box.from_rasterio(w) for _, w in ds.block_windows(1)]
with tqdm(windows, desc='Saving pixel scores') as bar:
Expand Down Expand Up @@ -310,7 +319,10 @@ def write_vector_outputs(self, labels: SemanticSegmentationLabels,
bar.set_postfix(vo.dict())
class_mask = (label_arr == vo.class_id).astype(np.uint8)
polys = vo.vectorize(class_mask)
polys = [self.crs_transformer.pixel_to_map(p) for p in polys]
polys = [
self.crs_transformer.pixel_to_map(p, bbox=self.bbox)
for p in polys
]
geojson = geoms_to_geojson(polys)
out_uri = vo.get_uri(vector_output_dir, self.class_config)
json_to_file(geojson, out_uri)
Expand Down
11 changes: 8 additions & 3 deletions rastervision_core/rastervision/core/data/label_store/utils.py
Expand Up @@ -16,8 +16,8 @@ def boxes_to_geojson(
class_ids: Sequence[int],
crs_transformer: 'CRSTransformer',
class_config: 'ClassConfig',
scores: Optional[Sequence[Union[float, Sequence[float]]]] = None
) -> dict:
scores: Optional[Sequence[Union[float, Sequence[float]]]] = None,
bbox: Optional['Box'] = None) -> dict:
"""Convert boxes and associated data into a GeoJSON dict.
Args:
Expand All @@ -30,6 +30,8 @@ def boxes_to_geojson(
Optional list of score or scores. If floats (one for each box),
property name will be "score". If lists of floats, property name
will be "scores". Defaults to None.
bbox (Optional[Box]): User-specified crop of the extent. Must be
provided if the corresponding RasterSource has bbox != extent.
Returns:
dict: Serialized GeoJSON.
Expand All @@ -46,7 +48,10 @@ def boxes_to_geojson(
boxes,
desc='Transforming boxes to map coords',
delay=PROGRESSBAR_DELAY_SEC) as bar:
geoms = [crs_transformer.pixel_to_map(box.to_shapely()) for box in bar]
geoms = [
crs_transformer.pixel_to_map(box.to_shapely(), bbox=bbox)
for box in bar
]

# add box properties (ID and name of predicted class)
with tqdm(
Expand Down

0 comments on commit ac69538

Please sign in to comment.