Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add roiaware pool3d ops from mmdet3d #1382

Merged
merged 6 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- RoIPointPool3d
- RoIPool
- RoIAlign
- RoIAwarePool3d
- SimpleRoIAlign
- SigmoidFocalLoss
- SoftmaxFocalLoss
Expand Down
1 change: 1 addition & 0 deletions docs_zh_CN/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- RoIPointPool3d
- RoIPool
- RoIAlign
- RoIAwarePool3d
- SimpleRoIAlign
- SigmoidFocalLoss
- SoftmaxFocalLoss
Expand Down
101 changes: 81 additions & 20 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@
from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .points_sampler import PointsSampler
from .psa_mask import PSAMask
from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_pool import RoIPool, roi_pool
from .roiaware_pool3d import RoIAwarePool3d
from .roipoint_pool3d import RoIPointPool3d
from .saconv import SAConv2d
from .scatter_points import DynamicScatter, dynamic_scatter
Expand All @@ -50,24 +53,82 @@
from .voxelize import Voxelization, voxelization

__all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
'get_compiler_version', 'get_compiling_cuda_version',
'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention',
'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter',
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
'bbox_overlaps',
'CARAFE',
'CARAFENaive',
'CARAFEPack',
'carafe',
'carafe_naive',
'CornerPool',
'DeformConv2d',
'DeformConv2dPack',
'deform_conv2d',
'DeformRoIPool',
'DeformRoIPoolPack',
'ModulatedDeformRoIPoolPack',
'deform_roi_pool',
'SigmoidFocalLoss',
'SoftmaxFocalLoss',
'sigmoid_focal_loss',
'softmax_focal_loss',
'get_compiler_version',
'get_compiling_cuda_version',
'get_onnxruntime_op_path',
'MaskedConv2d',
'masked_conv2d',
'ModulatedDeformConv2d',
'ModulatedDeformConv2dPack',
'modulated_deform_conv2d',
'batched_nms',
'nms',
'soft_nms',
'nms_match',
'RoIAlign',
'roi_align',
'RoIPool',
'roi_pool',
'SyncBatchNorm',
'Conv2d',
'ConvTranspose2d',
'Linear',
'MaxPool2d',
'CrissCrossAttention',
'PSAMask',
'point_sample',
'rel_roi_point_to_rel_img_point',
'SimpleRoIAlign',
'SAConv2d',
'TINShift',
'tin_shift',
'assign_score_withk',
'box_iou_rotated',
'RoIPointPool3d',
'nms_rotated',
'knn',
'ball_query',
'upfirdn2d',
'FusedBiasLeakyReLU',
'fused_bias_leakyrelu',
'RoIAlignRotated',
'roi_align_rotated',
'pixel_group',
'contour_expand',
'three_nn',
'three_interpolate',
'MultiScaleDeformableAttention',
'Voxelization',
'voxelization',
'dynamic_scatter',
'DynamicScatter',
'BorderAlign',
'border_align',
'gather_points',
'furthest_point_sample',
'furthest_point_sample_with_dist',
'PointsSampler',
'Correlation',
'RoIAwarePool3d',
'points_in_boxes_part',
'points_in_boxes_cpu',
'points_in_boxes_all',
]
93 changes: 93 additions & 0 deletions mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINT_IN_BOXES_CUDA_KERNEL_CUH
#define POINT_IN_BOXES_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}

template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
// cz in the bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size /
2.0; // shift to the center since cz in box3d is the bottom center

if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}

template <typename T>
__global__ void points_in_boxes_part_forward_cuda_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1

int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;

boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num + pt_idx;

T local_x = 0, local_y = 0;
int cur_in_flag = 0;
for (int k = 0; k < boxes_num; k++) {
cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[0] = k;
break;
}
}
}

template <typename T>
__global__ void points_in_boxes_all_forward_cuda_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1

int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;

boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num;

T local_x = 0, local_y = 0;
for (int k = 0; k < boxes_num; k++) {
const int cur_in_flag =
check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[k] = 1;
}
}
}

#endif // POINT_IN_BOXES_CUDA_KERNEL_CUH
Loading