diff --git a/rastervision/data/label/object_detection_labels.py b/rastervision/data/label/object_detection_labels.py index 37f37a514b..d9a75c0729 100644 --- a/rastervision/data/label/object_detection_labels.py +++ b/rastervision/data/label/object_detection_labels.py @@ -61,11 +61,14 @@ def filter_by_aoi(self, aoi_polygons): box_poly = box.to_shapely() for aoi in aoi_polygons: if box_poly.within(aoi): - new_boxes.append(box) + new_boxes.append(box.npbox_format()) new_class_ids.append(class_id) new_scores.append(score) break + if len(new_boxes) == 0: + return ObjectDetectionLabels.make_empty() + return ObjectDetectionLabels( np.array(new_boxes), np.array(new_class_ids), np.array(new_scores)) diff --git a/tests/data/label/test_chip_classification_labels.py b/tests/data/label/test_chip_classification_labels.py index 75b69311ec..0449cb7cbe 100644 --- a/tests/data/label/test_chip_classification_labels.py +++ b/tests/data/label/test_chip_classification_labels.py @@ -63,6 +63,22 @@ def test_extend(self): self.assertEqual(len(cells), 3) self.assertTrue(cell3 in cells) + def test_filter_by_aoi(self): + aois = [Box.make_square(0, 0, 2).to_shapely()] + filt_labels = self.labels.filter_by_aoi(aois) + + exp_labels = ChipClassificationLabels() + cell1 = Box.make_square(0, 0, 2) + class_id1 = 1 + exp_labels.set_cell(cell1, class_id1) + self.assertEqual(filt_labels, exp_labels) + + aois = [Box.make_square(4, 4, 2).to_shapely()] + filt_labels = self.labels.filter_by_aoi(aois) + + exp_labels = ChipClassificationLabels() + self.assertEqual(filt_labels, exp_labels) + if __name__ == '__main__': unittest.main() diff --git a/tests/data/label/test_object_detection_labels.py b/tests/data/label/test_object_detection_labels.py index d194758b50..4f429906dd 100644 --- a/tests/data/label/test_object_detection_labels.py +++ b/tests/data/label/test_object_detection_labels.py @@ -182,6 +182,21 @@ def test_prune_duplicates(self): scores=expected_scores[pruned_inds]) pruned_labels.assert_equal(expected_labels) + def test_filter_by_aoi(self): + aois = [Box.make_square(0, 0, 2).to_shapely()] + filt_labels = self.labels.filter_by_aoi(aois) + + npboxes = np.array([[0., 0., 2., 2.]]) + class_ids = np.array([1]) + scores = np.array([0.9]) + exp_labels = ObjectDetectionLabels(npboxes, class_ids, scores=scores) + self.assertEqual(filt_labels, exp_labels) + + aois = [Box.make_square(4, 4, 2).to_shapely()] + filt_labels = self.labels.filter_by_aoi(aois) + exp_labels = ObjectDetectionLabels.make_empty() + self.assertEqual(filt_labels, exp_labels) + if __name__ == '__main__': unittest.main()