From 2359be1feb58b2e9f4d25e43acdae5d26eddc47e Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Thu, 8 Feb 2024 16:43:49 -0500 Subject: [PATCH] allow specifying class probability thresholds for SS label vectorization --- .../data/label/semantic_segmentation_labels.py | 16 ++++++++++++++++ .../semantic_segmentation_label_store.py | 8 ++++---- .../semantic_segmentation_label_store_config.py | 9 +++++++++ .../test_semantic_segmentation_label_store.py | 4 +++- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/rastervision_core/rastervision/core/data/label/semantic_segmentation_labels.py b/rastervision_core/rastervision/core/data/label/semantic_segmentation_labels.py index 5a6078158..192015b72 100644 --- a/rastervision_core/rastervision/core/data/label/semantic_segmentation_labels.py +++ b/rastervision_core/rastervision/core/data/label/semantic_segmentation_labels.py @@ -61,6 +61,22 @@ def get_label_arr(self, window: Box, input window. """ + @abstractmethod + def get_score_arr(self, window: Box, + null_class_id: int = -1) -> np.ndarray: + """Get (C, H, W) array of pixel scores.""" + + def get_class_mask(self, + window: Box, + class_id: int, + threshold: Optional[float] = None) -> np.ndarray: + """Get a binary mask representing all pixels of a class.""" + scores = self.get_score_arr(window) + if threshold is None: + threshold = (1 / self.num_classes) + mask = scores[class_id] >= threshold + return mask + def get_windows(self, **kwargs) -> List[Box]: """Generate sliding windows over the local extent. diff --git a/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py b/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py index 238e63a35..1715d346e 100644 --- a/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py +++ b/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py @@ -311,13 +311,13 @@ def write_vector_outputs(self, labels: SemanticSegmentationLabels, log.info('Writing vector outputs to disk.') - label_arr = labels.get_label_arr(labels.extent, - self.class_config.null_class_id) - + extent = labels.extent with tqdm(self.vector_outputs, desc='Vectorizing predictions') as bar: for vo in bar: bar.set_postfix(vo.dict()) - class_mask = (label_arr == vo.class_id).astype(np.uint8) + class_mask = labels.get_class_mask(extent, vo.class_id, + vo.threshold) + class_mask = class_mask.astype(np.uint8) polys = vo.vectorize(class_mask) polys = [ self.crs_transformer.pixel_to_map(p, bbox=self.bbox) diff --git a/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store_config.py b/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store_config.py index 57f2c369f..115b94b28 100644 --- a/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store_config.py +++ b/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store_config.py @@ -41,6 +41,15 @@ class VectorOutputConfig(Config): 'intensive (especially for large images). Larger values will remove ' 'more noise and make vectorization faster but might also remove ' 'legitimate detections.') + threshold: Optional[float] = Field( + None, + description='Probability threshold for creating the binary mask for ' + 'the pixels of this class. Pixels will be considered to belong to ' + 'this class if their probability for this class is >= ``threshold``. ' + 'Note that Raster Vision treats classes as mutually exclusive so the ' + 'threshold should vary with the number of total classes. ' + '``None`` is equivalent to setting this to (1 / num_classes). ' + 'Defaults to ``None``.') def vectorize(self, mask: 'np.ndarray') -> Iterator['BaseGeometry']: """Vectorize binary mask representing the target class into polygons. diff --git a/tests/core/data/label_store/test_semantic_segmentation_label_store.py b/tests/core/data/label_store/test_semantic_segmentation_label_store.py index d06df262a..8d3638797 100644 --- a/tests/core/data/label_store/test_semantic_segmentation_label_store.py +++ b/tests/core/data/label_store/test_semantic_segmentation_label_store.py @@ -81,7 +81,9 @@ def test_saving_and_loading(self): bbox=None, smooth_output=True, smooth_as_uint8=True, - vector_outputs=[PolygonVectorOutputConfig(class_id=1)]) + vector_outputs=[ + PolygonVectorOutputConfig(class_id=1, threshold=0.3) + ]) labels = SemanticSegmentationSmoothLabels( extent=Box(0, 0, 10, 10), num_classes=len(class_config)) labels.pixel_scores += make_random_scores(