Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Differentiable rotated IoU #1854

Merged
merged 13 commits into from Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/understand_mmcv/ops.md
Expand Up @@ -13,6 +13,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- DiffIoURotated
- DynamicScatter
- GatherPoints
- FurthestPointSample
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/understand_mmcv/ops.md
Expand Up @@ -13,6 +13,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- DiffIoURotated
- DynamicScatter
- GatherPoints
- FurthestPointSample
Expand Down
3 changes: 2 additions & 1 deletion mmcv/ops/__init__.py
Expand Up @@ -18,6 +18,7 @@
from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
Expand Down Expand Up @@ -96,5 +97,5 @@
'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter',
'convex_iou', 'convex_giou'
'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
]
136 changes: 136 additions & 0 deletions mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh
@@ -0,0 +1,136 @@
// Copyright (c) OpenMMLab. All rights reserved
// Adapted from
// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

#define MAX_NUM_VERT_IDX 9
#define INTERSECTION_OFFSET 8
#define EPSILON 1e-8
grimoire marked this conversation as resolved.
Show resolved Hide resolved

inline int opt_n_thread(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, THREADS_PER_BLOCK), 1);
}

/*
compare normalized vertices (vertices around (0,0))
if vertex1 < vertex2 return true.
order: minimum at x-aixs, become larger in anti-clockwise direction
*/
__device__ bool compare_vertices(float x1, float y1, float x2, float y2) {
if (fabs(x1 - x2) < EPSILON && fabs(y2 - y1) < EPSILON)
return false; // if equal, return false

if (y1 > 0 && y2 < 0) return true;
if (y1 < 0 && y2 > 0) return false;

float n1 = x1 * x1 + y1 * y1 + EPSILON;
float n2 = x2 * x2 + y2 * y2 + EPSILON;
float diff = fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2;

if (y1 > 0 && y2 > 0) {
if (diff > EPSILON)
return true;
else
return false;
}
if (y1 < 0 && y2 < 0) {
if (diff < EPSILON)
return true;
else
return false;
}
}

__global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel(
int b, int n, int m, const float *__restrict__ vertices,
const bool *__restrict__ mask, const int *__restrict__ num_valid,
int *__restrict__ idx) {
int batch_idx = blockIdx.x;
vertices += batch_idx * n * m * 2;
mask += batch_idx * n * m;
num_valid += batch_idx * n;
idx += batch_idx * n * MAX_NUM_VERT_IDX;

int index = threadIdx.x; // index of polygon
int stride = blockDim.x;
for (int i = index; i < n; i += stride) {
int pad; // index of arbitrary invalid intersection point (not box corner!)
for (int j = INTERSECTION_OFFSET; j < m; ++j) {
if (!mask[i * m + j]) {
pad = j;
break;
}
}
if (num_valid[i] < 3) {
// not enough vertices, take an invalid intersection point
// (zero padding)
for (int j = 0; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}
} else {
// sort the valid vertices
// note the number of valid vertices is known
// note: check that num_valid[i] < MAX_NUM_VERT_IDX
for (int j = 0; j < num_valid[i]; ++j) {
// initialize with a "big" value
float x_min = 1;
float y_min = -EPSILON;
int i_take = 0;
int i2;
float x2, y2;
if (j != 0) {
i2 = idx[i * MAX_NUM_VERT_IDX + j - 1];
x2 = vertices[i * m * 2 + i2 * 2 + 0];
y2 = vertices[i * m * 2 + i2 * 2 + 1];
}
for (int k = 0; k < m; ++k) {
float x = vertices[i * m * 2 + k * 2 + 0];
float y = vertices[i * m * 2 + k * 2 + 1];
if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min)) {
if ((j == 0) || (j != 0 && compare_vertices(x2, y2, x, y))) {
x_min = x;
y_min = y;
i_take = k;
}
}
}
idx[i * MAX_NUM_VERT_IDX + j] = i_take;
}
// duplicate the first idx
idx[i * MAX_NUM_VERT_IDX + num_valid[i]] = idx[i * MAX_NUM_VERT_IDX + 0];

// pad zeros
for (int j = num_valid[i] + 1; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}

// for corner case: the two boxes are exactly the same.
// in this case, idx would have duplicate elements, which makes the
// shoelace formula broken because of the definition, the duplicate
// elements only appear in the first 8 positions (they are "corners in
// box", not "intersection of edges")
if (num_valid[i] == 8) {
int counter = 0;
for (int j = 0; j < 4; ++j) {
int check = idx[i * MAX_NUM_VERT_IDX + j];
for (int k = 4; k < INTERSECTION_OFFSET; ++k) {
if (idx[i * MAX_NUM_VERT_IDX + k] == check) counter++;
}
}
if (counter == 4) {
idx[i * MAX_NUM_VERT_IDX + 4] = idx[i * MAX_NUM_VERT_IDX + 0];
for (int j = 5; j < MAX_NUM_VERT_IDX; ++j) {
idx[i * MAX_NUM_VERT_IDX + j] = pad;
}
}
}

// TODO: still might need to cover some other corner cases :(
}
}
}
16 changes: 16 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Expand Up @@ -1699,3 +1699,19 @@ void convex_giou_impl(const Tensor pointsets, const Tensor polygons,

REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda);
REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda);

Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(Tensor vertices,
Tensor mask,
Tensor num_valid);

Tensor diff_iou_rotated_sort_vertices_forward_cuda(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DiffIoURotatedSortVerticesCUDAKernelLauncher(vertices, mask,
num_valid);
}

Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);

REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA,
diff_iou_rotated_sort_vertices_forward_cuda);
35 changes: 35 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu
@@ -0,0 +1,35 @@
// Copyright (c) OpenMMLab. All rights reserved
// Adapted from
// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa
#include "diff_iou_rotated_cuda_kernel.cuh"
#include "pytorch_cpp_helper.hpp"
#include "pytorch_cuda_helper.hpp"

at::Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(at::Tensor vertices,
at::Tensor mask,
at::Tensor num_valid) {
at::cuda::CUDAGuard device_guard(vertices.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

CHECK_CONTIGUOUS(vertices);
CHECK_CONTIGUOUS(mask);
CHECK_CONTIGUOUS(num_valid);
CHECK_CUDA(vertices);
CHECK_CUDA(mask);
CHECK_CUDA(num_valid);

int b = vertices.size(0);
int n = vertices.size(1);
int m = vertices.size(2);
at::Tensor idx =
torch::zeros({b, n, MAX_NUM_VERT_IDX},
at::device(vertices.device()).dtype(at::ScalarType::Int));

diff_iou_rotated_sort_vertices_forward_cuda_kernel<<<b, opt_n_thread(n), 0,
stream>>>(
b, n, m, vertices.data_ptr<float>(), mask.data_ptr<bool>(),
num_valid.data_ptr<int>(), idx.data_ptr<int>());
AT_CUDA_CHECK(cudaGetLastError());

return idx;
}
14 changes: 14 additions & 0 deletions mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp
@@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid) {
return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl,
vertices, mask, num_valid);
}

Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask,
Tensor num_valid) {
return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid);
}
8 changes: 8 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Expand Up @@ -400,6 +400,10 @@ void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious);

void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output);

at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices,
at::Tensor mask,
at::Tensor num_valid);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
Expand Down Expand Up @@ -809,4 +813,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("polygons"), py::arg("ious"));
m.def("convex_giou", &convex_giou, "convex_giou", py::arg("pointsets"),
py::arg("polygons"), py::arg("output"));
m.def("diff_iou_rotated_sort_vertices_forward",
&diff_iou_rotated_sort_vertices_forward,
"diff_iou_rotated_sort_vertices_forward", py::arg("vertices"),
py::arg("mask"), py::arg("num_valid"));
}