Skip to content

Commit

Permalink
Fix filter_by_aoi for ObjectDetectionLabels
Browse files Browse the repository at this point in the history
  • Loading branch information
lewfish committed Apr 3, 2019
1 parent 0b33c7b commit dbe3909
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
5 changes: 4 additions & 1 deletion rastervision/data/label/object_detection_labels.py
Expand Up @@ -61,11 +61,14 @@ def filter_by_aoi(self, aoi_polygons):
box_poly = box.to_shapely()
for aoi in aoi_polygons:
if box_poly.within(aoi):
new_boxes.append(box)
new_boxes.append(box.npbox_format())
new_class_ids.append(class_id)
new_scores.append(score)
break

if len(new_boxes) == 0:
return ObjectDetectionLabels.make_empty()

return ObjectDetectionLabels(
np.array(new_boxes), np.array(new_class_ids), np.array(new_scores))

Expand Down
16 changes: 16 additions & 0 deletions tests/data/label/test_chip_classification_labels.py
Expand Up @@ -63,6 +63,22 @@ def test_extend(self):
self.assertEqual(len(cells), 3)
self.assertTrue(cell3 in cells)

def test_filter_by_aoi(self):
aois = [Box.make_square(0, 0, 2).to_shapely()]
filt_labels = self.labels.filter_by_aoi(aois)

exp_labels = ChipClassificationLabels()
cell1 = Box.make_square(0, 0, 2)
class_id1 = 1
exp_labels.set_cell(cell1, class_id1)
self.assertEqual(filt_labels, exp_labels)

aois = [Box.make_square(4, 4, 2).to_shapely()]
filt_labels = self.labels.filter_by_aoi(aois)

exp_labels = ChipClassificationLabels()
self.assertEqual(filt_labels, exp_labels)


if __name__ == '__main__':
unittest.main()
15 changes: 15 additions & 0 deletions tests/data/label/test_object_detection_labels.py
Expand Up @@ -182,6 +182,21 @@ def test_prune_duplicates(self):
scores=expected_scores[pruned_inds])
pruned_labels.assert_equal(expected_labels)

def test_filter_by_aoi(self):
aois = [Box.make_square(0, 0, 2).to_shapely()]
filt_labels = self.labels.filter_by_aoi(aois)

npboxes = np.array([[0., 0., 2., 2.]])
class_ids = np.array([1])
scores = np.array([0.9])
exp_labels = ObjectDetectionLabels(npboxes, class_ids, scores=scores)
self.assertEqual(filt_labels, exp_labels)

aois = [Box.make_square(4, 4, 2).to_shapely()]
filt_labels = self.labels.filter_by_aoi(aois)
exp_labels = ObjectDetectionLabels.make_empty()
self.assertEqual(filt_labels, exp_labels)


if __name__ == '__main__':
unittest.main()

0 comments on commit dbe3909

Please sign in to comment.