Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add group points ops from mmdet3d #1415

Merged
merged 10 commits into from
Oct 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
- GroupPoints
- KNN
- MaskedConv
- NMS
Expand Down
13 changes: 7 additions & 6 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
furthest_point_sample_with_dist)
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .gather_points import gather_points
from .group_points import GroupAll, QueryAndGroup, grouping_operation
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
Expand Down Expand Up @@ -68,13 +69,13 @@
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
'upfirdn2d', 'FusedBiasLeakyReLU', 'boxes_iou_bev', 'nms_bev',
'nms_normal_bev', 'fused_bias_leakyrelu', 'RoIAlignRotated',
'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter',
'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu',
'points_in_boxes_all'
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
]
63 changes: 63 additions & 0 deletions mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
#ifndef GROUP_POINTS_CUDA_KERNEL_CUH
#define GROUP_POINTS_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__global__ void group_points_forward_cuda_kernel(int b, int c, int n,
int npoints, int nsample,
const T *points,
const int *__restrict__ idx,
T *out) {
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;

int sample_idx = index % nsample;

idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
int in_idx = bs_idx * c * n + c_idx * n + idx[0];
int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample +
pt_idx * nsample + sample_idx;

out[out_idx] = points[in_idx];
}

template <typename T>
__global__ void group_points_backward_cuda_kernel(int b, int c, int n,
int npoints, int nsample,
const T *grad_out,
const int *__restrict__ idx,
T *grad_points) {
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;

int sample_idx = index % nsample;
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample +
pt_idx * nsample + sample_idx;
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;

atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]);
}

#endif // GROUP_POINTS_CUDA_KERNEL_CUH
61 changes: 61 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
#include <stdio.h>
#include <stdlib.h>

#include "group_points_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"

void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
int nsample, const Tensor points,
const Tensor idx, Tensor out) {
// points: (B, C, N)
// idx: (B, npoints, nsample)
// output:
// out: (B, C, npoints, nsample)

at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "group_points_forward_cuda_kernel", [&] {
group_points_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, n, npoints, nsample, points.data_ptr<scalar_t>(),
idx.data_ptr<int>(), out.data_ptr<scalar_t>());
});

AT_CUDA_CHECK(cudaGetLastError());
}

void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
int nsample, const Tensor grad_out,
const Tensor idx,
Tensor grad_points) {
// grad_out: (B, C, npoints, nsample)
// idx: (B, npoints, nsample)
// output:
// grad_points: (B, C, N)

at::cuda::CUDAGuard device_guard(grad_out.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "group_points_backward_cuda_kernel", [&] {
group_points_backward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, n, npoints, nsample, grad_out.data_ptr<scalar_t>(),
idx.data_ptr<int>(), grad_points.data_ptr<scalar_t>());
});

AT_CUDA_CHECK(cudaGetLastError());
}
58 changes: 58 additions & 0 deletions mmcv/ops/csrc/pytorch/group_points.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points.cpp

#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
int nsample, const Tensor points,
const Tensor idx, Tensor out);
void group_points_forward_cuda(int b, int c, int n, int npoints, int nsample,
const Tensor points, const Tensor idx,
Tensor out) {
GroupPointsForwardCUDAKernelLauncher(b, c, n, npoints, nsample, points, idx,
out);
};

void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
int nsample, const Tensor grad_out,
const Tensor idx,
Tensor grad_points);
void group_points_backward_cuda(int b, int c, int n, int npoints, int nsample,
const Tensor grad_out, const Tensor idx,
Tensor grad_points) {
GroupPointsBackwardCUDAKernelLauncher(b, c, n, npoints, nsample, grad_out,
idx, grad_points);
};
#endif

void group_points_forward(int b, int c, int n, int npoints, int nsample,
Tensor points_tensor, Tensor idx_tensor,
Tensor out_tensor) {
if (points_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
group_points_forward_cuda(b, c, n, npoints, nsample, points_tensor,
idx_tensor, out_tensor);
#else
AT_ERROR("group_points is not compiled with GPU support");
#endif
} else {
AT_ERROR("group_points is not implemented on CPU");
}
}

void group_points_backward(int b, int c, int n, int npoints, int nsample,
Tensor grad_out_tensor, Tensor idx_tensor,
Tensor grad_points_tensor) {
if (grad_out_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
group_points_backward_cuda(b, c, n, npoints, nsample, grad_out_tensor,
idx_tensor, grad_points_tensor);
#else
AT_ERROR("group_points is not compiled with GPU support");
#endif
} else {
AT_ERROR("group_points is not implemented on CPU");
}
}
20 changes: 20 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois,
int pooled_width, float spatial_scale,
int sampling_ratio, float gamma);

void group_points_forward(int b, int c, int n, int npoints, int nsample,
Tensor points_tensor, Tensor idx_tensor,
Tensor out_tensor);

void group_points_backward(int b, int c, int n, int npoints, int nsample,
Tensor grad_out_tensor, Tensor idx_tensor,
Tensor grad_points_tensor);

void roipoint_pool3d_forward(Tensor xyz, Tensor boxes3d, Tensor pts_feature,
Tensor pooled_features, Tensor pooled_empty_flag);

Expand Down Expand Up @@ -453,6 +461,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"),
py::arg("bboxes2"), py::arg("ious"), py::arg("mode"),
py::arg("aligned"), py::arg("offset"));
m.def("group_points_forward", &group_points_forward, "group_points_forward",
py::arg("b"), py::arg("c"), py::arg("n"), py::arg("npoints"),
py::arg("nsample"), py::arg("points_tensor"), py::arg("idx_tensor"),
py::arg("out_tensor"));
m.def("group_points_backward", &group_points_backward,
"group_points_backward", py::arg("b"), py::arg("c"), py::arg("n"),
py::arg("npoints"), py::arg("nsample"), py::arg("grad_out_tensor"),
py::arg("idx_tensor"), py::arg("grad_points_tensor"));
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
py::arg("dist2_tensor"));
m.def("iou3d_boxes_overlap_bev_forward", &iou3d_boxes_overlap_bev_forward,
"iou3d_boxes_overlap_bev_forward", py::arg("boxes_a"),
py::arg("boxes_b"), py::arg("ans_overlap"));
Expand Down
Loading