From 6ebc1d3f030641da43e1c9a97cd2a9bfcadedc11 Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Mon, 20 Sep 2021 00:27:52 +0800 Subject: [PATCH 1/7] add op (group points) and its related ops (ball query and knn) in mmdet3d --- docs/understand_mmcv/ops.md | 3 + mmcv/ops/__init__.py | 6 +- mmcv/ops/ball_query.py | 49 ++++ .../common/cuda/ball_query_cuda_kernel.cuh | 55 +++++ .../common/cuda/group_points_cuda_kernel.cuh | 61 +++++ mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh | 89 +++++++ mmcv/ops/csrc/pytorch/ball_query.cpp | 37 +++ mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu | 38 +++ .../csrc/pytorch/cuda/group_points_cuda.cu | 60 +++++ mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu | 35 +++ mmcv/ops/csrc/pytorch/group_points.cpp | 57 +++++ mmcv/ops/csrc/pytorch/knn.cpp | 33 +++ mmcv/ops/csrc/pytorch/pybind.cpp | 42 +++- mmcv/ops/group_points.py | 233 ++++++++++++++++++ mmcv/ops/knn.py | 73 ++++++ tests/test_ops/test_ball_query.py | 54 ++++ tests/test_ops/test_group_points.py | 76 ++++++ tests/test_ops/test_knn.py | 54 ++++ 18 files changed, 1048 insertions(+), 7 deletions(-) create mode 100644 mmcv/ops/ball_query.py create mode 100644 mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh create mode 100644 mmcv/ops/csrc/pytorch/ball_query.cpp create mode 100644 mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/group_points.cpp create mode 100644 mmcv/ops/csrc/pytorch/knn.cpp create mode 100644 mmcv/ops/group_points.py create mode 100644 mmcv/ops/knn.py create mode 100644 tests/test_ops/test_ball_query.py create mode 100644 tests/test_ops/test_group_points.py create mode 100644 tests/test_ops/test_knn.py diff --git a/docs/understand_mmcv/ops.md b/docs/understand_mmcv/ops.md index 8460682c1f..bb54ce2099 100644 --- a/docs/understand_mmcv/ops.md +++ b/docs/understand_mmcv/ops.md @@ -2,6 +2,7 @@ We implement common CUDA ops used in detection, segmentation, etc. +- BallQuery - BBoxOverlaps - CARAFE - CrissCrossAttention @@ -10,6 +11,8 @@ We implement common CUDA ops used in detection, segmentation, etc. - Deformable Convolution v1/v2 - Deformable RoIPool - GeneralizedAttention +- GroupPoints +- KNN - MaskedConv - NMS - PSAMask diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 359a13c06a..41f0bfdd3e 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .ball_query import ball_query from .bbox import bbox_overlaps from .border_align import BorderAlign, border_align from .box_iou_rotated import box_iou_rotated @@ -16,8 +17,10 @@ from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, sigmoid_focal_loss, softmax_focal_loss) from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu +from .group_points import GroupAll, QueryAndGroup, grouping_operation from .info import (get_compiler_version, get_compiling_cuda_version, get_onnxruntime_op_path) +from .knn import knn from .masked_conv import MaskedConv2d, masked_conv2d from .modulated_deform_conv import (ModulatedDeformConv2d, ModulatedDeformConv2dPack, @@ -51,6 +54,7 @@ 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', - 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', + 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup', + 'GroupAll', 'grouping_operation', 'knn', 'ball_query', 'contour_expand', 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align' ] diff --git a/mmcv/ops/ball_query.py b/mmcv/ops/ball_query.py new file mode 100644 index 0000000000..01686b5b64 --- /dev/null +++ b/mmcv/ops/ball_query.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['ball_query_forward']) + + +class BallQuery(Function): + """Ball Query. + + Find nearby points in spherical space. + """ + + @staticmethod + def forward(ctx, min_radius: float, max_radius: float, sample_num: int, + xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor: + """ + Args: + min_radius (float): minimum radius of the balls. + max_radius (float): maximum radius of the balls. + sample_num (int): maximum number of features in the balls. + xyz (Tensor): (B, N, 3) xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) centers of the ball query. + + Returns: + Tensor: (B, npoint, nsample) tensor with the indicies of + the features that form the query balls. + """ + assert center_xyz.is_contiguous() + assert xyz.is_contiguous() + assert min_radius < max_radius + + B, N, _ = xyz.size() + npoint = center_xyz.size(1) + idx = torch.cuda.IntTensor(B, npoint, sample_num).zero_() + + ext_module.ball_query_forward(B, N, npoint, min_radius, max_radius, + sample_num, center_xyz, xyz, idx) + ctx.mark_non_differentiable(idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None + + +ball_query = BallQuery.apply diff --git a/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh new file mode 100644 index 0000000000..f61b8027dc --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh @@ -0,0 +1,55 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef BALL_QUERY_CUDA_KERNEL_CUH +#define BALL_QUERY_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__global__ void ball_query_forward_cuda_kernel(int b, int n, int m, + float min_radius, + float max_radius, int nsample, + const T* new_xyz, const T* xyz, + int* idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= m) return; + + new_xyz += bs_idx * m * 3 + pt_idx * 3; + xyz += bs_idx * n * 3; + idx += bs_idx * m * nsample + pt_idx * nsample; + + float max_radius2 = max_radius * max_radius; + float min_radius2 = min_radius * min_radius; + T new_x = new_xyz[0]; + T new_y = new_xyz[1]; + T new_z = new_xyz[2]; + + int cnt = 0; + for (int k = 0; k < n; ++k) { + T x = xyz[k * 3 + 0]; + T y = xyz[k * 3 + 1]; + T z = xyz[k * 3 + 2]; + T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + + (new_z - z) * (new_z - z); + if (d2 == 0 || (d2 >= min_radius2 && d2 < max_radius2)) { + if (cnt == 0) { + for (int l = 0; l < nsample; ++l) { + idx[l] = k; + } + } + idx[cnt] = k; + ++cnt; + if (cnt >= nsample) break; + } + } +} + +#endif // BALL_QUERY_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh new file mode 100644 index 0000000000..6eab59992d --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh @@ -0,0 +1,61 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef GROUP_POINTS_CUDA_KERNEL_CUH +#define GROUP_POINTS_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +template +__global__ void group_points_forward_cuda_kernel(int b, int c, int n, + int npoints, int nsample, + const T *points, + const int *__restrict__ idx, + T *out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + int in_idx = bs_idx * c * n + c_idx * n + idx[0]; + int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + + pt_idx * nsample + sample_idx; + + out[out_idx] = points[in_idx]; +} + +template +__global__ void group_points_backward_cuda_kernel(int b, int c, int n, + int npoints, int nsample, + const T *grad_out, + const int *__restrict__ idx, + T *grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + + pt_idx * nsample + sample_idx; + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]); +} + +#endif // GROUP_POINTS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh new file mode 100644 index 0000000000..1b4ec99512 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh @@ -0,0 +1,89 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef KNN_CUDA_KERNEL_CUH +#define KNN_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +__device__ void swap_float(float *x, float *y) { + float tmp = *x; + *x = *y; + *y = tmp; +} + +__device__ void swap_int(int *x, int *y) { + int tmp = *x; + *x = *y; + *y = tmp; +} + +__device__ void reheap(float *dist, int *idx, int k) { + int root = 0; + int child = root * 2 + 1; + while (child < k) { + if (child + 1 < k && dist[child + 1] > dist[child]) child++; + if (dist[root] > dist[child]) return; + swap_float(&dist[root], &dist[child]); + swap_int(&idx[root], &idx[child]); + root = child; + child = root * 2 + 1; + } +} + +__device__ void heap_sort(float *dist, int *idx, int k) { + int i; + for (i = k - 1; i > 0; i--) { + swap_float(&dist[0], &dist[i]); + swap_int(&idx[0], &idx[i]); + reheap(dist, idx, i); + } +} + +// input: xyz (b, n, 3) new_xyz (b, m, 3) +// output: idx (b, m, nsample) dist2 (b, m, nsample) +template +__global__ void knn_forward_cuda_kernel(int b, int n, int m, int nsample, + const T *xyz, const T *new_xyz, + int *__restrict__ idx, T *dist2) { + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= m) return; + + new_xyz += bs_idx * m * 3 + pt_idx * 3; + xyz += bs_idx * n * 3; + idx += bs_idx * m * nsample + pt_idx * nsample; + dist2 += bs_idx * m * nsample + pt_idx * nsample; + + T new_x = new_xyz[0]; + T new_y = new_xyz[1]; + T new_z = new_xyz[2]; + + float best_dist[100]; + int best_idx[100]; + for (int i = 0; i < nsample; i++) { + best_dist[i] = 1e10; + best_idx[i] = 0; + } + for (int i = 0; i < n; i++) { + T x = xyz[i * 3 + 0]; + T y = xyz[i * 3 + 1]; + T z = xyz[i * 3 + 2]; + T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + + (new_z - z) * (new_z - z); + if (d2 < best_dist[0]) { + best_dist[0] = d2; + best_idx[0] = i; + reheap(best_dist, best_idx, nsample); + } + } + heap_sort(best_dist, best_idx, nsample); + for (int i = 0; i < nsample; i++) { + idx[i] = best_idx[i]; + dist2[i] = best_dist[i]; + } +} + +#endif // KNN_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/ball_query.cpp b/mmcv/ops/csrc/pytorch/ball_query.cpp new file mode 100644 index 0000000000..11f67e3530 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/ball_query.cpp @@ -0,0 +1,37 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius, + float max_radius, int nsample, + const Tensor new_xyz, const Tensor xyz, + int *idx); + +void ball_query_forward_cuda(int b, int n, int m, float min_radius, + float max_radius, int nsample, Tensor new_xyz, + Tensor xyz, int *idx) { + BallQueryForwardCUDAKernelLauncher(b, n, m, min_radius, max_radius, nsample, + new_xyz, xyz, idx); +}; +#endif + +void ball_query_forward(int b, int n, int m, float min_radius, float max_radius, + int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor, + Tensor idx_tensor) { + if (new_xyz_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(new_xyz_tensor); + CHECK_CUDA_INPUT(xyz_tensor); + int *idx = idx_tensor.data_ptr(); + + ball_query_forward_cuda(b, n, m, min_radius, max_radius, nsample, + new_xyz_tensor, xyz_tensor, idx); +#else + AT_ERROR("ball_query is not compiled with GPU support"); +#endif + } else { + AT_ERROR("ball_query is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu new file mode 100644 index 0000000000..664c21bf65 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu @@ -0,0 +1,38 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu + +#include +#include +#include + +#include "ball_query_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius, + float max_radius, int nsample, + const Tensor new_xyz, const Tensor xyz, + int *idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + + at::cuda::CUDAGuard device_guard(new_xyz.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + new_xyz.scalar_type(), "ball_query_forward_cuda_kernel", [&] { + ball_query_forward_cuda_kernel + <<>>( + b, n, m, min_radius, max_radius, nsample, + new_xyz.data_ptr(), xyz.data_ptr(), idx); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu new file mode 100644 index 0000000000..117636cec3 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu @@ -0,0 +1,60 @@ +#include +#include + +#include "group_points_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor points, + const Tensor idx, Tensor out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + + at::cuda::CUDAGuard device_guard(points.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "group_points_forward_cuda_kernel", [&] { + group_points_forward_cuda_kernel + <<>>( + b, c, n, npoints, nsample, points.data_ptr(), + idx.data_ptr(), out.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} + +void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor grad_out, + const Tensor idx, + Tensor grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + + at::cuda::CUDAGuard device_guard(grad_out.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_out.scalar_type(), "group_points_backward_cuda_kernel", [&] { + group_points_backward_cuda_kernel + <<>>( + b, c, n, npoints, nsample, grad_out.data_ptr(), + idx.data_ptr(), grad_points.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu new file mode 100644 index 0000000000..3197d6d76d --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu @@ -0,0 +1,35 @@ +// Modified from +// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap + +#include +#include + +#include "knn_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample, + const Tensor xyz, const Tensor new_xyz, + Tensor idx, Tensor dist2) { + // param new_xyz: (B, m, 3) + // param xyz: (B, n, 3) + // param idx: (B, m, nsample) + + at::cuda::CUDAGuard device_guard(new_xyz.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + new_xyz.scalar_type(), "knn_forward_cuda_kernel", [&] { + knn_forward_cuda_kernel<<>>( + b, n, m, nsample, xyz.data_ptr(), + new_xyz.data_ptr(), idx.data_ptr(), + dist2.data_ptr()); + }); + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/group_points.cpp b/mmcv/ops/csrc/pytorch/group_points.cpp new file mode 100644 index 0000000000..954dffe5d6 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/group_points.cpp @@ -0,0 +1,57 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points.cpp + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor points, + const Tensor idx, Tensor out); +void group_points_forward_cuda(int b, int c, int n, int npoints, int nsample, + const Tensor points, const Tensor idx, + Tensor out) { + GroupPointsForwardCUDAKernelLauncher(b, c, n, npoints, nsample, points, idx, + out); +}; + +void GroupPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints, + int nsample, const Tensor grad_out, + const Tensor idx, + Tensor grad_points); +void group_points_backward_cuda(int b, int c, int n, int npoints, int nsample, + const Tensor grad_out, const Tensor idx, + Tensor grad_points) { + GroupPointsBackwardCUDAKernelLauncher(b, c, n, npoints, nsample, grad_out, + idx, grad_points); +}; +#endif + +void group_points_forward(int b, int c, int n, int npoints, int nsample, + Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor) { + if (points_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + group_points_forward_cuda(b, c, n, npoints, nsample, points_tensor, + idx_tensor, out_tensor); +#else + AT_ERROR("group_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("group_points is not implemented on CPU"); + } +} + +void group_points_backward(int b, int c, int n, int npoints, int nsample, + Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor) { + if (grad_out_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + group_points_backward_cuda(b, c, n, npoints, nsample, grad_out_tensor, + idx_tensor, grad_points_tensor); +#else + AT_ERROR("group_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("group_points is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/knn.cpp b/mmcv/ops/csrc/pytorch/knn.cpp new file mode 100644 index 0000000000..fbbbfc8f2b --- /dev/null +++ b/mmcv/ops/csrc/pytorch/knn.cpp @@ -0,0 +1,33 @@ +// Modified from +// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap + +#include "pytorch_cpp_helper.hpp" + +#ifdef MMCV_WITH_CUDA +void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample, + const Tensor xyz, const Tensor new_xyz, + Tensor idx, Tensor dist2); + +void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2) { + KNNForwardCUDAKernelLauncher(b, n, m, nsample, xyz, new_xyz, idx, dist2); +} +#endif + +void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor, + Tensor new_xyz_tensor, Tensor idx_tensor, + Tensor dist2_tensor) { + if (new_xyz_tensor.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CUDA_INPUT(new_xyz_tensor); + CHECK_CUDA_INPUT(xyz_tensor); + + knn_forward_cuda(b, n, m, nsample, xyz_tensor, new_xyz_tensor, idx_tensor, + dist2_tensor); +#else + AT_ERROR("knn is not compiled with GPU support"); +#endif + } else { + AT_ERROR("knn is not implemented on CPU"); + } +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 6ecdf763cf..878d46cc19 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -53,6 +53,13 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois, int pooled_width, float spatial_scale, int sampling_ratio, float gamma); +void group_points_forward(int b, int c, int n, int npoints, int nsample, + Tensor points_tensor, Tensor idx_tensor, + Tensor out_tensor); +void group_points_backward(int b, int c, int n, int npoints, int nsample, + Tensor grad_out_tensor, Tensor idx_tensor, + Tensor grad_points_tensor); + void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha); @@ -69,6 +76,9 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight, void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset); +void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor, + Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor); + void masked_im2col_forward(const Tensor im, const Tensor mask_h_idx, const Tensor mask_w_idx, Tensor col, const int kernel_h, const int kernel_w, @@ -111,16 +121,16 @@ Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset); Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold, float sigma, float min_score, int method, int offset); -std::vector > nms_match(Tensor dets, float iou_threshold); +std::vector> nms_match(Tensor dets, float iou_threshold); -std::vector > pixel_group( +std::vector> pixel_group( Tensor score, Tensor mask, Tensor embedding, Tensor kernel_label, Tensor kernel_contour, int kernel_region_num, float distance_threshold); -std::vector > contour_expand(Tensor kernel_mask, - Tensor internal_kernel_label, - int min_kernel_area, - int kernel_num); +std::vector> contour_expand(Tensor kernel_mask, + Tensor internal_kernel_label, + int min_kernel_area, + int kernel_num); void roi_align_forward(Tensor input, Tensor rois, Tensor output, Tensor argmax_y, Tensor argmax_x, int aligned_height, @@ -172,6 +182,10 @@ void tin_shift_forward(Tensor input, Tensor shift, Tensor output); void tin_shift_backward(Tensor grad_output, Tensor shift, Tensor grad_input); +void ball_query_forward(int b, int n, int m, float min_radius, float max_radius, + int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor, + Tensor idx_tensor); + Tensor bottom_pool_forward(Tensor input); Tensor bottom_pool_backward(Tensor input, Tensor grad_output); @@ -300,6 +314,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), py::arg("aligned"), py::arg("offset")); + m.def("group_points_forward", &group_points_forward, "group_points_forward", + py::arg("b"), py::arg("c"), py::arg("n"), py::arg("npoints"), + py::arg("nsample"), py::arg("points_tensor"), py::arg("idx_tensor"), + py::arg("out_tensor")); + m.def("group_points_backward", &group_points_backward, + "group_points_backward", py::arg("b"), py::arg("c"), py::arg("n"), + py::arg("npoints"), py::arg("nsample"), py::arg("grad_out_tensor"), + py::arg("idx_tensor"), py::arg("grad_points_tensor")); + m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"), + py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"), + py::arg("new_xyz_tensor"), py::arg("idx_tensor"), + py::arg("dist2_tensor")); m.def("masked_im2col_forward", &masked_im2col_forward, "masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"), py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"), @@ -415,6 +441,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"), py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), py::arg("iou_threshold"), py::arg("multi_label")); + m.def("ball_query_forward", &ball_query_forward, "ball_query_forward", + py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"), + py::arg("max_radius"), py::arg("nsample"), py::arg("new_xyz_tensor"), + py::arg("xyz_tensor"), py::arg("idx_tensor")); m.def("roi_align_rotated_forward", &roi_align_rotated_forward, "roi_align_rotated forward", py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"), diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py new file mode 100644 index 0000000000..46b629ac1f --- /dev/null +++ b/mmcv/ops/group_points.py @@ -0,0 +1,233 @@ +from typing import Tuple + +import torch +from torch import nn as nn +from torch.autograd import Function + +from ..utils import ext_loader +from .ball_query import ball_query +from .knn import knn + +ext_module = ext_loader.load_ext( + '_ext', ['group_points_forward', 'group_points_backward']) + + +class QueryAndGroup(nn.Module): + """Query and Group. + + Groups with a ball query of radius + + Args: + max_radius (float): The maximum radius of the balls. + If None is given, we will use kNN sampling instead of ball query. + sample_num (int): Maximum number of features to gather in the ball. + min_radius (float, optional): The minimum radius of the balls. + Default: 0. + use_xyz (bool, optional): Whether to use xyz. + Default: True. + return_grouped_xyz (bool, optional): Whether to return grouped xyz. + Default: False. + normalize_xyz (bool, optional): Whether to normalize xyz. + Default: False. + uniform_sample (bool, optional): Whether to sample uniformly. + Default: False + return_unique_cnt (bool, optional): Whether to return the count of + unique samples. Default: False. + return_grouped_idx (bool, optional): Whether to return grouped idx. + Default: False. + """ + + def __init__(self, + max_radius, + sample_num, + min_radius=0, + use_xyz=True, + return_grouped_xyz=False, + normalize_xyz=False, + uniform_sample=False, + return_unique_cnt=False, + return_grouped_idx=False): + super(QueryAndGroup, self).__init__() + self.max_radius = max_radius + self.min_radius = min_radius + self.sample_num = sample_num + self.use_xyz = use_xyz + self.return_grouped_xyz = return_grouped_xyz + self.normalize_xyz = normalize_xyz + self.uniform_sample = uniform_sample + self.return_unique_cnt = return_unique_cnt + self.return_grouped_idx = return_grouped_idx + if self.return_unique_cnt: + assert self.uniform_sample, \ + 'uniform_sample should be True when ' \ + 'returning the count of unique samples' + if self.max_radius is None: + assert not self.normalize_xyz, \ + 'can not normalize grouped xyz when max_radius is None' + + def forward(self, points_xyz, center_xyz, features=None): + """forward. + + Args: + points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) Centriods. + features (Tensor): (B, C, N) Descriptors of the features. + + Return: + Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. + """ + # if self.max_radius is None, we will perform kNN instead of ball query + # idx is of shape [B, npoint, sample_num] + if self.max_radius is None: + idx = knn(self.sample_num, points_xyz, center_xyz, False) + idx = idx.transpose(1, 2).contiguous() + else: + idx = ball_query(self.min_radius, self.max_radius, self.sample_num, + points_xyz, center_xyz) + + if self.uniform_sample: + unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) + for i_batch in range(idx.shape[0]): + for i_region in range(idx.shape[1]): + unique_ind = torch.unique(idx[i_batch, i_region, :]) + num_unique = unique_ind.shape[0] + unique_cnt[i_batch, i_region] = num_unique + sample_ind = torch.randint( + 0, + num_unique, (self.sample_num - num_unique, ), + dtype=torch.long) + all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) + idx[i_batch, i_region, :] = all_ind + + xyz_trans = points_xyz.transpose(1, 2).contiguous() + # (B, 3, npoint, sample_num) + grouped_xyz = grouping_operation(xyz_trans, idx) + grouped_xyz_diff = grouped_xyz - \ + center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets + if self.normalize_xyz: + grouped_xyz_diff /= self.max_radius + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + # (B, C + 3, npoint, sample_num) + new_features = torch.cat([grouped_xyz_diff, grouped_features], + dim=1) + else: + new_features = grouped_features + else: + assert (self.use_xyz + ), 'Cannot have not features and not use xyz as a feature!' + new_features = grouped_xyz_diff + + ret = [new_features] + if self.return_grouped_xyz: + ret.append(grouped_xyz) + if self.return_unique_cnt: + ret.append(unique_cnt) + if self.return_grouped_idx: + ret.append(idx) + if len(ret) == 1: + return ret[0] + else: + return tuple(ret) + + +class GroupAll(nn.Module): + """Group All. + + Group xyz with feature. + + Args: + use_xyz (bool): Whether to use xyz. + """ + + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, + xyz: torch.Tensor, + new_xyz: torch.Tensor, + features: torch.Tensor = None): + """forward. + + Args: + xyz (Tensor): (B, N, 3) xyz coordinates of the features. + new_xyz (Tensor): Ignored. + features (Tensor): (B, C, N) features to group. + + Return: + Tensor: (B, C + 3, 1, N) Grouped feature. + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], + dim=1) # (B, 3 + C, 1, N) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features + + +class GroupingOperation(Function): + """Grouping Operation. + + Group feature with given index. + """ + + @staticmethod + def forward(ctx, features: torch.Tensor, + indices: torch.Tensor) -> torch.Tensor: + """forward. + + Args: + features (Tensor): (B, C, N) tensor of features to group. + indices (Tensor): (B, npoint, nsample) the indicies of + features to group with. + + Returns: + Tensor: (B, C, npoint, nsample) Grouped features. + """ + assert features.is_contiguous() + assert indices.is_contiguous() + + B, nfeatures, nsample = indices.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + ext_module.group_points_forward(B, C, N, nfeatures, nsample, features, + indices, output) + + ctx.for_backwards = (indices, N) + return output + + @staticmethod + def backward(ctx, + grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """backward. + + Args: + grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients + of the output from forward. + + Returns: + Tensor: (B, C, N) gradient of the features. + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = torch.cuda.FloatTensor(B, C, N).zero_() + + grad_out_data = grad_out.data.contiguous() + ext_module.group_points_backward(B, C, N, npoint, nsample, + grad_out_data, idx, + grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py new file mode 100644 index 0000000000..ee2fcea76f --- /dev/null +++ b/mmcv/ops/knn.py @@ -0,0 +1,73 @@ +import torch +from torch.autograd import Function + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['knn_forward']) + + +class KNN(Function): + r"""KNN (CUDA) based on heap data structure. + Modified from `PAConv `_. + + Find k-nearest points. + """ + + @staticmethod + def forward(ctx, + k: int, + xyz: torch.Tensor, + center_xyz: torch.Tensor = None, + transposed: bool = False) -> torch.Tensor: + """ + Args: + k (int): number of nearest neighbors. + xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N). + xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) if transposed == False, + else (B, 3, npoint). centers of the knn query. + transposed (bool): whether the input tensors are transposed. + defaults to False. Should not expicitly use this keyword + when calling knn (=KNN.apply), just add the fourth param. + + Returns: + Tensor: (B, k, npoint) tensor with the indicies of + the features that form k-nearest neighbours. + """ + assert k > 0 + + if center_xyz is None: + center_xyz = xyz + + if transposed: + xyz = xyz.transpose(2, 1).contiguous() + center_xyz = center_xyz.transpose(2, 1).contiguous() + + assert xyz.is_contiguous() # [B, N, 3] + assert center_xyz.is_contiguous() # [B, npoint, 3] + + center_xyz_device = center_xyz.get_device() + assert center_xyz_device == xyz.get_device(), \ + 'center_xyz and xyz should be put on the same device' + if torch.cuda.current_device() != center_xyz_device: + torch.cuda.set_device(center_xyz_device) + + B, npoint, _ = center_xyz.shape + N = xyz.shape[1] + + idx = center_xyz.new_zeros((B, npoint, k)).int() + dist2 = center_xyz.new_zeros((B, npoint, k)).float() + + ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2) + # idx shape to [B, k, npoint] + idx = idx.transpose(2, 1).contiguous() + ctx.mark_non_differentiable(idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None + + +knn = KNN.apply diff --git a/tests/test_ops/test_ball_query.py b/tests/test_ops/test_ball_query.py new file mode 100644 index 0000000000..7bb56ee96b --- /dev/null +++ b/tests/test_ops/test_ball_query.py @@ -0,0 +1,54 @@ +import pytest +import torch + +from mmcv.ops import ball_query + + +def test_ball_query(): + if not torch.cuda.is_available(): + pytest.skip() + new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], + [-2.2769, 2.7817, -0.2334], + [-0.4003, 2.4666, -0.5116], + [-0.0740, 1.3147, -1.3625], + [-0.0740, 1.3147, -1.3625]], + [[-2.0289, 2.4952, -0.1708], + [-2.0668, 6.0278, -0.4875], + [0.4066, 1.4211, -0.2947], + [-2.0289, 2.4952, -0.1708], + [-2.0289, 2.4952, -0.1708]]]).cuda() + + xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634], + [-0.4003, 2.4666, + -0.5116], [-0.5251, 2.4379, -0.8466], + [-0.9691, 1.1418, + -1.3733], [-0.2232, 0.9561, -1.3626], + [-2.2769, 2.7817, -0.2334], + [-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432], + [0.4917, 1.1529, -1.3496]], + [[-2.0289, 2.4952, + -0.1708], [-0.7188, 0.9956, -0.5096], + [-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610], + [0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791], + [-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947], + [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, + -1.2000]]]).cuda() + + idx = ball_query(0, 0.2, 5, xyz, new_xyz) + expected_idx = torch.tensor([[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], + [2, 2, 2, 2, 2], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], + [7, 7, 7, 7, 7], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]]).cuda() + assert torch.all(idx == expected_idx) + + # test dilated ball query + idx = ball_query(0.2, 0.4, 5, xyz, new_xyz) + expected_idx = torch.tensor([[[0, 5, 7, 0, 0], [6, 6, 6, 6, 6], + [2, 3, 2, 2, 2], [0, 5, 7, 0, 0], + [0, 5, 7, 0, 0]], + [[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], + [7, 7, 7, 7, 7], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]]).cuda() + assert torch.all(idx == expected_idx) diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py new file mode 100644 index 0000000000..a4c588eb97 --- /dev/null +++ b/tests/test_ops/test_group_points.py @@ -0,0 +1,76 @@ +import pytest +import torch + +from mmcv.ops import grouping_operation + + +def test_grouping_points(): + if not torch.cuda.is_available(): + pytest.skip() + idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], + [0, 0, 0]]]).int().cuda() + festures = torch.tensor([[[ + 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, + 0.9268, 0.8414 + ], + [ + 5.4247, 1.5113, 2.3944, 1.4740, 5.0300, + 5.1030, 1.9360, 2.1939, 2.1581, 3.4666 + ], + [ + -1.6266, -1.0281, -1.0393, -1.6931, -1.3982, + -0.5732, -1.0830, -1.7561, -1.6786, -1.6967 + ]], + [[ + -0.0380, -0.1880, -1.5724, 0.6905, -0.3190, + 0.7798, -0.3693, -0.9457, -0.2942, -1.8527 + ], + [ + 1.1773, 1.5009, 2.6399, 5.9242, 1.0962, + 2.7346, 6.0865, 1.5555, 4.3303, 2.8229 + ], + [ + -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, + -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 + ]]]).cuda() + + output = grouping_operation(festures, idx) + expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798], + [-1.3311, -1.3311, -1.3311], + [0.9268, 0.9268, 0.9268], + [0.5798, 0.5798, 0.5798], + [0.5798, 0.5798, 0.5798], + [0.5798, 0.5798, 0.5798]], + [[5.4247, 5.4247, 5.4247], + [1.4740, 1.4740, 1.4740], + [2.1581, 2.1581, 2.1581], + [5.4247, 5.4247, 5.4247], + [5.4247, 5.4247, 5.4247], + [5.4247, 5.4247, 5.4247]], + [[-1.6266, -1.6266, -1.6266], + [-1.6931, -1.6931, -1.6931], + [-1.6786, -1.6786, -1.6786], + [-1.6266, -1.6266, -1.6266], + [-1.6266, -1.6266, -1.6266], + [-1.6266, -1.6266, -1.6266]]], + [[[-0.0380, -0.0380, -0.0380], + [-0.3693, -0.3693, -0.3693], + [-1.8527, -1.8527, -1.8527], + [-0.0380, -0.0380, -0.0380], + [-0.0380, -0.0380, -0.0380], + [-0.0380, -0.0380, -0.0380]], + [[1.1773, 1.1773, 1.1773], + [6.0865, 6.0865, 6.0865], + [2.8229, 2.8229, 2.8229], + [1.1773, 1.1773, 1.1773], + [1.1773, 1.1773, 1.1773], + [1.1773, 1.1773, 1.1773]], + [[-0.6646, -0.6646, -0.6646], + [0.4990, 0.4990, 0.4990], + [0.0386, 0.0386, 0.0386], + [-0.6646, -0.6646, -0.6646], + [-0.6646, -0.6646, -0.6646], + [-0.6646, -0.6646, -0.6646]]]]).cuda() + assert torch.allclose(output, expected_output) diff --git a/tests/test_ops/test_knn.py b/tests/test_ops/test_knn.py new file mode 100644 index 0000000000..33237dcd76 --- /dev/null +++ b/tests/test_ops/test_knn.py @@ -0,0 +1,54 @@ +import pytest +import torch + +from mmcv.ops import knn + + +def test_knn(): + if not torch.cuda.is_available(): + pytest.skip() + new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], + [-2.2769, 2.7817, -0.2334], + [-0.4003, 2.4666, -0.5116], + [-0.0740, 1.3147, -1.3625], + [-0.0740, 1.3147, -1.3625]], + [[-2.0289, 2.4952, -0.1708], + [-2.0668, 6.0278, -0.4875], + [0.4066, 1.4211, -0.2947], + [-2.0289, 2.4952, -0.1708], + [-2.0289, 2.4952, -0.1708]]]).cuda() + + xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634], + [-0.4003, 2.4666, + -0.5116], [-0.5251, 2.4379, -0.8466], + [-0.9691, 1.1418, + -1.3733], [-0.2232, 0.9561, -1.3626], + [-2.2769, 2.7817, -0.2334], + [-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432], + [0.4917, 1.1529, -1.3496]], + [[-2.0289, 2.4952, + -0.1708], [-0.7188, 0.9956, -0.5096], + [-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610], + [0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791], + [-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947], + [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, + -1.2000]]]).cuda() + + idx = knn(5, xyz, new_xyz) + new_xyz_ = new_xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1) + xyz_ = xyz.unsqueeze(1).repeat(1, new_xyz.shape[1], 1, 1) + dist = ((new_xyz_ - xyz_) * (new_xyz_ - xyz_)).sum(-1) + expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1) + assert torch.all(idx == expected_idx) + + idx = knn(5, + xyz.transpose(1, 2).contiguous(), + new_xyz.transpose(1, 2).contiguous(), True) + assert torch.all(idx == expected_idx) + + idx = knn(5, xyz, xyz) + xyz_ = xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1) + xyz__ = xyz.unsqueeze(1).repeat(1, xyz.shape[1], 1, 1) + dist = ((xyz_ - xyz__) * (xyz_ - xyz__)).sum(-1) + expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1) + assert torch.all(idx == expected_idx) From 0959dc8e0658565de67488d371d44a0da157e613 Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Mon, 20 Sep 2021 15:09:02 +0800 Subject: [PATCH 2/7] refactor code --- mmcv/ops/csrc/pytorch/ball_query.cpp | 10 +++++----- mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/ball_query.cpp b/mmcv/ops/csrc/pytorch/ball_query.cpp index 11f67e3530..0a0892ba16 100644 --- a/mmcv/ops/csrc/pytorch/ball_query.cpp +++ b/mmcv/ops/csrc/pytorch/ball_query.cpp @@ -7,11 +7,12 @@ void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius, float max_radius, int nsample, const Tensor new_xyz, const Tensor xyz, - int *idx); + Tensor idx); void ball_query_forward_cuda(int b, int n, int m, float min_radius, - float max_radius, int nsample, Tensor new_xyz, - Tensor xyz, int *idx) { + float max_radius, int nsample, + const Tensor new_xyz, const Tensor xyz, + Tensor idx) { BallQueryForwardCUDAKernelLauncher(b, n, m, min_radius, max_radius, nsample, new_xyz, xyz, idx); }; @@ -24,10 +25,9 @@ void ball_query_forward(int b, int n, int m, float min_radius, float max_radius, #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(new_xyz_tensor); CHECK_CUDA_INPUT(xyz_tensor); - int *idx = idx_tensor.data_ptr(); ball_query_forward_cuda(b, n, m, min_radius, max_radius, nsample, - new_xyz_tensor, xyz_tensor, idx); + new_xyz_tensor, xyz_tensor, idx_tensor); #else AT_ERROR("ball_query is not compiled with GPU support"); #endif diff --git a/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu index 664c21bf65..fbf3f9f458 100644 --- a/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu @@ -13,7 +13,7 @@ void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius, float max_radius, int nsample, const Tensor new_xyz, const Tensor xyz, - int *idx) { + Tensor idx) { // new_xyz: (B, M, 3) // xyz: (B, N, 3) // output: @@ -31,7 +31,8 @@ void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius, ball_query_forward_cuda_kernel <<>>( b, n, m, min_radius, max_radius, nsample, - new_xyz.data_ptr(), xyz.data_ptr(), idx); + new_xyz.data_ptr(), xyz.data_ptr(), + idx.data_ptr()); }); AT_CUDA_CHECK(cudaGetLastError()); From 7835e7585f342cc199223e851cb4483328f0f11d Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Wed, 22 Sep 2021 11:54:42 +0800 Subject: [PATCH 3/7] fix typo --- mmcv/ops/ball_query.py | 5 +---- .../common/cuda/ball_query_cuda_kernel.cuh | 4 ++++ .../common/cuda/group_points_cuda_kernel.cuh | 2 ++ mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh | 4 ++++ mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu | 3 +-- .../csrc/pytorch/cuda/group_points_cuda.cu | 2 -- mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu | 3 +-- mmcv/ops/group_points.py | 19 +++++-------------- tests/test_ops/test_ball_query.py | 2 ++ tests/test_ops/test_group_points.py | 2 ++ tests/test_ops/test_knn.py | 2 ++ 11 files changed, 24 insertions(+), 24 deletions(-) diff --git a/mmcv/ops/ball_query.py b/mmcv/ops/ball_query.py index 01686b5b64..22b8151587 100644 --- a/mmcv/ops/ball_query.py +++ b/mmcv/ops/ball_query.py @@ -8,10 +8,7 @@ class BallQuery(Function): - """Ball Query. - - Find nearby points in spherical space. - """ + """Find nearby points in spherical space.""" @staticmethod def forward(ctx, min_radius: float, max_radius: float, sample_num: int, diff --git a/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh index f61b8027dc..588002a282 100644 --- a/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh @@ -1,4 +1,6 @@ // Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu #ifndef BALL_QUERY_CUDA_KERNEL_CUH #define BALL_QUERY_CUDA_KERNEL_CUH @@ -8,6 +10,8 @@ #include "pytorch_cuda_helper.hpp" #endif +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + template __global__ void ball_query_forward_cuda_kernel(int b, int n, int m, float min_radius, diff --git a/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh index 6eab59992d..1b42651a23 100644 --- a/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh @@ -8,6 +8,8 @@ #include "pytorch_cuda_helper.hpp" #endif +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + template __global__ void group_points_forward_cuda_kernel(int b, int c, int n, int npoints, int nsample, diff --git a/mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh index 1b4ec99512..b84ca13834 100644 --- a/mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh @@ -1,4 +1,6 @@ // Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap #ifndef KNN_CUDA_KERNEL_CUH #define KNN_CUDA_KERNEL_CUH @@ -8,6 +10,8 @@ #include "pytorch_cuda_helper.hpp" #endif +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + __device__ void swap_float(float *x, float *y) { float tmp = *x; *x = *y; diff --git a/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu index fbf3f9f458..b13321ae46 100644 --- a/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu @@ -1,3 +1,4 @@ +// Copyright (c) OpenMMLab. All rights reserved // Modified from // https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu @@ -8,8 +9,6 @@ #include "ball_query_cuda_kernel.cuh" #include "pytorch_cuda_helper.hpp" -#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) - void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius, float max_radius, int nsample, const Tensor new_xyz, const Tensor xyz, diff --git a/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu index 117636cec3..4286d5ec75 100644 --- a/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu @@ -4,8 +4,6 @@ #include "group_points_cuda_kernel.cuh" #include "pytorch_cuda_helper.hpp" -#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) - void GroupPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, int nsample, const Tensor points, const Tensor idx, Tensor out) { diff --git a/mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu index 3197d6d76d..e6d1e80a67 100644 --- a/mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu @@ -1,3 +1,4 @@ +// Copyright (c) OpenMMLab. All rights reserved // Modified from // https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap @@ -7,8 +8,6 @@ #include "knn_cuda_kernel.cuh" #include "pytorch_cuda_helper.hpp" -#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) - void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample, const Tensor xyz, const Tensor new_xyz, Tensor idx, Tensor dist2) { diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py index 46b629ac1f..ab2e001ba3 100644 --- a/mmcv/ops/group_points.py +++ b/mmcv/ops/group_points.py @@ -13,9 +13,7 @@ class QueryAndGroup(nn.Module): - """Query and Group. - - Groups with a ball query of radius + """Groups with a ball query of radius. Args: max_radius (float): The maximum radius of the balls. @@ -66,8 +64,7 @@ def __init__(self, 'can not normalize grouped xyz when max_radius is None' def forward(self, points_xyz, center_xyz, features=None): - """forward. - + """ Args: points_xyz (Tensor): (B, N, 3) xyz coordinates of the features. center_xyz (Tensor): (B, npoint, 3) Centriods. @@ -134,9 +131,7 @@ def forward(self, points_xyz, center_xyz, features=None): class GroupAll(nn.Module): - """Group All. - - Group xyz with feature. + """Group xyz with feature. Args: use_xyz (bool): Whether to use xyz. @@ -175,10 +170,7 @@ def forward(self, class GroupingOperation(Function): - """Grouping Operation. - - Group feature with given index. - """ + """Group feature with given index.""" @staticmethod def forward(ctx, features: torch.Tensor, @@ -209,8 +201,7 @@ def forward(ctx, features: torch.Tensor, @staticmethod def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """backward. - + """ Args: grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients of the output from forward. diff --git a/tests/test_ops/test_ball_query.py b/tests/test_ops/test_ball_query.py index 7bb56ee96b..438f7efdcf 100644 --- a/tests/test_ops/test_ball_query.py +++ b/tests/test_ops/test_ball_query.py @@ -4,6 +4,8 @@ from mmcv.ops import ball_query +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') def test_ball_query(): if not torch.cuda.is_available(): pytest.skip() diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py index a4c588eb97..3f45d2fc3a 100644 --- a/tests/test_ops/test_group_points.py +++ b/tests/test_ops/test_group_points.py @@ -4,6 +4,8 @@ from mmcv.ops import grouping_operation +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') def test_grouping_points(): if not torch.cuda.is_available(): pytest.skip() diff --git a/tests/test_ops/test_knn.py b/tests/test_ops/test_knn.py index 33237dcd76..b2b33823ab 100644 --- a/tests/test_ops/test_knn.py +++ b/tests/test_ops/test_knn.py @@ -4,6 +4,8 @@ from mmcv.ops import knn +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') def test_knn(): if not torch.cuda.is_available(): pytest.skip() From 8668144876ea4149ef64c9b77502c35c1096b94c Mon Sep 17 00:00:00 2001 From: "hudingchang.vendor" Date: Sun, 10 Oct 2021 03:51:10 +0800 Subject: [PATCH 4/7] refactor code --- tests/test_ops/test_ball_query.py | 2 -- tests/test_ops/test_group_points.py | 2 -- tests/test_ops/test_knn.py | 2 -- 3 files changed, 6 deletions(-) diff --git a/tests/test_ops/test_ball_query.py b/tests/test_ops/test_ball_query.py index 438f7efdcf..cf30a7efab 100644 --- a/tests/test_ops/test_ball_query.py +++ b/tests/test_ops/test_ball_query.py @@ -7,8 +7,6 @@ @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') def test_ball_query(): - if not torch.cuda.is_available(): - pytest.skip() new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [-2.2769, 2.7817, -0.2334], [-0.4003, 2.4666, -0.5116], diff --git a/tests/test_ops/test_group_points.py b/tests/test_ops/test_group_points.py index 3f45d2fc3a..1b495c2850 100644 --- a/tests/test_ops/test_group_points.py +++ b/tests/test_ops/test_group_points.py @@ -7,8 +7,6 @@ @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') def test_grouping_points(): - if not torch.cuda.is_available(): - pytest.skip() idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], diff --git a/tests/test_ops/test_knn.py b/tests/test_ops/test_knn.py index b2b33823ab..2740cb5e1b 100644 --- a/tests/test_ops/test_knn.py +++ b/tests/test_ops/test_knn.py @@ -7,8 +7,6 @@ @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') def test_knn(): - if not torch.cuda.is_available(): - pytest.skip() new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [-2.2769, 2.7817, -0.2334], [-0.4003, 2.4666, -0.5116], From 85b4c07fd87e767e69dc94bd4a1eb49bf69cef95 Mon Sep 17 00:00:00 2001 From: hdc Date: Wed, 20 Oct 2021 22:55:50 +0800 Subject: [PATCH 5/7] fix typo --- mmcv/ops/group_points.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py index 48a5b87fc5..1fd0faa155 100644 --- a/mmcv/ops/group_points.py +++ b/mmcv/ops/group_points.py @@ -178,7 +178,7 @@ def forward(ctx, features: torch.Tensor, """ Args: features (Tensor): (B, C, N) tensor of features to group. - indices (Tensor): (B, npoint, nsample) the indicies of + indices (Tensor): (B, npoint, nsample) the indices of features to group with. Returns: From bd064caced03bcd2dee2532ebcacb2991906c7e5 Mon Sep 17 00:00:00 2001 From: hdc Date: Thu, 21 Oct 2021 21:09:13 +0800 Subject: [PATCH 6/7] refactor code --- mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh | 4 +++- mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu | 3 +++ mmcv/ops/csrc/pytorch/group_points.cpp | 1 + mmcv/ops/group_points.py | 1 + 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh index 6eab59992d..9cfc2dc865 100644 --- a/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/group_points_cuda_kernel.cuh @@ -1,4 +1,6 @@ -// Copyright (c) OpenMMLab. All rights reserved +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu #ifndef GROUP_POINTS_CUDA_KERNEL_CUH #define GROUP_POINTS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu index bcad948276..e7c57b018a 100644 --- a/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/group_points_cuda.cu @@ -1,3 +1,6 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu #include #include diff --git a/mmcv/ops/csrc/pytorch/group_points.cpp b/mmcv/ops/csrc/pytorch/group_points.cpp index 954dffe5d6..1ebc947a19 100644 --- a/mmcv/ops/csrc/pytorch/group_points.cpp +++ b/mmcv/ops/csrc/pytorch/group_points.cpp @@ -1,3 +1,4 @@ +// Copyright (c) OpenMMLab. All rights reserved. // Modified from // https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points.cpp diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py index 1fd0faa155..12ef4523bc 100644 --- a/mmcv/ops/group_points.py +++ b/mmcv/ops/group_points.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple import torch From 9f566e6d3bce0d6c011c80502e82a1d32e4a1c3a Mon Sep 17 00:00:00 2001 From: zhouzaida Date: Fri, 22 Oct 2021 12:43:46 +0800 Subject: [PATCH 7/7] make input contiguous --- mmcv/ops/group_points.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/group_points.py b/mmcv/ops/group_points.py index 12ef4523bc..5afd227944 100644 --- a/mmcv/ops/group_points.py +++ b/mmcv/ops/group_points.py @@ -185,8 +185,8 @@ def forward(ctx, features: torch.Tensor, Returns: Tensor: (B, C, npoint, nsample) Grouped features. """ - assert features.is_contiguous() - assert indices.is_contiguous() + features = features.contiguous() + indices = indices.contiguous() B, nfeatures, nsample = indices.size() _, C, N = features.size()