diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 338e32b652..b5a06c7614 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -24,6 +24,7 @@ from .gather_points import gather_points from .info import (get_compiler_version, get_compiling_cuda_version, get_onnxruntime_op_path) +from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev from .knn import knn from .masked_conv import MaskedConv2d, masked_conv2d from .modulated_deform_conv import (ModulatedDeformConv2d, @@ -53,82 +54,27 @@ from .voxelize import Voxelization, voxelization __all__ = [ - 'bbox_overlaps', - 'CARAFE', - 'CARAFENaive', - 'CARAFEPack', - 'carafe', - 'carafe_naive', - 'CornerPool', - 'DeformConv2d', - 'DeformConv2dPack', - 'deform_conv2d', - 'DeformRoIPool', - 'DeformRoIPoolPack', - 'ModulatedDeformRoIPoolPack', - 'deform_roi_pool', - 'SigmoidFocalLoss', - 'SoftmaxFocalLoss', - 'sigmoid_focal_loss', - 'softmax_focal_loss', - 'get_compiler_version', - 'get_compiling_cuda_version', - 'get_onnxruntime_op_path', - 'MaskedConv2d', - 'masked_conv2d', - 'ModulatedDeformConv2d', - 'ModulatedDeformConv2dPack', - 'modulated_deform_conv2d', - 'batched_nms', - 'nms', - 'soft_nms', - 'nms_match', - 'RoIAlign', - 'roi_align', - 'RoIPool', - 'roi_pool', - 'SyncBatchNorm', - 'Conv2d', - 'ConvTranspose2d', - 'Linear', - 'MaxPool2d', - 'CrissCrossAttention', - 'PSAMask', - 'point_sample', - 'rel_roi_point_to_rel_img_point', - 'SimpleRoIAlign', - 'SAConv2d', - 'TINShift', - 'tin_shift', - 'assign_score_withk', - 'box_iou_rotated', - 'RoIPointPool3d', - 'nms_rotated', - 'knn', - 'ball_query', - 'upfirdn2d', - 'FusedBiasLeakyReLU', - 'fused_bias_leakyrelu', - 'RoIAlignRotated', - 'roi_align_rotated', - 'pixel_group', - 'contour_expand', - 'three_nn', - 'three_interpolate', - 'MultiScaleDeformableAttention', - 'Voxelization', - 'voxelization', - 'dynamic_scatter', - 'DynamicScatter', - 'BorderAlign', - 'border_align', - 'gather_points', - 'furthest_point_sample', - 'furthest_point_sample_with_dist', - 'PointsSampler', - 'Correlation', - 'RoIAwarePool3d', - 'points_in_boxes_part', - 'points_in_boxes_cpu', - 'points_in_boxes_all', + 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', + 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack', + 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack', + 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss', + 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss', + 'get_compiler_version', 'get_compiling_cuda_version', + 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d', + 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', + 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', + 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', + 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', + 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', + 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', + 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', + 'upfirdn2d', 'FusedBiasLeakyReLU', 'boxes_iou_bev', 'nms_bev', + 'nms_normal_bev', 'fused_bias_leakyrelu', 'RoIAlignRotated', + 'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn', + 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', + 'border_align', 'gather_points', 'furthest_point_sample', + 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', + 'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter', + 'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu', + 'points_in_boxes_all' ] diff --git a/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh new file mode 100644 index 0000000000..4e261cbd0c --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh @@ -0,0 +1,369 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef IOU3D_CUDA_KERNEL_CUH +#define IOU3D_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +const int THREADS_PER_BLOCK_IOU3D = 16; +const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; +__device__ const float EPS = 1e-8; + +struct Point { + float x, y; + __device__ Point() {} + __device__ Point(double _x, double _y) { x = _x, y = _y; } + + __device__ void set(float _x, float _y) { + x = _x; + y = _y; + } + + __device__ Point operator+(const Point &b) const { + return Point(x + b.x, y + b.y); + } + + __device__ Point operator-(const Point &b) const { + return Point(x - b.x, y - b.y); + } +}; + +__device__ inline float cross(const Point &a, const Point &b) { + return a.x * b.y - a.y * b.x; +} + +__device__ inline float cross(const Point &p1, const Point &p2, + const Point &p0) { + return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); +} + +__device__ int check_rect_cross(const Point &p1, const Point &p2, + const Point &q1, const Point &q2) { + int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) && + min(q1.x, q2.x) <= max(p1.x, p2.x) && + min(p1.y, p2.y) <= max(q1.y, q2.y) && + min(q1.y, q2.y) <= max(p1.y, p2.y); + return ret; +} + +__device__ inline int check_in_box2d(const float *box, const Point &p) { + // params: box (5) [x1, y1, x2, y2, angle] + const float MARGIN = 1e-5; + + float center_x = (box[0] + box[2]) / 2; + float center_y = (box[1] + box[3]) / 2; + float angle_cos = cos(-box[4]), + angle_sin = + sin(-box[4]); // rotate the point in the opposite direction of box + float rot_x = + (p.x - center_x) * angle_cos - (p.y - center_y) * angle_sin + center_x; + float rot_y = + (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y; + + return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN && + rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN); +} + +__device__ inline int intersection(const Point &p1, const Point &p0, + const Point &q1, const Point &q0, + Point &ans_point) { + // fast exclusion + if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; + + // check cross standing + float s1 = cross(q0, p1, p0); + float s2 = cross(p1, q1, p0); + float s3 = cross(p0, q1, q0); + float s4 = cross(q1, p1, q0); + + if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0; + + // calculate intersection of two lines + float s5 = cross(q1, p1, p0); + if (fabs(s5 - s1) > EPS) { + ans_point.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); + ans_point.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); + + } else { + float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; + float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; + float D = a0 * b1 - a1 * b0; + + ans_point.x = (b0 * c1 - b1 * c0) / D; + ans_point.y = (a1 * c0 - a0 * c1) / D; + } + + return 1; +} + +__device__ inline void rotate_around_center(const Point ¢er, + const float angle_cos, + const float angle_sin, Point &p) { + float new_x = + (p.x - center.x) * angle_cos - (p.y - center.y) * angle_sin + center.x; + float new_y = + (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; + p.set(new_x, new_y); +} + +__device__ inline int point_cmp(const Point &a, const Point &b, + const Point ¢er) { + return atan2(a.y - center.y, a.x - center.x) > + atan2(b.y - center.y, b.x - center.x); +} + +__device__ inline float box_overlap(const float *box_a, const float *box_b) { + // params: box_a (5) [x1, y1, x2, y2, angle] + // params: box_b (5) [x1, y1, x2, y2, angle] + + float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3], + a_angle = box_a[4]; + float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3], + b_angle = box_b[4]; + + Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2); + Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2); + + Point box_a_corners[5]; + box_a_corners[0].set(a_x1, a_y1); + box_a_corners[1].set(a_x2, a_y1); + box_a_corners[2].set(a_x2, a_y2); + box_a_corners[3].set(a_x1, a_y2); + + Point box_b_corners[5]; + box_b_corners[0].set(b_x1, b_y1); + box_b_corners[1].set(b_x2, b_y1); + box_b_corners[2].set(b_x2, b_y2); + box_b_corners[3].set(b_x1, b_y2); + + // get oriented corners + float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle); + float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle); + + for (int k = 0; k < 4; k++) { + rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]); + rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]); + } + + box_a_corners[4] = box_a_corners[0]; + box_b_corners[4] = box_b_corners[0]; + + // get intersection of lines + Point cross_points[16]; + Point poly_center; + int cnt = 0, flag = 0; + + poly_center.set(0, 0); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + flag = intersection(box_a_corners[i + 1], box_a_corners[i], + box_b_corners[j + 1], box_b_corners[j], + cross_points[cnt]); + if (flag) { + poly_center = poly_center + cross_points[cnt]; + cnt++; + } + } + } + + // check corners + for (int k = 0; k < 4; k++) { + if (check_in_box2d(box_a, box_b_corners[k])) { + poly_center = poly_center + box_b_corners[k]; + cross_points[cnt] = box_b_corners[k]; + cnt++; + } + if (check_in_box2d(box_b, box_a_corners[k])) { + poly_center = poly_center + box_a_corners[k]; + cross_points[cnt] = box_a_corners[k]; + cnt++; + } + } + + poly_center.x /= cnt; + poly_center.y /= cnt; + + // sort the points of polygon + Point temp; + for (int j = 0; j < cnt - 1; j++) { + for (int i = 0; i < cnt - j - 1; i++) { + if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) { + temp = cross_points[i]; + cross_points[i] = cross_points[i + 1]; + cross_points[i + 1] = temp; + } + } + } + + // get the overlap areas + float area = 0; + for (int k = 0; k < cnt - 1; k++) { + area += cross(cross_points[k] - cross_points[0], + cross_points[k + 1] - cross_points[0]); + } + + return fabs(area) / 2.0; +} + +__device__ inline float iou_bev(const float *box_a, const float *box_b) { + // params: box_a (5) [x1, y1, x2, y2, angle] + // params: box_b (5) [x1, y1, x2, y2, angle] + float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]); + float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]); + float s_overlap = box_overlap(box_a, box_b); + return s_overlap / fmaxf(sa + sb - s_overlap, EPS); +} + +__global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel( + const int num_a, const float *boxes_a, const int num_b, + const float *boxes_b, float *ans_overlap) { + const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; + const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; + + if (a_idx >= num_a || b_idx >= num_b) { + return; + } + const float *cur_box_a = boxes_a + a_idx * 5; + const float *cur_box_b = boxes_b + b_idx * 5; + float s_overlap = box_overlap(cur_box_a, cur_box_b); + ans_overlap[a_idx * num_b + b_idx] = s_overlap; +} + +__global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a, + const float *boxes_a, + const int num_b, + const float *boxes_b, + float *ans_iou) { + const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; + const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; + + if (a_idx >= num_a || b_idx >= num_b) { + return; + } + + const float *cur_box_a = boxes_a + a_idx * 5; + const float *cur_box_b = boxes_b + b_idx * 5; + float cur_iou_bev = iou_bev(cur_box_a, cur_box_b); + ans_iou[a_idx * num_b + b_idx] = cur_iou_bev; +} + +__global__ void nms_forward_cuda_kernel(const int boxes_num, + const float nms_overlap_thresh, + const float *boxes, + unsigned long long *mask) { + // params: boxes (N, 5) [x1, y1, x2, y2, ry] + // params: mask (N, N/THREADS_PER_BLOCK_NMS) + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; + + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; + const float *cur_box = boxes + cur_box_idx * 5; + + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +__device__ inline float iou_normal(float const *const a, float const *const b) { + float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); + float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); + float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0]) * (a[3] - a[1]); + float Sb = (b[2] - b[0]) * (b[3] - b[1]); + return interS / fmaxf(Sa + Sb - interS, EPS); +} + +__global__ void nms_normal_forward_cuda_kernel(const int boxes_num, + const float nms_overlap_thresh, + const float *boxes, + unsigned long long *mask) { + // params: boxes (N, 5) [x1, y1, x2, y2, ry] + // params: mask (N, N/THREADS_PER_BLOCK_NMS) + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, + THREADS_PER_BLOCK_NMS); + + __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; + + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; + const float *cur_box = boxes + cur_box_idx * 5; + + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +#endif // IOU3D_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp index b812e62713..c7f9f35b7b 100644 --- a/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_cpp_helper.hpp @@ -6,6 +6,8 @@ using namespace at; +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CPU(x) \ diff --git a/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu new file mode 100644 index 0000000000..0643c16044 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu @@ -0,0 +1,86 @@ +// Modified from +// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu + +/* +3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) +Written by Shaoshuai Shi +All Rights Reserved 2019-2020. +*/ + +#include + +#include "iou3d_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_overlap) { + at::cuda::CUDAGuard device_guard(boxes_a.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D), + DIVUP(num_a, THREADS_PER_BLOCK_IOU3D)); + dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D); + + iou3d_boxes_overlap_bev_forward_cuda_kernel<<>>( + num_a, boxes_a.data_ptr(), num_b, boxes_b.data_ptr(), + ans_overlap.data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_iou) { + at::cuda::CUDAGuard device_guard(boxes_a.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D), + DIVUP(num_a, THREADS_PER_BLOCK_IOU3D)); + dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D); + + iou3d_boxes_iou_bev_forward_cuda_kernel<<>>( + num_a, boxes_a.data_ptr(), num_b, boxes_b.data_ptr(), + ans_iou.data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, int boxes_num, + float nms_overlap_thresh) { + at::cuda::CUDAGuard device_guard(boxes.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), + DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); + dim3 threads(THREADS_PER_BLOCK_NMS); + + nms_forward_cuda_kernel<<>>( + boxes_num, nms_overlap_thresh, boxes.data_ptr(), mask); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, + int boxes_num, + float nms_overlap_thresh) { + at::cuda::CUDAGuard device_guard(boxes.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), + DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); + dim3 threads(THREADS_PER_BLOCK_NMS); + + nms_normal_forward_cuda_kernel<<>>( + boxes_num, nms_overlap_thresh, boxes.data_ptr(), mask); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/iou3d.cpp b/mmcv/ops/csrc/pytorch/iou3d.cpp new file mode 100644 index 0000000000..eecfdf224a --- /dev/null +++ b/mmcv/ops/csrc/pytorch/iou3d.cpp @@ -0,0 +1,244 @@ +// Modified from +// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp + +/* +3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) +Written by Shaoshuai Shi +All Rights Reserved 2019-2020. +*/ + +#include "pytorch_cpp_helper.hpp" + +const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; + +#ifdef MMCV_WITH_CUDA +#include +#include + +#define CHECK_ERROR(state) \ + { gpuAssert((state), __FILE__, __LINE__); } +inline void gpuAssert(cudaError_t code, const char *file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) exit(code); + } +} + +void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_overlap); +void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_overlap) { + IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b, + ans_overlap); +}; + +void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a, + const Tensor boxes_a, + const int num_b, + const Tensor boxes_b, + Tensor ans_iou); +void iou3d_boxes_iou_bev_forward_cuda(const int num_a, const Tensor boxes_a, + const int num_b, const Tensor boxes_b, + Tensor ans_iou) { + IoU3DBoxesIoUBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b, + ans_iou); +}; + +void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, int boxes_num, + float nms_overlap_thresh); + +void iou3d_nms_forward_cuda(const Tensor boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh) { + IoU3DNMSForwardCUDAKernelLauncher(boxes, mask, boxes_num, nms_overlap_thresh); +}; + +void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes, + unsigned long long *mask, + int boxes_num, + float nms_overlap_thresh); + +void iou3d_nms_normal_forward_cuda(const Tensor boxes, unsigned long long *mask, + int boxes_num, float nms_overlap_thresh) { + IoU3DNMSNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num, + nms_overlap_thresh); +}; +#endif + +void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b, + Tensor ans_overlap) { + // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] + // params boxes_b: (M, 5) + // params ans_overlap: (N, M) + + if (boxes_a.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes_a); + CHECK_CUDA_INPUT(boxes_b); + CHECK_CUDA_INPUT(ans_overlap); + + int num_a = boxes_a.size(0); + int num_b = boxes_b.size(0); + + iou3d_boxes_overlap_bev_forward_cuda(num_a, boxes_a, num_b, boxes_b, + ans_overlap); +#else + AT_ERROR("iou3d_boxes_overlap_bev is not compiled with GPU support"); +#endif + } else { + AT_ERROR("iou3d_boxes_overlap_bev is not implemented on CPU"); + } +} + +void iou3d_boxes_iou_bev_forward(Tensor boxes_a, Tensor boxes_b, + Tensor ans_iou) { + // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] + // params boxes_b: (M, 5) + // params ans_overlap: (N, M) + + if (boxes_a.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes_a); + CHECK_CUDA_INPUT(boxes_b); + CHECK_CUDA_INPUT(ans_iou); + + int num_a = boxes_a.size(0); + int num_b = boxes_b.size(0); + + iou3d_boxes_iou_bev_forward_cuda(num_a, boxes_a, num_b, boxes_b, ans_iou); +#else + AT_ERROR("iou3d_boxes_iou_bev is not compiled with GPU support"); +#endif + } else { + AT_ERROR("iou3d_boxes_iou_bev is not implemented on CPU"); + } +} + +int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh) { + // params boxes: (N, 5) [x1, y1, x2, y2, ry] + // params keep: (N) + + if (boxes.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes); + CHECK_CONTIGUOUS(keep); + + int boxes_num = boxes.size(0); + int64_t *keep_data = keep.data_ptr(); + + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + + unsigned long long *mask_data = NULL; + CHECK_ERROR( + cudaMalloc((void **)&mask_data, + boxes_num * col_blocks * sizeof(unsigned long long))); + iou3d_nms_forward_cuda(boxes, mask_data, boxes_num, nms_overlap_thresh); + + // unsigned long long mask_cpu[boxes_num * col_blocks]; + // unsigned long long *mask_cpu = new unsigned long long [boxes_num * + // col_blocks]; + std::vector mask_cpu(boxes_num * col_blocks); + + // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); + CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, + boxes_num * col_blocks * sizeof(unsigned long long), + cudaMemcpyDeviceToHost)); + + cudaFree(mask_data); + + unsigned long long *remv_cpu = new unsigned long long[col_blocks](); + + int num_to_keep = 0; + + for (int i = 0; i < boxes_num; i++) { + int nblock = i / THREADS_PER_BLOCK_NMS; + int inblock = i % THREADS_PER_BLOCK_NMS; + + if (!(remv_cpu[nblock] & (1ULL << inblock))) { + keep_data[num_to_keep++] = i; + unsigned long long *p = &mask_cpu[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv_cpu[j] |= p[j]; + } + } + } + delete[] remv_cpu; + if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); + + return num_to_keep; + +#else + AT_ERROR("iou3d_nms is not compiled with GPU support"); +#endif + } else { + AT_ERROR("iou3d_nms is not implemented on CPU"); + } +} + +int iou3d_nms_normal_forward(Tensor boxes, Tensor keep, + float nms_overlap_thresh) { + // params boxes: (N, 5) [x1, y1, x2, y2, ry] + // params keep: (N) + + if (boxes.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes); + CHECK_CONTIGUOUS(keep); + + int boxes_num = boxes.size(0); + int64_t *keep_data = keep.data_ptr(); + + const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); + + unsigned long long *mask_data = NULL; + CHECK_ERROR( + cudaMalloc((void **)&mask_data, + boxes_num * col_blocks * sizeof(unsigned long long))); + iou3d_nms_normal_forward_cuda(boxes, mask_data, boxes_num, + nms_overlap_thresh); + + // unsigned long long mask_cpu[boxes_num * col_blocks]; + // unsigned long long *mask_cpu = new unsigned long long [boxes_num * + // col_blocks]; + std::vector mask_cpu(boxes_num * col_blocks); + + CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, + boxes_num * col_blocks * sizeof(unsigned long long), + cudaMemcpyDeviceToHost)); + + cudaFree(mask_data); + + unsigned long long *remv_cpu = new unsigned long long[col_blocks](); + + int num_to_keep = 0; + + for (int i = 0; i < boxes_num; i++) { + int nblock = i / THREADS_PER_BLOCK_NMS; + int inblock = i % THREADS_PER_BLOCK_NMS; + + if (!(remv_cpu[nblock] & (1ULL << inblock))) { + keep_data[num_to_keep++] = i; + unsigned long long *p = &mask_cpu[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv_cpu[j] |= p[j]; + } + } + } + delete[] remv_cpu; + if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); + + return num_to_keep; + +#else + AT_ERROR("iou3d_nms_normal is not compiled with GPU support"); +#endif + } else { + AT_ERROR("iou3d_nms_normal is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index c5e3d1b697..7b39a5e443 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -105,6 +105,17 @@ void three_nn_forward(int b, int n, int m, Tensor unknown_tensor, void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset); +void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b, + Tensor ans_overlap); + +void iou3d_boxes_iou_bev_forward(Tensor boxes_a, Tensor boxes_b, + Tensor ans_iou); + +int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh); + +int iou3d_nms_normal_forward(Tensor boxes, Tensor keep, + float nms_overlap_thresh); + void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor, Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor); @@ -442,6 +453,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), py::arg("aligned"), py::arg("offset")); + m.def("iou3d_boxes_overlap_bev_forward", &iou3d_boxes_overlap_bev_forward, + "iou3d_boxes_overlap_bev_forward", py::arg("boxes_a"), + py::arg("boxes_b"), py::arg("ans_overlap")); + m.def("iou3d_boxes_iou_bev_forward", &iou3d_boxes_iou_bev_forward, + "iou3d_boxes_iou_bev_forward", py::arg("boxes_a"), py::arg("boxes_b"), + py::arg("ans_iou")); + m.def("iou3d_nms_forward", &iou3d_nms_forward, "iou3d_nms_forward", + py::arg("boxes"), py::arg("keep"), py::arg("nms_overlap_thresh")); + m.def("iou3d_nms_normal_forward", &iou3d_nms_normal_forward, + "iou3d_nms_normal_forward", py::arg("boxes"), py::arg("keep"), + py::arg("nms_overlap_thresh")); m.def("furthest_point_sampling_forward", &furthest_point_sampling_forward, "furthest_point_sampling_forward", py::arg("b"), py::arg("n"), py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"), diff --git a/mmcv/ops/iou3d.py b/mmcv/ops/iou3d.py new file mode 100644 index 0000000000..f22a9c82c0 --- /dev/null +++ b/mmcv/ops/iou3d.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward', + 'iou3d_nms_normal_forward' +]) + + +def boxes_iou_bev(boxes_a, boxes_b): + """Calculate boxes IoU in the Bird's Eye View. + + Args: + boxes_a (torch.Tensor): Input boxes a with shape (M, 5). + boxes_b (torch.Tensor): Input boxes b with shape (N, 5). + + Returns: + ans_iou (torch.Tensor): IoU result with shape (M, N). + """ + ans_iou = boxes_a.new_zeros( + torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) + + ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(), + boxes_b.contiguous(), ans_iou) + + return ans_iou + + +def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None): + """NMS function GPU implementation (for BEV boxes). The overlap of two + boxes for IoU calculation is defined as the exact overlapping area of the + two boxes. In this function, one can also set ``pre_max_size`` and + ``post_max_size``. + + Args: + boxes (torch.Tensor): Input boxes with the shape of [N, 5] + ([x1, y1, x2, y2, ry]). + scores (torch.Tensor): Scores of boxes with the shape of [N]. + thresh (float): Overlap threshold of NMS. + pre_max_size (int, optional): Max size of boxes before NMS. + Default: None. + post_max_size (int, optional): Max size of boxes after NMS. + Default: None. + + Returns: + torch.Tensor: Indexes after NMS. + """ + order = scores.sort(0, descending=True)[1] + + if pre_max_size is not None: + order = order[:pre_max_size] + boxes = boxes[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh) + keep = order[keep[:num_out].cuda(boxes.device)].contiguous() + if post_max_size is not None: + keep = keep[:post_max_size] + return keep + + +def nms_normal_bev(boxes, scores, thresh): + """Normal NMS function GPU implementation (for BEV boxes). The overlap of + two boxes for IoU calculation is defined as the exact overlapping area of + the two boxes WITH their yaw angle set to 0. + + Args: + boxes (torch.Tensor): Input boxes with shape (N, 5). + scores (torch.Tensor): Scores of predicted boxes with shape (N). + thresh (float): Overlap threshold of NMS. + + Returns: + torch.Tensor: Remaining indices with scores in descending order. + """ + order = scores.sort(0, descending=True)[1] + + boxes = boxes[order].contiguous() + + keep = torch.zeros(boxes.size(0), dtype=torch.long) + num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh) + return order[keep[:num_out].cuda(boxes.device)].contiguous() diff --git a/tests/test_ops/test_iou3d.py b/tests/test_ops/test_iou3d.py new file mode 100644 index 0000000000..9747e131f0 --- /dev/null +++ b/tests/test_ops/test_iou3d.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest +import torch + +from mmcv.ops import boxes_iou_bev, nms_bev, nms_normal_bev + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_boxes_iou_bev(): + np_boxes1 = np.asarray( + [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6], + [7.0, 7.0, 8.0, 8.0, 0.4]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5], + [5.0, 5.0, 6.0, 7.0, 0.4]], + dtype=np.float32) + np_expect_ious = np.asarray( + [[0.2621, 0.2948, 0.0000], [0.0549, 0.1587, 0.0000], + [0.0000, 0.0000, 0.0000]], + dtype=np.float32) + + boxes1 = torch.from_numpy(np_boxes1).cuda() + boxes2 = torch.from_numpy(np_boxes2).cuda() + + ious = boxes_iou_bev(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_nms_gpu(): + np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], + [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], + dtype=np.float32) + np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) + np_inds = np.array([1, 0, 3]) + boxes = torch.from_numpy(np_boxes) + scores = torch.from_numpy(np_scores) + inds = nms_bev(boxes.cuda(), scores.cuda(), thresh=0.3) + + assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_nms_normal_gpu(): + np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], + [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], + dtype=np.float32) + np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) + np_inds = np.array([1, 2, 0, 3]) + boxes = torch.from_numpy(np_boxes) + scores = torch.from_numpy(np_scores) + inds = nms_normal_bev(boxes.cuda(), scores.cuda(), thresh=0.3) + + assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu