From 397afec8609d736cfc0be5b40a2022d399d43c2b Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Mon, 27 Sep 2021 01:03:53 +0800 Subject: [PATCH 1/3] add ops (roiaware pool3d) in mmdet3d --- docs/understand_mmcv/ops.md | 1 + mmcv/ops/__init__.py | 7 +- .../cuda/points_in_boxes_cuda_kernel.cuh | 96 ++++++ .../cuda/roiaware_pool3d_cuda_kernel.cuh | 294 ++++++++++++++++++ .../csrc/pytorch/cuda/points_in_boxes_cuda.cu | 55 ++++ .../csrc/pytorch/cuda/roiaware_pool3d_cuda.cu | 120 +++++++ mmcv/ops/csrc/pytorch/points_in_boxes.cpp | 92 ++++++ mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp | 53 ++++ mmcv/ops/csrc/pytorch/pybind.cpp | 33 ++ mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp | 115 +++++++ mmcv/ops/points_in_boxes.py | 133 ++++++++ mmcv/ops/roiaware_pool3d.py | 115 +++++++ tests/test_ops/test_roiaware_pool3d.py | 151 +++++++++ 13 files changed, 1264 insertions(+), 1 deletion(-) create mode 100644 mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/common/cuda/roiaware_pool3d_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/points_in_boxes.cpp create mode 100644 mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp create mode 100644 mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp create mode 100644 mmcv/ops/points_in_boxes.py create mode 100644 mmcv/ops/roiaware_pool3d.py create mode 100644 tests/test_ops/test_roiaware_pool3d.py diff --git a/docs/understand_mmcv/ops.md b/docs/understand_mmcv/ops.md index 8460682c1f..53c341f673 100644 --- a/docs/understand_mmcv/ops.md +++ b/docs/understand_mmcv/ops.md @@ -15,6 +15,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - PSAMask - RoIPool - RoIAlign +- RoIAwarePool3d - SimpleRoIAlign - SigmoidFocalLoss - SoftmaxFocalLoss diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 359a13c06a..1e90dda0c2 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -27,10 +27,13 @@ from .pixel_group import pixel_group from .point_sample import (SimpleRoIAlign, point_sample, rel_roi_point_to_rel_img_point) +from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu, + points_in_boxes_part) from .psa_mask import PSAMask from .roi_align import RoIAlign, roi_align from .roi_align_rotated import RoIAlignRotated, roi_align_rotated from .roi_pool import RoIPool, roi_pool +from .roiaware_pool3d import RoIAwarePool3d from .saconv import SAConv2d from .sync_bn import SyncBatchNorm from .tin_shift import TINShift, tin_shift @@ -52,5 +55,7 @@ 'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', - 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align' + 'MultiScaleDeformableAttention', 'BorderAlign', 'RoIAwarePool3d', + 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all', + 'border_align' ] diff --git a/mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh new file mode 100644 index 0000000000..bb310c712d --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh @@ -0,0 +1,96 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef POINT_IN_BOXES_CUDA_KERNEL_CUH +#define POINT_IN_BOXES_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +template +__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz, + T &local_x, T &local_y) { + T cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +template +__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x, + T &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, + // cz in the bottom center + T x = pt[0], y = pt[1], z = pt[2]; + T cx = box3d[0], cy = box3d[1], cz = box3d[2]; + T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6]; + cz += z_size / + 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > z_size / 2.0) return 0; + lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); + float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) & + (local_y > -y_size / 2.0) & (local_y < y_size / 2.0); + return in_flag; +} + +template +__global__ void points_in_boxes_part_forward_cuda_kernel( + int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts, + int *box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box DO NOT overlaps params pts: + // (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points: + // (B, npoints), default -1 + + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= batch_size || pt_idx >= pts_num) return; + + boxes += bs_idx * boxes_num * 7; + pts += bs_idx * pts_num * 3 + pt_idx * 3; + box_idx_of_points += bs_idx * pts_num + pt_idx; + + T local_x = 0, local_y = 0; + int cur_in_flag = 0; + for (int k = 0; k < boxes_num; k++) { + cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y); + if (cur_in_flag) { + box_idx_of_points[0] = k; + break; + } + } +} + +template +__global__ void points_in_boxes_all_forward_cuda_kernel( + int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts, + int *box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box DO NOT overlaps params pts: + // (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points: + // (B, npoints), default -1 + + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= batch_size || pt_idx >= pts_num) return; + + boxes += bs_idx * boxes_num * 7; + pts += bs_idx * pts_num * 3 + pt_idx * 3; + box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num; + + T local_x = 0, local_y = 0; + int cur_in_flag = 0; + for (int k = 0; k < boxes_num; k++) { + cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y); + if (cur_in_flag) { + box_idx_of_points[k] = 1; + } + cur_in_flag = 0; + } +} + +#endif // POINT_IN_BOXES_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/cuda/roiaware_pool3d_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/roiaware_pool3d_cuda_kernel.cuh new file mode 100644 index 0000000000..6e6c265aa3 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/roiaware_pool3d_cuda_kernel.cuh @@ -0,0 +1,294 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROIAWARE_POOL3D_CUDA_KERNEL_CUH +#define ROIAWARE_POOL3D_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +template +__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz, + T &local_x, T &local_y) { + T cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +template +__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x, + T &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, + // cz in the bottom center + T x = pt[0], y = pt[1], z = pt[2]; + T cx = box3d[0], cy = box3d[1], cz = box3d[2]; + T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6]; + cz += z_size / + 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > z_size / 2.0) return 0; + lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); + float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) & + (local_y > -y_size / 2.0) & (local_y < y_size / 2.0); + return in_flag; +} + +template +__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, + int out_x, int out_y, int out_z, + const T *rois, const T *pts, + int *pts_mask) { + // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N, + // npoints): -1 means point does not in this box, otherwise: encode (x_idxs, + // y_idxs, z_idxs) by binary bit + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + int box_idx = blockIdx.y; + if (pt_idx >= pts_num || box_idx >= boxes_num) return; + + pts += pt_idx * 3; + rois += box_idx * 7; + pts_mask += box_idx * pts_num + pt_idx; + + T local_x = 0, local_y = 0; + int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y); + + pts_mask[0] = -1; + if (cur_in_flag > 0) { + T local_z = pts[2] - rois[2]; + T x_size = rois[3], y_size = rois[4], z_size = rois[5]; + + T x_res = x_size / out_x; + T y_res = y_size / out_y; + T z_res = z_size / out_z; + + unsigned int x_idx = int((local_x + x_size / 2) / x_res); + unsigned int y_idx = int((local_y + y_size / 2) / y_res); + unsigned int z_idx = int(local_z / z_res); + + x_idx = min(max(x_idx, 0), out_x - 1); + y_idx = min(max(y_idx, 0), out_y - 1); + z_idx = min(max(z_idx, 0), out_z - 1); + + unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx; +#ifdef DEBUG + printf( + "mask: pts_%d(%.3f, %.3f, %.3f), local(%.3f, %.3f, %.3f), idx(%d, %d, " + "%d), res(%.3f, %.3f, %.3f), idx_encoding=%x\n", + pt_idx, pts[0], pts[1], pts[2], local_x, local_y, local_z, x_idx, y_idx, + z_idx, x_res, y_res, z_res, idx_encoding); +#endif + + pts_mask[0] = idx_encoding; + } +} + +template +__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num, + int max_pts_each_voxel, int out_x, + int out_y, int out_z, + const int *pts_mask, + T *pts_idx_of_voxels) { + // params pts_mask: (N, npoints) 0 or 1 + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + + int box_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (box_idx >= boxes_num) return; + + int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel; + + for (int k = 0; k < pts_num; k++) { + if (pts_mask[box_idx * pts_num + k] != -1) { + unsigned int idx_encoding = pts_mask[box_idx * pts_num + k]; + unsigned int x_idx = (idx_encoding >> 16) & 0xFF; + unsigned int y_idx = (idx_encoding >> 8) & 0xFF; + unsigned int z_idx = idx_encoding & 0xFF; + unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel + + y_idx * out_z * max_pts_each_voxel + + z_idx * max_pts_each_voxel; + unsigned int cnt = pts_idx_of_voxels[base_offset]; + if (cnt < max_num_pts) { + pts_idx_of_voxels[base_offset + cnt + 1] = k; + pts_idx_of_voxels[base_offset]++; + } +#ifdef DEBUG + printf("collect: pts_%d, idx(%d, %d, %d), idx_encoding=%x\n", k, x_idx, + y_idx, z_idx, idx_encoding); +#endif + } + } +} + +template +__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const T *pts_feature, + const int *pts_idx_of_voxels, + T *pooled_features, int *argmax) { + // params pts_feature: (npoints, C) + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), + // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C) + // params argmax: (N, out_x, out_y, out_z, C) + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x; + + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x || + y_idx >= out_y || z_idx >= out_z) + return; + +#ifdef DEBUG + printf("src pts_idx_of_voxels: (%p, ), argmax: %p\n", pts_idx_of_voxels, + argmax); +#endif + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + pooled_features += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + argmax += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + int argmax_idx = -1; + float max_val = -1e50; + + int total_pts = pts_idx_of_voxels[0]; + + for (int k = 1; k <= total_pts; k++) { + if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val) { + max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx]; + argmax_idx = pts_idx_of_voxels[k]; + } + } + + if (argmax_idx != -1) { + pooled_features[0] = max_val; + } + argmax[0] = argmax_idx; + +#ifdef DEBUG + printf( + "channel_%d idx(%d, %d, %d), argmax_idx=(%d, %.3f), total=%d, after " + "pts_idx: %p, argmax: (%p, %d)\n", + channel_idx, x_idx, y_idx, z_idx, argmax_idx, max_val, total_pts, + pts_idx_of_voxels, argmax, argmax_idx); +#endif +} + +template +__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const T *pts_feature, + const int *pts_idx_of_voxels, + T *pooled_features) { + // params pts_feature: (npoints, C) + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), + // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C) + // params argmax: (N, out_x, out_y, out_z, C) + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x; + + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x || + y_idx >= out_y || z_idx >= out_z) + return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + pooled_features += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + float sum_val = 0; + int total_pts = pts_idx_of_voxels[0]; + + for (int k = 1; k <= total_pts; k++) { + sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx]; + } + + if (total_pts > 0) { + pooled_features[0] = sum_val / total_pts; + } +} + +template +__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels, + int out_x, int out_y, int out_z, + const int *argmax, + const T *grad_out, T *grad_in) { + // params argmax: (N, out_x, out_y, out_z, C) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x; + + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x || + y_idx >= out_y || z_idx >= out_z) + return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + argmax += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + grad_out += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + if (argmax[0] == -1) return; + + atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1); +} + +template +__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels, + int out_x, int out_y, int out_z, + int max_pts_each_voxel, + const int *pts_idx_of_voxels, + const T *grad_out, T *grad_in) { + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x; + + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x || + y_idx >= out_y || z_idx >= out_z) + return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + grad_out += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + int total_pts = pts_idx_of_voxels[0]; + float cur_grad = 1 / fmaxf(float(total_pts), 1.0); + for (int k = 1; k <= total_pts; k++) { + atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx, + grad_out[0] * cur_grad); + } +} + +#endif // ROIAWARE_POOL3D_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu new file mode 100644 index 0000000000..5af86121a3 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu @@ -0,0 +1,55 @@ +#include + +#include "points_in_boxes_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is + // the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x, + // y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default + // -1 + + at::cuda::CUDAGuard device_guard(boxes.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + boxes.scalar_type(), "points_in_boxes_part_forward_cuda_kernel", [&] { + points_in_boxes_part_forward_cuda_kernel<<>>( + batch_size, boxes_num, pts_num, boxes.data_ptr(), + pts.data_ptr(), box_idx_of_points.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box params pts: (B, npoints, 3) + // [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), + // default -1 + + at::cuda::CUDAGuard device_guard(boxes.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + boxes.scalar_type(), "points_in_boxes_all_forward_cuda_kernel", [&] { + points_in_boxes_all_forward_cuda_kernel<<>>( + batch_size, boxes_num, pts_num, boxes.data_ptr(), + pts.data_ptr(), box_idx_of_points.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu new file mode 100644 index 0000000000..1840553394 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu @@ -0,0 +1,120 @@ +// Modified from +// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu +// Written by Shaoshuai Shi +// All Rights Reserved 2019. + +#include + +#include "pytorch_cuda_helper.hpp" +#include "roiaware_pool3d_cuda_kernel.cuh" + +void RoiawarePool3dForwardCUDAKernelLauncher( + int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, + int out_y, int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate params pts: (npoints, 3) [x, y, z] in LiDAR coordinate params + // pts_feature: (npoints, C) params argmax: (N, out_x, out_y, out_z, C) params + // pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) params + // pooled_features: (N, out_x, out_y, out_z, C) params pool_method: 0: + // max_pool 1: avg_pool + + at::cuda::CUDAGuard device_guard(pts_feature.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int *pts_mask = NULL; + cudaMalloc(&pts_mask, boxes_num * pts_num * sizeof(int)); // (N, M) + cudaMemset(pts_mask, -1, boxes_num * pts_num * sizeof(int)); + + dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rois.scalar_type(), "generate_pts_mask_for_box3d", [&] { + generate_pts_mask_for_box3d<<>>( + boxes_num, pts_num, out_x, out_y, out_z, rois.data_ptr(), + pts.data_ptr(), pts_mask); + }); + + AT_CUDA_CHECK(cudaGetLastError()); + + // TODO: Merge the collect and pool functions, SS + + dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK)); + + AT_DISPATCH_INTEGRAL_TYPES( + pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] { + collect_inside_pts_for_box3d<<>>( + boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z, + pts_mask, pts_idx_of_voxels.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); + + dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, + boxes_num); + if (pool_method == 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + pts_feature.scalar_type(), "roiaware_maxpool3d", [&] { + roiaware_maxpool3d<<>>( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, + out_z, pts_feature.data_ptr(), + pts_idx_of_voxels.data_ptr(), + pooled_features.data_ptr(), argmax.data_ptr()); + }); + } else if (pool_method == 1) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + pts_feature.scalar_type(), "roiaware_avgpool3d", [&] { + roiaware_avgpool3d<<>>( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, + out_z, pts_feature.data_ptr(), + pts_idx_of_voxels.data_ptr(), + pooled_features.data_ptr()); + }); + } + + AT_CUDA_CHECK(cudaGetLastError()); + cudaFree(pts_mask); + +#ifdef DEBUG + cudaDeviceSynchronize(); // for using printf in kernel function +#endif +} + +void RoiawarePool3dBackwardCUDAKernelLauncher( + int boxes_num, int out_x, int out_y, int out_z, int channels, + int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax, + const Tensor grad_out, Tensor grad_in, int pool_method) { + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params argmax: (N, out_x, out_y, out_z, C) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + // params pool_method: 0: max_pool, 1: avg_pool + + at::cuda::CUDAGuard device_guard(grad_out.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, + boxes_num); + dim3 threads(THREADS_PER_BLOCK); + + if (pool_method == 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_in.scalar_type(), "roiaware_maxpool3d_backward", [&] { + roiaware_maxpool3d_backward<<>>( + boxes_num, channels, out_x, out_y, out_z, argmax.data_ptr(), + grad_out.data_ptr(), grad_in.data_ptr()); + }); + } else if (pool_method == 1) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_in.scalar_type(), "roiaware_avgpool3d_backward", [&] { + roiaware_avgpool3d_backward<<>>( + boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel, + pts_idx_of_voxels.data_ptr(), grad_out.data_ptr(), + grad_in.data_ptr()); + }); + } + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/points_in_boxes.cpp b/mmcv/ops/csrc/pytorch/points_in_boxes.cpp new file mode 100644 index 0000000000..9ebeec9ab8 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/points_in_boxes.cpp @@ -0,0 +1,92 @@ +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); + +void points_in_boxes_part_forward_cuda(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + PointsInBoxesPartForwardCUDAKernelLauncher(batch_size, boxes_num, pts_num, + boxes, pts, box_idx_of_points); +}; + +void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); + +void points_in_boxes_all_forward_cuda(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + PointsInBoxesAllForwardCUDAKernelLauncher(batch_size, boxes_num, pts_num, + boxes, pts, box_idx_of_points); +}; +#endif + +void points_in_boxes_part_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor box_idx_of_points_tensor) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box params pts: (B, npoints, 3) + // [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), + // default -1 + + if (pts_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes_tensor); + CHECK_CUDA_INPUT(pts_tensor); + CHECK_CUDA_INPUT(box_idx_of_points_tensor); + + int batch_size = boxes_tensor.size(0); + int boxes_num = boxes_tensor.size(1); + int pts_num = pts_tensor.size(1); + + const float *boxes = boxes_tensor.data_ptr(); + const float *pts = pts_tensor.data_ptr(); + int *box_idx_of_points = box_idx_of_points_tensor.data_ptr(); + + points_in_boxes_part_forward_cuda(batch_size, boxes_num, pts_num, + boxes_tensor, pts_tensor, + box_idx_of_points_tensor); +#else + AT_ERROR("points_in_boxes_part is not compiled with GPU support"); +#endif + } else { + AT_ERROR("points_in_boxes_part is not implemented on CPU"); + } +} + +void points_in_boxes_all_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor box_idx_of_points_tensor) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center. params pts: (B, npoints, 3) [x, y, z] + // in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1 + + if (pts_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(boxes_tensor); + CHECK_CUDA_INPUT(pts_tensor); + CHECK_CUDA_INPUT(box_idx_of_points_tensor); + + int batch_size = boxes_tensor.size(0); + int boxes_num = boxes_tensor.size(1); + int pts_num = pts_tensor.size(1); + + const float *boxes = boxes_tensor.data_ptr(); + const float *pts = pts_tensor.data_ptr(); + int *box_idx_of_points = box_idx_of_points_tensor.data_ptr(); + + points_in_boxes_all_forward_cuda(batch_size, boxes_num, pts_num, + boxes_tensor, pts_tensor, + box_idx_of_points_tensor); +#else + AT_ERROR("points_in_boxes_all is not compiled with GPU support"); +#endif + } else { + AT_ERROR("points_in_boxes_all is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp b/mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp new file mode 100644 index 0000000000..c16baa4cca --- /dev/null +++ b/mmcv/ops/csrc/pytorch/points_in_boxes_cpu.cpp @@ -0,0 +1,53 @@ +#include "pytorch_cpp_helper.hpp" + +inline void lidar_to_local_coords_cpu(float shift_x, float shift_y, float rz, + float &local_x, float &local_y) { + float cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +inline int check_pt_in_box3d_cpu(const float *pt, const float *box3d, + float &local_x, float &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, + // cz in the bottom center + float x = pt[0], y = pt[1], z = pt[2]; + float cx = box3d[0], cy = box3d[1], cz = box3d[2]; + float x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6]; + cz += z_size / + 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > z_size / 2.0) return 0; + lidar_to_local_coords_cpu(x - cx, y - cy, rz, local_x, local_y); + float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) & + (local_y > -y_size / 2.0) & (local_y < y_size / 2.0); + return in_flag; +} + +void points_in_boxes_cpu_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor pts_indices_tensor) { + // params boxes: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box DO NOT overlaps params pts: + // (npoints, 3) [x, y, z] in LiDAR coordinate params pts_indices: (N, npoints) + + CHECK_CONTIGUOUS(boxes_tensor); + CHECK_CONTIGUOUS(pts_tensor); + CHECK_CONTIGUOUS(pts_indices_tensor); + + int boxes_num = boxes_tensor.size(0); + int pts_num = pts_tensor.size(0); + + const float *boxes = boxes_tensor.data_ptr(); + const float *pts = pts_tensor.data_ptr(); + int *pts_indices = pts_indices_tensor.data_ptr(); + + float local_x = 0, local_y = 0; + for (int i = 0; i < boxes_num; i++) { + for (int j = 0; j < pts_num; j++) { + int cur_in_flag = + check_pt_in_box3d_cpu(pts + j * 3, boxes + i * 7, local_x, local_y); + pts_indices[i * pts_num + j] = cur_in_flag; + } + } +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 6ecdf763cf..a73a21d1a0 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -221,6 +221,22 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes, const Tensor &argmax_idx, Tensor grad_input, const int pool_size); +void points_in_boxes_cpu_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor pts_indices_tensor); + +void points_in_boxes_part_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor box_idx_of_points_tensor); + +void points_in_boxes_all_forward(Tensor boxes_tensor, Tensor pts_tensor, + Tensor box_idx_of_points_tensor); + +void roiaware_pool3d_forward(Tensor rois, Tensor pts, Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +void roiaware_pool3d_backward(Tensor pts_idx_of_voxels, Tensor argmax, + Tensor grad_out, Tensor grad_in, int pool_method); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), @@ -444,4 +460,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "backward function of border_align", py::arg("grad_output"), py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), py::arg("pool_size")); + m.def("points_in_boxes_cpu_forward", &points_in_boxes_cpu_forward, + "points_in_boxes_cpu_forward", py::arg("boxes_tensor"), + py::arg("pts_tensor"), py::arg("pts_indices_tensor")); + m.def("points_in_boxes_part_forward", &points_in_boxes_part_forward, + "points_in_boxes_part_forward", py::arg("boxes_tensor"), + py::arg("pts_tensor"), py::arg("box_idx_of_points_tensor")); + m.def("points_in_boxes_all_forward", &points_in_boxes_all_forward, + "points_in_boxes_all_forward", py::arg("boxes_tensor"), + py::arg("pts_tensor"), py::arg("box_idx_of_points_tensor")); + m.def("roiaware_pool3d_forward", &roiaware_pool3d_forward, + "roiaware_pool3d_forward", py::arg("rois"), py::arg("pts"), + py::arg("pts_feature"), py::arg("argmax"), py::arg("pts_idx_of_voxels"), + py::arg("pooled_features"), py::arg("pool_method")); + m.def("roiaware_pool3d_backward", &roiaware_pool3d_backward, + "roiaware_pool3d_backward", py::arg("pts_idx_of_voxels"), + py::arg("argmax"), py::arg("grad_out"), py::arg("grad_in"), + py::arg("pool_method")); } diff --git a/mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp b/mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp new file mode 100644 index 0000000000..c7e267f8f0 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/roiaware_pool3d.cpp @@ -0,0 +1,115 @@ +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void RoiawarePool3dForwardCUDAKernelLauncher( + int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, + int out_y, int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +void roiaware_pool3d_forward_cuda(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, + const Tensor pts, const Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + RoiawarePool3dForwardCUDAKernelLauncher( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z, + rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features, + pool_method); +}; + +void RoiawarePool3dBackwardCUDAKernelLauncher( + int boxes_num, int out_x, int out_y, int out_z, int channels, + int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax, + const Tensor grad_out, Tensor grad_in, int pool_method); + +void roiaware_pool3d_backward_cuda(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method) { + RoiawarePool3dBackwardCUDAKernelLauncher( + boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, + pts_idx_of_voxels, argmax, grad_out, grad_in, pool_method); +}; +#endif + +void roiaware_pool3d_forward(Tensor rois, Tensor pts, Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, ry] in LiDAR + // coordinate + // params pts: (npoints, 3) [x, y, z] in LiDAR coordinate + // params pts_feature: (npoints, C) + // params argmax: (N, out_x, out_y, out_z, C) + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params pooled_features: (N, out_x, out_y, out_z, C) + // params pool_method: 0: max_pool 1: avg_pool + if (pts.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(rois); + CHECK_CUDA_INPUT(pts); + CHECK_CUDA_INPUT(pts_feature); + CHECK_CUDA_INPUT(argmax); + CHECK_CUDA_INPUT(pts_idx_of_voxels); + CHECK_CUDA_INPUT(pooled_features); + + int boxes_num = rois.size(0); + int pts_num = pts.size(0); + int channels = pts_feature.size(1); + int max_pts_each_voxel = + pts_idx_of_voxels.size(4); // index 0 is the counter + int out_x = pts_idx_of_voxels.size(1); + int out_y = pts_idx_of_voxels.size(2); + int out_z = pts_idx_of_voxels.size(3); + assert((out_x < 256) && (out_y < 256) && + (out_z < 256)); // we encode index with 8bit + + roiaware_pool3d_forward_cuda(boxes_num, pts_num, channels, + max_pts_each_voxel, out_x, out_y, out_z, rois, + pts, pts_feature, argmax, pts_idx_of_voxels, + pooled_features, pool_method); +#else + AT_ERROR("roiaware_pool3d is not compiled with GPU support"); +#endif + } else { + AT_ERROR("roiaware_pool3d is not implemented on CPU"); + } +} + +void roiaware_pool3d_backward(Tensor pts_idx_of_voxels, Tensor argmax, + Tensor grad_out, Tensor grad_in, + int pool_method) { + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params argmax: (N, out_x, out_y, out_z, C) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + // params pool_method: 0: max_pool 1: avg_pool + + if (grad_in.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(pts_idx_of_voxels); + CHECK_CUDA_INPUT(argmax); + CHECK_CUDA_INPUT(grad_out); + CHECK_CUDA_INPUT(grad_in); + + int boxes_num = pts_idx_of_voxels.size(0); + int out_x = pts_idx_of_voxels.size(1); + int out_y = pts_idx_of_voxels.size(2); + int out_z = pts_idx_of_voxels.size(3); + int max_pts_each_voxel = + pts_idx_of_voxels.size(4); // index 0 is the counter + int channels = grad_out.size(4); + + roiaware_pool3d_backward_cuda(boxes_num, out_x, out_y, out_z, channels, + max_pts_each_voxel, pts_idx_of_voxels, argmax, + grad_out, grad_in, pool_method); +#else + AT_ERROR("roiaware_pool3d is not compiled with GPU support"); +#endif + } else { + AT_ERROR("roiaware_pool3d is not implemented on CPU"); + } +} diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py new file mode 100644 index 0000000000..cec354e5a5 --- /dev/null +++ b/mmcv/ops/points_in_boxes.py @@ -0,0 +1,133 @@ +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward', + 'points_in_boxes_all_forward' +]) + + +def points_in_boxes_part(points, boxes): + """Find the box in which each point is (CUDA). + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in + LiDAR/DEPTH coordinate, (x, y, z) is the bottom center + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M), default background = -1 + """ + assert points.shape[0] == boxes.shape[0], \ + f'Points and boxes should have the same batch size, ' \ + f'got {points.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + f'boxes dimension should be 7, ' \ + f'got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + f'points dimension should be 3, ' \ + f'got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + + box_idxs_of_pts = points.new_zeros((batch_size, num_points), + dtype=torch.int).fill_(-1) + + # If manually put the tensor 'points' or 'boxes' on a device + # which is not the current device, some temporary variables + # will be created on the current device in the cuda op, + # and the output will be incorrect. + # Therefore, we force the current device to be the same + # as the device of the tensors if it was not. + # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305 + # for the incorrect output before the fix. + points_device = points.get_device() + assert points_device == boxes.get_device(), \ + 'Points and boxes should be put on the same device' + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + + ext_module.points_in_boxes_part_forward(boxes.contiguous(), + points.contiguous(), + box_idxs_of_pts) + + return box_idxs_of_pts + + +def points_in_boxes_cpu(points, boxes): + """Find all boxes in which each point is (CPU). The CPU version of + :meth:`points_in_boxes_all`. + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in + LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], + (x, y, z) is the bottom center. + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0. + """ + assert points.shape[0] == boxes.shape[0], \ + f'Points and boxes should have the same batch size, ' \ + f'got {points.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + f'boxes dimension should be 7, ' \ + f'got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + f'points dimension should be 3, ' \ + f'got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + num_boxes = boxes.shape[1] + + point_indices = points.new_zeros((batch_size, num_boxes, num_points), + dtype=torch.int) + for b in range(batch_size): + ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(), + points[b].float().contiguous(), + point_indices[b]) + point_indices = point_indices.transpose(1, 2) + + return point_indices + + +def points_in_boxes_all(points, boxes): + """Find all boxes in which each point is (CUDA). + + Args: + points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate + boxes (torch.Tensor): [B, T, 7], + num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], + (x, y, z) is the bottom center. + + Returns: + box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0. + """ + assert boxes.shape[0] == points.shape[0], \ + f'Points and boxes should have the same batch size, ' \ + f'got {boxes.shape[0]} and {boxes.shape[0]}' + assert boxes.shape[2] == 7, \ + f'boxes dimension should be 7, ' \ + f'got unexpected shape {boxes.shape[2]}' + assert points.shape[2] == 3, \ + f'points dimension should be 3, ' \ + f'got unexpected shape {points.shape[2]}' + batch_size, num_points, _ = points.shape + num_boxes = boxes.shape[1] + + box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes), + dtype=torch.int).fill_(0) + + # Same reason as line 25-32 + points_device = points.get_device() + assert points_device == boxes.get_device(), \ + 'Points and boxes should be put on the same device' + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + + ext_module.points_in_boxes_all_forward(boxes.contiguous(), + points.contiguous(), + box_idxs_of_pts) + + return box_idxs_of_pts diff --git a/mmcv/ops/roiaware_pool3d.py b/mmcv/ops/roiaware_pool3d.py new file mode 100644 index 0000000000..38ff927515 --- /dev/null +++ b/mmcv/ops/roiaware_pool3d.py @@ -0,0 +1,115 @@ +import torch +from torch import nn as nn +from torch.autograd import Function + +import mmcv +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward']) + + +class RoIAwarePool3d(nn.Module): + + def __init__(self, out_size, max_pts_per_voxel=128, mode='max'): + super().__init__() + """RoIAwarePool3d module + + Args: + out_size (int or tuple): n or [n1, n2, n3] + max_pts_per_voxel (int): m + mode (str): 'max' or 'avg' + """ + self.out_size = out_size + self.max_pts_per_voxel = max_pts_per_voxel + assert mode in ['max', 'avg'] + pool_method_map = {'max': 0, 'avg': 1} + self.mode = pool_method_map[mode] + + def forward(self, rois, pts, pts_feature): + """RoIAwarePool3d module forward. + + Args: + rois (torch.Tensor): [N, 7],in LiDAR coordinate, + (x, y, z) is the bottom center of rois + pts (torch.Tensor): [npoints, 3] + pts_feature (torch.Tensor): [npoints, C] + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] + """ + + return RoIAwarePool3dFunction.apply(rois, pts, pts_feature, + self.out_size, + self.max_pts_per_voxel, self.mode) + + +class RoIAwarePool3dFunction(Function): + + @staticmethod + def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, + mode): + """RoIAwarePool3d function forward. + + Args: + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois + pts (torch.Tensor): [npoints, 3] + pts_feature (torch.Tensor): [npoints, C] + out_size (int or tuple): n or [n1, n2, n3] + max_pts_per_voxel (int): m + mode (int): 0 (max pool) or 1 (average pool) + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] + """ + + if isinstance(out_size, int): + out_x = out_y = out_z = out_size + else: + assert len(out_size) == 3 + assert mmcv.is_tuple_of(out_size, int) + out_x, out_y, out_z = out_size + + num_rois = rois.shape[0] + num_channels = pts_feature.shape[-1] + num_pts = pts.shape[0] + + pooled_features = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, num_channels)) + argmax = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int) + pts_idx_of_voxels = pts_feature.new_zeros( + (num_rois, out_x, out_y, out_z, max_pts_per_voxel), + dtype=torch.int) + + ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax, + pts_idx_of_voxels, pooled_features, + mode) + + ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode, + num_pts, num_channels) + return pooled_features + + @staticmethod + def backward(ctx, grad_out): + """RoIAwarePool3d function forward. + + Args: + grad_out (torch.Tensor): [N, out_x, out_y, out_z, C] + Returns: + grad_in (torch.Tensor): [npoints, C] + """ + ret = ctx.roiaware_pool3d_for_backward + pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret + + grad_in = grad_out.new_zeros((num_pts, num_channels)) + ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax, + grad_out.contiguous(), grad_in, + mode) + + return None, None, grad_in, None, None, None + + +if __name__ == '__main__': + pass diff --git a/tests/test_ops/test_roiaware_pool3d.py b/tests/test_ops/test_roiaware_pool3d.py new file mode 100644 index 0000000000..c7085f591d --- /dev/null +++ b/tests/test_ops/test_roiaware_pool3d.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch + +from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu, + points_in_boxes_part) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_RoIAwarePool3d(): + roiaware_pool3d_max = RoIAwarePool3d( + out_size=4, max_pts_per_voxel=128, mode='max') + roiaware_pool3d_avg = RoIAwarePool3d( + out_size=4, max_pts_per_voxel=128, mode='avg') + rois = torch.tensor( + [[1.0, 2.0, 3.0, 5.0, 4.0, 6.0, -0.3 - np.pi / 2], + [-10.0, 23.0, 16.0, 20.0, 10.0, 20.0, -0.5 - np.pi / 2]], + dtype=torch.float32).cuda( + ) # boxes (m, 7) with bottom center in lidar coordinate + pts = torch.tensor( + [[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], + [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]], + dtype=torch.float32).cuda() # points (n, 3) in lidar coordinate + pts_feature = pts.clone() + + pooled_features_max = roiaware_pool3d_max( + rois=rois, pts=pts, pts_feature=pts_feature) + assert pooled_features_max.shape == torch.Size([2, 4, 4, 4, 3]) + assert torch.allclose(pooled_features_max.sum(), + torch.tensor(51.100).cuda(), 1e-3) + + pooled_features_avg = roiaware_pool3d_avg( + rois=rois, pts=pts, pts_feature=pts_feature) + assert pooled_features_avg.shape == torch.Size([2, 4, 4, 4, 3]) + assert torch.allclose(pooled_features_avg.sum(), + torch.tensor(49.750).cuda(), 1e-3) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_points_in_boxes_part(): + boxes = torch.tensor( + [[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3]], + [[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], + dtype=torch.float32).cuda( + ) # boxes (b, t, 7) with bottom center in lidar coordinate + pts = torch.tensor( + [[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2]], + [[3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], [-21.3, -52, -5], + [0, 0, 0], [6, 7, 8], [-2, -3, -4], [6, 4, 9]]], + dtype=torch.float32).cuda() # points (b, m, 3) in lidar coordinate + + point_indices = points_in_boxes_part(points=pts, boxes=boxes) + expected_point_indices = torch.tensor( + [[0, 0, 0, 0, 0, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1]], + dtype=torch.int32).cuda() + assert point_indices.shape == torch.Size([2, 8]) + assert (point_indices == expected_point_indices).all() + + boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]], + dtype=torch.float32).cuda() # 30 degrees + pts = torch.tensor( + [[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0], + [-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]], + dtype=torch.float32).cuda() + point_indices = points_in_boxes_part(points=pts, boxes=boxes) + expected_point_indices = torch.tensor([[-1, -1, 0, -1, 0, -1, -1, -1]], + dtype=torch.int32).cuda() + assert (point_indices == expected_point_indices).all() + + if torch.cuda.device_count() > 1: + pts = pts.to('cuda:1') + boxes = boxes.to('cuda:1') + expected_point_indices = expected_point_indices.to('cuda:1') + point_indices = points_in_boxes_part(points=pts, boxes=boxes) + assert point_indices.shape == torch.Size([2, 8]) + assert (point_indices == expected_point_indices).all() + + +def test_points_in_boxes_cpu(): + boxes = torch.tensor( + [[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], + [-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], + dtype=torch.float32 + ) # boxes (m, 7) with bottom center in lidar coordinate + pts = torch.tensor( + [[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [ + -16, -18, 9 + ], [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]]], + dtype=torch.float32) # points (n, 3) in lidar coordinate + + point_indices = points_in_boxes_cpu(points=pts, boxes=boxes) + expected_point_indices = torch.tensor( + [[[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [0, 1], [0, 0], [0, 0], + [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]], + dtype=torch.int32) + assert point_indices.shape == torch.Size([1, 15, 2]) + assert (point_indices == expected_point_indices).all() + + boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]], + dtype=torch.float32) # 30 degrees + pts = torch.tensor( + [[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0], + [-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]], + dtype=torch.float32) + point_indices = points_in_boxes_cpu(points=pts, boxes=boxes) + expected_point_indices = torch.tensor( + [[[0], [0], [1], [0], [1], [0], [0], [0]]], dtype=torch.int32) + assert (point_indices == expected_point_indices).all() + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_points_in_boxes_all(): + + boxes = torch.tensor( + [[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3], + [-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]], + dtype=torch.float32).cuda( + ) # boxes (m, 7) with bottom center in lidar coordinate + pts = torch.tensor( + [[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], + [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], + [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [ + -16, -18, 9 + ], [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]]], + dtype=torch.float32).cuda() # points (n, 3) in lidar coordinate + + point_indices = points_in_boxes_all(points=pts, boxes=boxes) + expected_point_indices = torch.tensor( + [[[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [0, 1], [0, 0], [0, 0], + [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]], + dtype=torch.int32).cuda() + assert point_indices.shape == torch.Size([1, 15, 2]) + assert (point_indices == expected_point_indices).all() + + if torch.cuda.device_count() > 1: + pts = pts.to('cuda:1') + boxes = boxes.to('cuda:1') + expected_point_indices = expected_point_indices.to('cuda:1') + point_indices = points_in_boxes_all(points=pts, boxes=boxes) + assert point_indices.shape == torch.Size([1, 15, 2]) + assert (point_indices == expected_point_indices).all() From 689ddec103f1b24dda87fc4b61e4ef28e73513b0 Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Sun, 10 Oct 2021 03:59:14 +0800 Subject: [PATCH 2/3] refactor code --- .../csrc/pytorch/cuda/roiaware_pool3d_cuda.cu | 22 ++++--- mmcv/ops/roiaware_pool3d.py | 64 +++++++++---------- tests/test_ops/test_roiaware_pool3d.py | 16 ----- 3 files changed, 42 insertions(+), 60 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu index 1840553394..2fe3f7b5cf 100644 --- a/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/roiaware_pool3d_cuda.cu @@ -32,9 +32,10 @@ void RoiawarePool3dForwardCUDAKernelLauncher( AT_DISPATCH_FLOATING_TYPES_AND_HALF( rois.scalar_type(), "generate_pts_mask_for_box3d", [&] { - generate_pts_mask_for_box3d<<>>( - boxes_num, pts_num, out_x, out_y, out_z, rois.data_ptr(), - pts.data_ptr(), pts_mask); + generate_pts_mask_for_box3d + <<>>( + boxes_num, pts_num, out_x, out_y, out_z, + rois.data_ptr(), pts.data_ptr(), pts_mask); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -45,9 +46,10 @@ void RoiawarePool3dForwardCUDAKernelLauncher( AT_DISPATCH_INTEGRAL_TYPES( pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] { - collect_inside_pts_for_box3d<<>>( - boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z, - pts_mask, pts_idx_of_voxels.data_ptr()); + collect_inside_pts_for_box3d + <<>>( + boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z, + pts_mask, pts_idx_of_voxels.data_ptr()); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -57,7 +59,7 @@ void RoiawarePool3dForwardCUDAKernelLauncher( if (pool_method == 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( pts_feature.scalar_type(), "roiaware_maxpool3d", [&] { - roiaware_maxpool3d<<>>( + roiaware_maxpool3d<<>>( boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z, pts_feature.data_ptr(), pts_idx_of_voxels.data_ptr(), @@ -66,7 +68,7 @@ void RoiawarePool3dForwardCUDAKernelLauncher( } else if (pool_method == 1) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( pts_feature.scalar_type(), "roiaware_avgpool3d", [&] { - roiaware_avgpool3d<<>>( + roiaware_avgpool3d<<>>( boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z, pts_feature.data_ptr(), pts_idx_of_voxels.data_ptr(), @@ -102,14 +104,14 @@ void RoiawarePool3dBackwardCUDAKernelLauncher( if (pool_method == 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_in.scalar_type(), "roiaware_maxpool3d_backward", [&] { - roiaware_maxpool3d_backward<<>>( + roiaware_maxpool3d_backward<<>>( boxes_num, channels, out_x, out_y, out_z, argmax.data_ptr(), grad_out.data_ptr(), grad_in.data_ptr()); }); } else if (pool_method == 1) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_in.scalar_type(), "roiaware_avgpool3d_backward", [&] { - roiaware_avgpool3d_backward<<>>( + roiaware_avgpool3d_backward<<>>( boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel, pts_idx_of_voxels.data_ptr(), grad_out.data_ptr(), grad_in.data_ptr()); diff --git a/mmcv/ops/roiaware_pool3d.py b/mmcv/ops/roiaware_pool3d.py index 38ff927515..52b021bcd7 100644 --- a/mmcv/ops/roiaware_pool3d.py +++ b/mmcv/ops/roiaware_pool3d.py @@ -10,30 +10,34 @@ class RoIAwarePool3d(nn.Module): + """Encode the geometry-specific features of each 3D proposal. Paper + reference: https://arxiv.org/pdf/1907.03670.pdf. + + Args: + out_size (int or tuple): The size of output features. n or + [n1, n2, n3]. + max_pts_per_voxel (int, optional): The maximum number of points per + voxel. Default: 128. + mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'. + Default: 'max'. + """ def __init__(self, out_size, max_pts_per_voxel=128, mode='max'): super().__init__() - """RoIAwarePool3d module - Args: - out_size (int or tuple): n or [n1, n2, n3] - max_pts_per_voxel (int): m - mode (str): 'max' or 'avg' - """ self.out_size = out_size self.max_pts_per_voxel = max_pts_per_voxel assert mode in ['max', 'avg'] - pool_method_map = {'max': 0, 'avg': 1} - self.mode = pool_method_map[mode] + pool_mapping = {'max': 0, 'avg': 1} + self.mode = pool_mapping[mode] def forward(self, rois, pts, pts_feature): - """RoIAwarePool3d module forward. - + """ Args: - rois (torch.Tensor): [N, 7],in LiDAR coordinate, - (x, y, z) is the bottom center of rois - pts (torch.Tensor): [npoints, 3] - pts_feature (torch.Tensor): [npoints, C] + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois. + pts (torch.Tensor): [npoints, 3], coordinates of input points. + pts_feature (torch.Tensor): [npoints, C], features of input points. Returns: pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] @@ -49,19 +53,22 @@ class RoIAwarePool3dFunction(Function): @staticmethod def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, mode): - """RoIAwarePool3d function forward. - + """ Args: rois (torch.Tensor): [N, 7], in LiDAR coordinate, - (x, y, z) is the bottom center of rois - pts (torch.Tensor): [npoints, 3] - pts_feature (torch.Tensor): [npoints, C] - out_size (int or tuple): n or [n1, n2, n3] - max_pts_per_voxel (int): m - mode (int): 0 (max pool) or 1 (average pool) + (x, y, z) is the bottom center of rois. + pts (torch.Tensor): [npoints, 3], coordinates of input points. + pts_feature (torch.Tensor): [npoints, C], features of input points. + out_size (int or tuple): The size of output features. n or + [n1, n2, n3]. + max_pts_per_voxel (int): The maximum number of points per voxel. + Default: 128. + mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average + pool). Returns: - pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output + pooled features. """ if isinstance(out_size, int): @@ -93,13 +100,6 @@ def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, @staticmethod def backward(ctx, grad_out): - """RoIAwarePool3d function forward. - - Args: - grad_out (torch.Tensor): [N, out_x, out_y, out_z, C] - Returns: - grad_in (torch.Tensor): [npoints, C] - """ ret = ctx.roiaware_pool3d_for_backward pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret @@ -109,7 +109,3 @@ def backward(ctx, grad_out): mode) return None, None, grad_in, None, None, None - - -if __name__ == '__main__': - pass diff --git a/tests/test_ops/test_roiaware_pool3d.py b/tests/test_ops/test_roiaware_pool3d.py index c7085f591d..1d63e398da 100644 --- a/tests/test_ops/test_roiaware_pool3d.py +++ b/tests/test_ops/test_roiaware_pool3d.py @@ -74,14 +74,6 @@ def test_points_in_boxes_part(): dtype=torch.int32).cuda() assert (point_indices == expected_point_indices).all() - if torch.cuda.device_count() > 1: - pts = pts.to('cuda:1') - boxes = boxes.to('cuda:1') - expected_point_indices = expected_point_indices.to('cuda:1') - point_indices = points_in_boxes_part(points=pts, boxes=boxes) - assert point_indices.shape == torch.Size([2, 8]) - assert (point_indices == expected_point_indices).all() - def test_points_in_boxes_cpu(): boxes = torch.tensor( @@ -141,11 +133,3 @@ def test_points_in_boxes_all(): dtype=torch.int32).cuda() assert point_indices.shape == torch.Size([1, 15, 2]) assert (point_indices == expected_point_indices).all() - - if torch.cuda.device_count() > 1: - pts = pts.to('cuda:1') - boxes = boxes.to('cuda:1') - expected_point_indices = expected_point_indices.to('cuda:1') - point_indices = points_in_boxes_all(points=pts, boxes=boxes) - assert point_indices.shape == torch.Size([1, 15, 2]) - assert (point_indices == expected_point_indices).all() From 1b30e0995f14b349d5e7e00fc15501efdf24b39a Mon Sep 17 00:00:00 2001 From: hdc Date: Wed, 20 Oct 2021 22:56:40 +0800 Subject: [PATCH 3/3] fix typo --- .../csrc/pytorch/cuda/points_in_boxes_cuda.cu | 19 +++++++++++++------ mmcv/ops/roiaware_pool3d.py | 7 ++++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu index 5af86121a3..17e6441ba4 100644 --- a/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/points_in_boxes_cuda.cu @@ -1,3 +1,8 @@ +// Modified from +// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu +// Written by Shaoshuai Shi +// All Rights Reserved 2019. + #include #include "points_in_boxes_cuda_kernel.cuh" @@ -21,9 +26,10 @@ void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num, AT_DISPATCH_FLOATING_TYPES_AND_HALF( boxes.scalar_type(), "points_in_boxes_part_forward_cuda_kernel", [&] { - points_in_boxes_part_forward_cuda_kernel<<>>( - batch_size, boxes_num, pts_num, boxes.data_ptr(), - pts.data_ptr(), box_idx_of_points.data_ptr()); + points_in_boxes_part_forward_cuda_kernel + <<>>( + batch_size, boxes_num, pts_num, boxes.data_ptr(), + pts.data_ptr(), box_idx_of_points.data_ptr()); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -46,9 +52,10 @@ void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num, AT_DISPATCH_FLOATING_TYPES_AND_HALF( boxes.scalar_type(), "points_in_boxes_all_forward_cuda_kernel", [&] { - points_in_boxes_all_forward_cuda_kernel<<>>( - batch_size, boxes_num, pts_num, boxes.data_ptr(), - pts.data_ptr(), box_idx_of_points.data_ptr()); + points_in_boxes_all_forward_cuda_kernel + <<>>( + batch_size, boxes_num, pts_num, boxes.data_ptr(), + pts.data_ptr(), box_idx_of_points.data_ptr()); }); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/mmcv/ops/roiaware_pool3d.py b/mmcv/ops/roiaware_pool3d.py index e2925aab13..e593c7052f 100644 --- a/mmcv/ops/roiaware_pool3d.py +++ b/mmcv/ops/roiaware_pool3d.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import torch from torch import nn as nn from torch.autograd import Function @@ -10,10 +11,10 @@ class RoIAwarePool3d(nn.Module): - """Encode the geometry-specific features of each 3D proposal. Paper. + """Encode the geometry-specific features of each 3D proposal. - Please refer to `Paper of PartA2 `_ - for more details. + Please refer to `PartA2 `_ for more + details. Args: out_size (int or tuple): The size of output features. n or