Skip to content

Commit

Permalink
allow specifying class probability thresholds for SS label vectorization
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Feb 9, 2024
1 parent 44891f0 commit 2359be1
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
Expand Up @@ -61,6 +61,22 @@ def get_label_arr(self, window: Box,
input window.
"""

@abstractmethod
def get_score_arr(self, window: Box,
null_class_id: int = -1) -> np.ndarray:
"""Get (C, H, W) array of pixel scores."""

def get_class_mask(self,
window: Box,
class_id: int,
threshold: Optional[float] = None) -> np.ndarray:
"""Get a binary mask representing all pixels of a class."""
scores = self.get_score_arr(window)
if threshold is None:
threshold = (1 / self.num_classes)
mask = scores[class_id] >= threshold
return mask

def get_windows(self, **kwargs) -> List[Box]:
"""Generate sliding windows over the local extent.
Expand Down
Expand Up @@ -311,13 +311,13 @@ def write_vector_outputs(self, labels: SemanticSegmentationLabels,

log.info('Writing vector outputs to disk.')

label_arr = labels.get_label_arr(labels.extent,
self.class_config.null_class_id)

extent = labels.extent
with tqdm(self.vector_outputs, desc='Vectorizing predictions') as bar:
for vo in bar:
bar.set_postfix(vo.dict())
class_mask = (label_arr == vo.class_id).astype(np.uint8)
class_mask = labels.get_class_mask(extent, vo.class_id,
vo.threshold)
class_mask = class_mask.astype(np.uint8)
polys = vo.vectorize(class_mask)
polys = [
self.crs_transformer.pixel_to_map(p, bbox=self.bbox)
Expand Down
Expand Up @@ -41,6 +41,15 @@ class VectorOutputConfig(Config):
'intensive (especially for large images). Larger values will remove '
'more noise and make vectorization faster but might also remove '
'legitimate detections.')
threshold: Optional[float] = Field(
None,
description='Probability threshold for creating the binary mask for '
'the pixels of this class. Pixels will be considered to belong to '
'this class if their probability for this class is >= ``threshold``. '
'Note that Raster Vision treats classes as mutually exclusive so the '
'threshold should vary with the number of total classes. '
'``None`` is equivalent to setting this to (1 / num_classes). '
'Defaults to ``None``.')

def vectorize(self, mask: 'np.ndarray') -> Iterator['BaseGeometry']:
"""Vectorize binary mask representing the target class into polygons.
Expand Down
Expand Up @@ -81,7 +81,9 @@ def test_saving_and_loading(self):
bbox=None,
smooth_output=True,
smooth_as_uint8=True,
vector_outputs=[PolygonVectorOutputConfig(class_id=1)])
vector_outputs=[
PolygonVectorOutputConfig(class_id=1, threshold=0.3)
])
labels = SemanticSegmentationSmoothLabels(
extent=Box(0, 0, 10, 10), num_classes=len(class_config))
labels.pixel_scores += make_random_scores(
Expand Down

0 comments on commit 2359be1

Please sign in to comment.