diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 4ff81374ba..f100558fc4 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -13,6 +13,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - CornerPool - Deformable Convolution v1/v2 - Deformable RoIPool +- DiffIoURotated - DynamicScatter - GatherPoints - FurthestPointSample diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 0d0a9c9103..776b05536f 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -13,6 +13,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - CornerPool - Deformable Convolution v1/v2 - Deformable RoIPool +- DiffIoURotated - DynamicScatter - GatherPoints - FurthestPointSample diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index bdd39fcae7..fd112f049d 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -18,6 +18,7 @@ from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d from .deprecated_wrappers import Linear_deprecated as Linear from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d +from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, sigmoid_focal_loss, softmax_focal_loss) from .furthest_point_sample import (furthest_point_sample, @@ -96,5 +97,5 @@ 'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', 'min_area_polygons', 'active_rotated_filter', - 'convex_iou', 'convex_giou' + 'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d' ] diff --git a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh new file mode 100644 index 0000000000..3ee1814e12 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh @@ -0,0 +1,136 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Adapted from +// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +#define MAX_NUM_VERT_IDX 9 +#define INTERSECTION_OFFSET 8 +#define EPSILON 1e-8 + +inline int opt_n_thread(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + return max(min(1 << pow_2, THREADS_PER_BLOCK), 1); +} + +/* +compare normalized vertices (vertices around (0,0)) +if vertex1 < vertex2 return true. +order: minimum at x-aixs, become larger in anti-clockwise direction +*/ +__device__ bool compare_vertices(float x1, float y1, float x2, float y2) { + if (fabs(x1 - x2) < EPSILON && fabs(y2 - y1) < EPSILON) + return false; // if equal, return false + + if (y1 > 0 && y2 < 0) return true; + if (y1 < 0 && y2 > 0) return false; + + float n1 = x1 * x1 + y1 * y1 + EPSILON; + float n2 = x2 * x2 + y2 * y2 + EPSILON; + float diff = fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2; + + if (y1 > 0 && y2 > 0) { + if (diff > EPSILON) + return true; + else + return false; + } + if (y1 < 0 && y2 < 0) { + if (diff < EPSILON) + return true; + else + return false; + } +} + +__global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel( + int b, int n, int m, const float *__restrict__ vertices, + const bool *__restrict__ mask, const int *__restrict__ num_valid, + int *__restrict__ idx) { + int batch_idx = blockIdx.x; + vertices += batch_idx * n * m * 2; + mask += batch_idx * n * m; + num_valid += batch_idx * n; + idx += batch_idx * n * MAX_NUM_VERT_IDX; + + int index = threadIdx.x; // index of polygon + int stride = blockDim.x; + for (int i = index; i < n; i += stride) { + int pad; // index of arbitrary invalid intersection point (not box corner!) + for (int j = INTERSECTION_OFFSET; j < m; ++j) { + if (!mask[i * m + j]) { + pad = j; + break; + } + } + if (num_valid[i] < 3) { + // not enough vertices, take an invalid intersection point + // (zero padding) + for (int j = 0; j < MAX_NUM_VERT_IDX; ++j) { + idx[i * MAX_NUM_VERT_IDX + j] = pad; + } + } else { + // sort the valid vertices + // note the number of valid vertices is known + // note: check that num_valid[i] < MAX_NUM_VERT_IDX + for (int j = 0; j < num_valid[i]; ++j) { + // initialize with a "big" value + float x_min = 1; + float y_min = -EPSILON; + int i_take = 0; + int i2; + float x2, y2; + if (j != 0) { + i2 = idx[i * MAX_NUM_VERT_IDX + j - 1]; + x2 = vertices[i * m * 2 + i2 * 2 + 0]; + y2 = vertices[i * m * 2 + i2 * 2 + 1]; + } + for (int k = 0; k < m; ++k) { + float x = vertices[i * m * 2 + k * 2 + 0]; + float y = vertices[i * m * 2 + k * 2 + 1]; + if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min)) { + if ((j == 0) || (j != 0 && compare_vertices(x2, y2, x, y))) { + x_min = x; + y_min = y; + i_take = k; + } + } + } + idx[i * MAX_NUM_VERT_IDX + j] = i_take; + } + // duplicate the first idx + idx[i * MAX_NUM_VERT_IDX + num_valid[i]] = idx[i * MAX_NUM_VERT_IDX + 0]; + + // pad zeros + for (int j = num_valid[i] + 1; j < MAX_NUM_VERT_IDX; ++j) { + idx[i * MAX_NUM_VERT_IDX + j] = pad; + } + + // for corner case: the two boxes are exactly the same. + // in this case, idx would have duplicate elements, which makes the + // shoelace formula broken because of the definition, the duplicate + // elements only appear in the first 8 positions (they are "corners in + // box", not "intersection of edges") + if (num_valid[i] == 8) { + int counter = 0; + for (int j = 0; j < 4; ++j) { + int check = idx[i * MAX_NUM_VERT_IDX + j]; + for (int k = 4; k < INTERSECTION_OFFSET; ++k) { + if (idx[i * MAX_NUM_VERT_IDX + k] == check) counter++; + } + } + if (counter == 4) { + idx[i * MAX_NUM_VERT_IDX + 4] = idx[i * MAX_NUM_VERT_IDX + 0]; + for (int j = 5; j < MAX_NUM_VERT_IDX; ++j) { + idx[i * MAX_NUM_VERT_IDX + j] = pad; + } + } + } + + // TODO: still might need to cover some other corner cases :( + } + } +} diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index b92ad6791b..93b19d4b6b 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -1699,3 +1699,19 @@ void convex_giou_impl(const Tensor pointsets, const Tensor polygons, REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda); REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda); + +Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(Tensor vertices, + Tensor mask, + Tensor num_valid); + +Tensor diff_iou_rotated_sort_vertices_forward_cuda(Tensor vertices, Tensor mask, + Tensor num_valid) { + return DiffIoURotatedSortVerticesCUDAKernelLauncher(vertices, mask, + num_valid); +} + +Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, + Tensor num_valid); + +REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA, + diff_iou_rotated_sort_vertices_forward_cuda); diff --git a/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu new file mode 100644 index 0000000000..62dbf5da35 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu @@ -0,0 +1,35 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Adapted from +// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa +#include "diff_iou_rotated_cuda_kernel.cuh" +#include "pytorch_cpp_helper.hpp" +#include "pytorch_cuda_helper.hpp" + +at::Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(at::Tensor vertices, + at::Tensor mask, + at::Tensor num_valid) { + at::cuda::CUDAGuard device_guard(vertices.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + CHECK_CONTIGUOUS(vertices); + CHECK_CONTIGUOUS(mask); + CHECK_CONTIGUOUS(num_valid); + CHECK_CUDA(vertices); + CHECK_CUDA(mask); + CHECK_CUDA(num_valid); + + int b = vertices.size(0); + int n = vertices.size(1); + int m = vertices.size(2); + at::Tensor idx = + torch::zeros({b, n, MAX_NUM_VERT_IDX}, + at::device(vertices.device()).dtype(at::ScalarType::Int)); + + diff_iou_rotated_sort_vertices_forward_cuda_kernel<<>>( + b, n, m, vertices.data_ptr(), mask.data_ptr(), + num_valid.data_ptr(), idx.data_ptr()); + AT_CUDA_CHECK(cudaGetLastError()); + + return idx; +} diff --git a/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp new file mode 100644 index 0000000000..2361b7fbe5 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp @@ -0,0 +1,14 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, + Tensor num_valid) { + return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, + vertices, mask, num_valid); +} + +Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask, + Tensor num_valid) { + return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 01c84c948f..b53ef3fb10 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -400,6 +400,10 @@ void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious); void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output); +at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices, + at::Tensor mask, + at::Tensor num_valid); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), @@ -809,4 +813,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("polygons"), py::arg("ious")); m.def("convex_giou", &convex_giou, "convex_giou", py::arg("pointsets"), py::arg("polygons"), py::arg("output")); + m.def("diff_iou_rotated_sort_vertices_forward", + &diff_iou_rotated_sort_vertices_forward, + "diff_iou_rotated_sort_vertices_forward", py::arg("vertices"), + py::arg("mask"), py::arg("num_valid")); } diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py new file mode 100644 index 0000000000..26bdbecf6e --- /dev/null +++ b/mmcv/ops/diff_iou_rotated.py @@ -0,0 +1,293 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/box_intersection_2d.py # noqa +# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/oriented_iou_loss.py # noqa +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +EPSILON = 1e-8 +ext_module = ext_loader.load_ext('_ext', + ['diff_iou_rotated_sort_vertices_forward']) + + +class SortVertices(Function): + + @staticmethod + def forward(ctx, vertices, mask, num_valid): + idx = ext_module.diff_iou_rotated_sort_vertices_forward( + vertices, mask, num_valid) + ctx.mark_non_differentiable(idx) + return idx + + @staticmethod + def backward(ctx, gradout): + return () + + +def box_intersection(corners1, corners2): + """Find intersection points of rectangles. + Convention: if two edges are collinear, there is no intersection point. + + Args: + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. + + Returns: + Tuple: + - Tensor: (B, N, 4, 4, 2) Intersections. + - Tensor: (B, N, 4, 4) Valid intersections mask. + """ + # build edges from corners + # B, N, 4, 4: Batch, Box, edge, point + line1 = torch.cat([corners1, corners1[:, :, [1, 2, 3, 0], :]], dim=3) + line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3) + # duplicate data to pair each edges from the boxes + # (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point + line1_ext = line1.unsqueeze(3) + line2_ext = line2.unsqueeze(2) + x1, y1, x2, y2 = line1_ext.split([1, 1, 1, 1], dim=-1) + x3, y3, x4, y4 = line2_ext.split([1, 1, 1, 1], dim=-1) + # math: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection + numerator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) + denumerator_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4) + t = denumerator_t / numerator + t[numerator == .0] = -1. + mask_t = (t > 0) & (t < 1) # intersection on line segment 1 + denumerator_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3) + u = -denumerator_u / numerator + u[numerator == .0] = -1. + mask_u = (u > 0) & (u < 1) # intersection on line segment 2 + mask = mask_t * mask_u + # overwrite with EPSILON. otherwise numerically unstable + t = denumerator_t / (numerator + EPSILON) + intersections = torch.stack([x1 + t * (x2 - x1), y1 + t * (y2 - y1)], + dim=-1) + intersections = intersections * mask.float().unsqueeze(-1) + return intersections, mask + + +def box1_in_box2(corners1, corners2): + """Check if corners of box1 lie in box2. + Convention: if a corner is exactly on the edge of the other box, + it's also a valid point. + + Args: + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. + + Returns: + Tensor: (B, N, 4) Intersection. + """ + # a, b, c, d - 4 vertices of box2 + a = corners2[:, :, 0:1, :] # (B, N, 1, 2) + b = corners2[:, :, 1:2, :] # (B, N, 1, 2) + d = corners2[:, :, 3:4, :] # (B, N, 1, 2) + # ab, am, ad - vectors between corresponding vertices + ab = b - a # (B, N, 1, 2) + am = corners1 - a # (B, N, 4, 2) + ad = d - a # (B, N, 1, 2) + prod_ab = torch.sum(ab * am, dim=-1) # (B, N, 4) + norm_ab = torch.sum(ab * ab, dim=-1) # (B, N, 1) + prod_ad = torch.sum(ad * am, dim=-1) # (B, N, 4) + norm_ad = torch.sum(ad * ad, dim=-1) # (B, N, 1) + # NOTE: the expression looks ugly but is stable if the two boxes + # are exactly the same also stable with different scale of bboxes + cond1 = (prod_ab / norm_ab > -1e-6) * (prod_ab / norm_ab < 1 + 1e-6 + ) # (B, N, 4) + cond2 = (prod_ad / norm_ad > -1e-6) * (prod_ad / norm_ad < 1 + 1e-6 + ) # (B, N, 4) + return cond1 * cond2 + + +def box_in_box(corners1, corners2): + """Check if corners of two boxes lie in each other. + + Args: + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. + + Returns: + Tuple: + - Tensor: (B, N, 4) True if i-th corner of box1 is in box2. + - Tensor: (B, N, 4) True if i-th corner of box2 is in box1. + """ + c1_in_2 = box1_in_box2(corners1, corners2) + c2_in_1 = box1_in_box2(corners2, corners1) + return c1_in_2, c2_in_1 + + +def build_vertices(corners1, corners2, c1_in_2, c2_in_1, intersections, + valid_mask): + """Find vertices of intersection area. + + Args: + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. + c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2. + c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1. + intersections (Tensor): (B, N, 4, 4, 2) Intersections. + valid_mask (Tensor): (B, N, 4, 4) Valid intersections mask. + + Returns: + Tuple: + - Tensor: (B, N, 24, 2) Vertices of intersection area; + only some elements are valid. + - Tensor: (B, N, 24) Mask of valid elements in vertices. + """ + # NOTE: inter has elements equals zero and has zeros gradient + # (masked by multiplying with 0); can be used as trick + B = corners1.size()[0] + N = corners1.size()[1] + # (B, N, 4 + 4 + 16, 2) + vertices = torch.cat( + [corners1, corners2, + intersections.view([B, N, -1, 2])], dim=2) + # Bool (B, N, 4 + 4 + 16) + mask = torch.cat([c1_in_2, c2_in_1, valid_mask.view([B, N, -1])], dim=2) + return vertices, mask + + +def sort_indices(vertices, mask): + """Sort indices. + Note: + why 9? the polygon has maximal 8 vertices. + +1 to duplicate the first element. + the index should have following structure: + (A, B, C, ... , A, X, X, X) + and X indicates the index of arbitrary elements in the last + 16 (intersections not corners) with value 0 and mask False. + (cause they have zero value and zero gradient) + + Args: + vertices (Tensor): (B, N, 24, 2) Box vertices. + mask (Tensor): (B, N, 24) Mask. + + Returns: + Tensor: (B, N, 9) Sorted indices. + + """ + num_valid = torch.sum(mask.int(), dim=2).int() # (B, N) + mean = torch.sum( + vertices * mask.float().unsqueeze(-1), dim=2, + keepdim=True) / num_valid.unsqueeze(-1).unsqueeze(-1) + vertices_normalized = vertices - mean # normalization makes sorting easier + return SortVertices.apply(vertices_normalized, mask, num_valid).long() + + +def calculate_area(idx_sorted, vertices): + """Calculate area of intersection. + + Args: + idx_sorted (Tensor): (B, N, 9) Sorted vertex ids. + vertices (Tensor): (B, N, 24, 2) Vertices. + + Returns: + Tuple: + - Tensor (B, N): Area of intersection. + - Tensor: (B, N, 9, 2) Vertices of polygon with zero padding. + """ + idx_ext = idx_sorted.unsqueeze(-1).repeat([1, 1, 1, 2]) + selected = torch.gather(vertices, 2, idx_ext) + total = selected[:, :, 0:-1, 0] * selected[:, :, 1:, 1] \ + - selected[:, :, 0:-1, 1] * selected[:, :, 1:, 0] + total = torch.sum(total, dim=2) + area = torch.abs(total) / 2 + return area, selected + + +def oriented_box_intersection_2d(corners1, corners2): + """Calculate intersection area of 2d rotated boxes. + + Args: + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. + + Returns: + Tuple: + - Tensor (B, N): Area of intersection. + - Tensor (B, N, 9, 2): Vertices of polygon with zero padding. + """ + intersections, valid_mask = box_intersection(corners1, corners2) + c12, c21 = box_in_box(corners1, corners2) + vertices, mask = build_vertices(corners1, corners2, c12, c21, + intersections, valid_mask) + sorted_indices = sort_indices(vertices, mask) + return calculate_area(sorted_indices, vertices) + + +def box2corners(box): + """Convert rotated 2d box coordinate to corners. + + Args: + box (Tensor): (B, N, 5) with x, y, w, h, alpha. + + Returns: + Tensor: (B, N, 4, 2) Corners. + """ + B = box.size()[0] + x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1) + x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device) + x4 = x4 * w # (B, N, 4) + y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device) + y4 = y4 * h # (B, N, 4) + corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2) + sin = torch.sin(alpha) + cos = torch.cos(alpha) + row1 = torch.cat([cos, sin], dim=-1) + row2 = torch.cat([-sin, cos], dim=-1) # (B, N, 2) + rot_T = torch.stack([row1, row2], dim=-2) # (B, N, 2, 2) + rotated = torch.bmm(corners.view([-1, 4, 2]), rot_T.view([-1, 2, 2])) + rotated = rotated.view([B, -1, 4, 2]) # (B * N, 4, 2) -> (B, N, 4, 2) + rotated[..., 0] += x + rotated[..., 1] += y + return rotated + + +def diff_iou_rotated_2d(box1, box2): + """Calculate differentiable iou of rotated 2d boxes. + + Args: + box1 (Tensor): (B, N, 5) First box. + box2 (Tensor): (B, N, 5) Second box. + + Returns: + Tensor: (B, N) IoU. + """ + corners1 = box2corners(box1) + corners2 = box2corners(box2) + intersection, _ = oriented_box_intersection_2d(corners1, + corners2) # (B, N) + area1 = box1[:, :, 2] * box1[:, :, 3] + area2 = box2[:, :, 2] * box2[:, :, 3] + union = area1 + area2 - intersection + iou = intersection / union + return iou + + +def diff_iou_rotated_3d(box3d1, box3d2): + """Calculate differentiable iou of rotated 3d boxes. + + Args: + box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha). + box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha). + + Returns: + Tensor: (B, N) IoU. + """ + box1 = box3d1[..., [0, 1, 3, 4, 6]] # 2d box + box2 = box3d2[..., [0, 1, 3, 4, 6]] + corners1 = box2corners(box1) + corners2 = box2corners(box2) + intersection, _ = oriented_box_intersection_2d(corners1, corners2) + zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5 + zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5 + zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5 + zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5 + z_overlap = (torch.min(zmax1, zmax2) - + torch.max(zmin1, zmin2)).clamp_(min=0.) + intersection_3d = intersection * z_overlap + volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5] + volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5] + union_3d = volume1 + volume2 - intersection_3d + return intersection_3d / union_3d diff --git a/tests/test_ops/test_diff_iou_rotated.py b/tests/test_ops/test_diff_iou_rotated.py new file mode 100644 index 0000000000..01e05551b0 --- /dev/null +++ b/tests/test_ops/test_diff_iou_rotated.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch + +from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_diff_iou_rotated_2d(): + np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., .0]]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., np.pi / 2], + [0.5, 0.5, 1., 1., np.pi / 4], [1., 1., 1., 1., .0], + [1.5, 1.5, 1., 1., .0]]], + dtype=np.float32) + + boxes1 = torch.from_numpy(np_boxes1).cuda() + boxes2 = torch.from_numpy(np_boxes2).cuda() + + np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]]) + ious = diff_iou_rotated_2d(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_diff_iou_rotated_3d(): + np_boxes1 = np.asarray( + [[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0]]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 2., np.pi / 2], + [.5, .5, .5, 1., 1., 1., np.pi / 4], [1., 1., 1., 1., 1., 1., .0], + [-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]], + dtype=np.float32) + + boxes1 = torch.from_numpy(np_boxes1).cuda() + boxes2 = torch.from_numpy(np_boxes2).cuda() + + np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]]) + ious = diff_iou_rotated_3d(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)