Skip to content

Commit

Permalink
Merge pull request #1843 from AdeelH/fix-1842
Browse files Browse the repository at this point in the history
Fix #1842
  • Loading branch information
AdeelH committed Aug 1, 2023
2 parents e229483 + ba1e155 commit dc23f1f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def get_scores(self) -> 'SemanticSegmentationSmoothLabels':
score_arr = score_arr.astype(np.float16)
score_arr /= 255

_, h, w = score_arr.shape
labels = SemanticSegmentationSmoothLabels(
extent=Box(0, 0, *score_arr.shape),
num_classes=len(self.class_config))
extent=Box(0, 0, h, w), num_classes=len(self.class_config))
labels.pixel_scores = score_arr * hits_arr
labels.pixel_hits = hits_arr
return labels
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from typing import Callable
from os.path import join
import unittest

import numpy as np

from rastervision.core.data import (BuildingVectorOutputConfig, ClassConfig,
PolygonVectorOutputConfig,
VectorOutputConfig)
from rastervision.pipeline.file_system.utils import get_tmp_dir, file_exists
from rastervision.core.box import Box
from rastervision.core.data import (
BuildingVectorOutputConfig, ClassConfig, IdentityCRSTransformer,
PolygonVectorOutputConfig, SemanticSegmentationLabelStore,
SemanticSegmentationSmoothLabels, VectorOutputConfig)
from tests.core.data.label.test_semantic_segmentation_labels import (
make_random_scores)


class TestVectorOutputConfig(unittest.TestCase):
Expand Down Expand Up @@ -57,5 +64,64 @@ def test_denoise(self):
self.assertEqual(len(polys), 1)


class TestSemanticSegmentationLabelStore(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
fn()
except Exception:
self.fail(msg)

def test_saving_and_loading(self):
with get_tmp_dir() as tmp_dir:
class_config = ClassConfig(names=['bg', 'fg'], null_class='bg')
label_store = SemanticSegmentationLabelStore(
uri=tmp_dir,
crs_transformer=IdentityCRSTransformer(),
class_config=class_config,
bbox=None,
smooth_output=True,
smooth_as_uint8=True,
vector_outputs=[PolygonVectorOutputConfig(class_id=1)])
labels = SemanticSegmentationSmoothLabels(
extent=Box(0, 0, 10, 10), num_classes=len(class_config))
labels.pixel_scores += make_random_scores(
len(class_config), 10, 10)
labels.pixel_hits += 1
label_store.save(labels)

self.assertTrue(file_exists(join(tmp_dir, 'labels.tif')))
self.assertTrue(file_exists(join(tmp_dir, 'scores.tif')))
self.assertTrue(file_exists(join(tmp_dir, 'pixel_hits.npy')))

del label_store

# test compatibility validation
args = dict(
uri=tmp_dir,
crs_transformer=IdentityCRSTransformer(),
class_config=ClassConfig(names=['bg', 'fg', 'null']),
smooth_output=True,
smooth_as_uint8=True,
)
with self.assertRaises(FileExistsError):
label_store = SemanticSegmentationLabelStore(**args)

args = dict(
uri=tmp_dir,
crs_transformer=IdentityCRSTransformer(),
class_config=class_config,
smooth_output=True,
smooth_as_uint8=True,
)
label_store = SemanticSegmentationLabelStore(**args)
self.assertIsNotNone(label_store.label_source)
self.assertIsNotNone(label_store.score_source)

self.assertNoError(lambda: label_store.get_labels())
self.assertNoError(lambda: label_store.get_scores())

self.assertNoError(lambda: label_store.save(labels))


if __name__ == '__main__':
unittest.main()

0 comments on commit dc23f1f

Please sign in to comment.