Skip to content

Commit

Permalink
Adds Generalized IOU (pytorch#2642)
Browse files Browse the repository at this point in the history
* tries adding genaralized_iou

* fixes linting

* Adds docs for giou, iou and box area

* fixes lint

* removes docs to fixup in other PR

* linter fix

* Cleans comments

* Adds tests for box area, iou and giou

* typo fix for testCase

* fixes typo

* fixes box area test

* fixes implementation

* updates tests to tolerance
  • Loading branch information
oke-aditya authored and bryant1410 committed Nov 22, 2020
1 parent 544d978 commit c4c98f9
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/ops.rst
Expand Up @@ -15,6 +15,7 @@ torchvision.ops
.. autofunction:: clip_boxes_to_image
.. autofunction:: box_area
.. autofunction:: box_iou
.. autofunction:: generalized_box_iou
.. autofunction:: roi_align
.. autofunction:: ps_roi_align
.. autofunction:: roi_pool
Expand Down
46 changes: 46 additions & 0 deletions test/test_ops.py
Expand Up @@ -647,5 +647,51 @@ def test_convert_boxes_to_roi_format(self):
self.assertTrue(torch.equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence)))


class BoxAreaTester(unittest.TestCase):
def test_box_area(self):
# A bounding box of area 10000 and a degenerate case
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
expected = torch.tensor([10000, 0])
calc_area = ops.box_area(box_tensor)
assert calc_area.size() == torch.Size([2])
assert calc_area.dtype == box_tensor.dtype
assert torch.all(torch.eq(calc_area, expected)).item() is True


class BoxIouTester(unittest.TestCase):
def test_iou(self):
# Boxes to test Iou
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)

# Expected IoU matrix for these boxes
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])

out = ops.box_iou(boxes1, boxes2)

# Check if all elements of tensor are as expected.
assert out.size() == torch.Size([3, 3])
tolerance = 1e-4
assert ((out - expected).abs().max() < tolerance).item() is True


class GenBoxIouTester(unittest.TestCase):
def test_gen_iou(self):
# Test Generalized IoU
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)

# Expected gIoU matrix for these boxes
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611],
[-0.7778, -0.8611, 1.0]])

out = ops.generalized_box_iou(boxes1, boxes2)

# Check if all elements of tensor are as expected.
assert out.size() == torch.Size([3, 3])
tolerance = 1e-4
assert ((out - expected).abs().max() < tolerance).item() is True


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions torchvision/ops/__init__.py
@@ -1,4 +1,4 @@
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou
from .new_empty_tensor import _new_empty_tensor
from .deform_conv import deform_conv2d, DeformConv2d
from .roi_align import roi_align, RoIAlign
Expand All @@ -15,7 +15,7 @@

__all__ = [
'deform_conv2d', 'DeformConv2d', 'nms', 'batched_nms', 'remove_small_boxes',
'clip_boxes_to_image', 'box_area', 'box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
'clip_boxes_to_image', 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
]
46 changes: 44 additions & 2 deletions torchvision/ops/boxes.py
Expand Up @@ -161,8 +161,7 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
boxes2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)
Expand All @@ -175,3 +174,46 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:

iou = inter / (area1[:, None] + area2 - inter)
return iou


# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Return generalized intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
boxes1 (Tensor[N, 4])
boxes2 (Tensor[M, 4])
Returns:
generalized_iou (Tensor[N, M]): the NxM matrix containing the pairwise generalized_IoU values
for every element in boxes1 and boxes2
"""

# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()

area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

iou = inter / union

lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

whi = (rbi - lti).clamp(min=0) # [N,M,2]
areai = whi[:, :, 0] * whi[:, :, 1]

return iou - (areai - union) / areai

0 comments on commit c4c98f9

Please sign in to comment.