Skip to content

Commit

Permalink
fix: ensure SS datasets always return label array with correct dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Oct 5, 2023
1 parent 142df38 commit e4e10ad
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def semantic_segmentation_transformer(
y = np.array(y)
out = apply_transform(transform, image=x, mask=y)
x, y = out['image'], out['mask']
y = y.astype(int)
if y is not None:
y = y.astype(int)
return x, y


Expand Down
26 changes: 25 additions & 1 deletion tests/pytorch_learner/dataset/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import albumentations as A

from rastervision.pytorch_learner.dataset.transform import (
yxyx_to_albu, albu_to_yxyx, xywh_to_albu, apply_transform)
yxyx_to_albu, albu_to_yxyx, xywh_to_albu, apply_transform,
semantic_segmentation_transformer)


class TestTransforms(unittest.TestCase):
Expand Down Expand Up @@ -65,6 +66,29 @@ def test_box_format_conversions_xywh(self):
boxes_albu = xywh_to_albu(boxes, (10, 10))
np.testing.assert_allclose(boxes_albu, boxes_albu_gt)

def test_semantic_segmentation_transformer(self):
# w/ y, w/o transform
x_in, y_in = np.zeros((10, 10, 3), dtype=np.uint8), np.zeros((10, 10))
x_out, y_out = semantic_segmentation_transformer((x_in, y_in), None)
np.issubdtype(y_out.dtype, int)

# w/ y, w/ transform
x_out, y_out = semantic_segmentation_transformer((x_in, y_in),
A.Resize(20, 20))
self.assertEqual(x_out.shape, (20, 20, 3))
self.assertEqual(y_out.shape, (20, 20))
np.issubdtype(y_out.dtype, int)

# w/o y, w/o transform
x_out, y_out = semantic_segmentation_transformer((x_in, None), None)
self.assertIsNone(y_out)

# w/o y, w/ transform
x_out, y_out = semantic_segmentation_transformer((x_in, None),
A.Resize(20, 20))
self.assertEqual(x_out.shape, (20, 20, 3))
self.assertIsNone(y_out)


if __name__ == '__main__':
unittest.main()

0 comments on commit e4e10ad

Please sign in to comment.