-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
iou_box3d.py
168 lines (138 loc) · 4.78 KB
/
iou_box3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Tuple
import torch
import torch.nn.functional as F
from pytorch3d import _C
from torch.autograd import Function
# -------------------------------------------------- #
# CONSTANTS #
# -------------------------------------------------- #
"""
_box_planes and _box_triangles define the 4- and 3-connectivity
of the 8 box corners.
_box_planes gives the quad faces of the 3D box
_box_triangles gives the triangle faces of the 3D box
"""
_box_planes = [
[0, 1, 2, 3],
[3, 2, 6, 7],
[0, 1, 5, 4],
[0, 3, 7, 4],
[1, 2, 6, 5],
[4, 5, 6, 7],
]
_box_triangles = [
[0, 1, 2],
[0, 3, 2],
[4, 5, 6],
[4, 6, 7],
[1, 5, 6],
[1, 6, 2],
[0, 4, 7],
[0, 7, 3],
[3, 2, 6],
[3, 6, 7],
[0, 1, 5],
[0, 4, 5],
]
def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-4) -> None:
faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device)
verts = boxes.index_select(index=faces.view(-1), dim=1)
B = boxes.shape[0]
P, V = faces.shape
# (B, P, 4, 3) -> (B, P, 3)
v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2)
# Compute the normal
e0 = F.normalize(v1 - v0, dim=-1)
e1 = F.normalize(v2 - v0, dim=-1)
normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
# Check the fourth vertex is also on the same plane
mat1 = (v3 - v0).view(B, 1, -1) # (B, 1, P*3)
mat2 = normal.view(B, -1, 1) # (B, P*3, 1)
if not (mat1.bmm(mat2).abs() < eps).all().item():
msg = "Plane vertices are not coplanar"
raise ValueError(msg)
return
def _check_nonzero(boxes: torch.Tensor, eps: float = 1e-4) -> None:
"""
Checks that the sides of the box have a non zero area
"""
faces = torch.tensor(_box_triangles, dtype=torch.int64, device=boxes.device)
verts = boxes.index_select(index=faces.view(-1), dim=1)
B = boxes.shape[0]
T, V = faces.shape
# (B, T, 3, 3) -> (B, T, 3)
v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2)
normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # (B, T, 3)
face_areas = normals.norm(dim=-1) / 2
if (face_areas < eps).any().item():
msg = "Planes have zero areas"
raise ValueError(msg)
return
class _box3d_overlap(Function):
"""
Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
Backward is not supported.
"""
@staticmethod
def forward(ctx, boxes1, boxes2):
"""
Arguments defintions the same as in the box3d_overlap function
"""
vol, iou = _C.iou_box3d(boxes1, boxes2)
return vol, iou
@staticmethod
def backward(ctx, grad_vol, grad_iou):
raise ValueError("box3d_overlap backward is not supported")
def box3d_overlap(
boxes1: torch.Tensor, boxes2: torch.Tensor, eps: float = 1e-4
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the intersection of 3D boxes1 and boxes2.
Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
(where B doesn't have to be the same for boxes1 and boxes2),
containing the 8 corners of the boxes, as follows:
(4) +---------+. (5)
| ` . | ` .
| (0) +---+-----+ (1)
| | | |
(7) +-----+---+. (6)|
` . | ` . |
(3) ` +---------+ (2)
NOTE: Throughout this implementation, we assume that boxes
are defined by their 8 corners exactly in the order specified in the
diagram above for the function to give correct results. In addition
the vertices on each plane must be coplanar.
As an alternative to the diagram, this is a unit bounding
box which has the correct vertex ordering:
box_corner_vertices = [
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1],
]
Args:
boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
Returns:
vol: (N, M) tensor of the volume of the intersecting convex shapes
iou: (N, M) tensor of the intersection over union which is
defined as: `iou = vol / (vol1 + vol2 - vol)`
"""
if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
raise ValueError("Each box in the batch must be of shape (8, 3)")
_check_coplanar(boxes1, eps)
_check_coplanar(boxes2, eps)
_check_nonzero(boxes1, eps)
_check_nonzero(boxes2, eps)
vol, iou = _box3d_overlap.apply(boxes1, boxes2)
return vol, iou