Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Oct 21, 2021
2 parents 86f3a65 + 4c8bfb4 commit 7012aed
Show file tree
Hide file tree
Showing 37 changed files with 2,732 additions and 34 deletions.
14 changes: 7 additions & 7 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ Welcome to MMCV's documentation!
You can switch between Chinese and English documents in the lower-left corner of the layout.

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: Get Started

get_started/introduction.md
get_started/installation.md
get_started/build.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: Understand MMCV

understand_mmcv/config.md
Expand All @@ -26,7 +26,7 @@ You can switch between Chinese and English documents in the lower-left corner of
understand_mmcv/utils.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: Deployment

deployment/onnx.md
Expand All @@ -36,26 +36,26 @@ You can switch between Chinese and English documents in the lower-left corner of
deployment/tensorrt_custom_ops.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: Compatibility

compatibility.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: FAQ

faq.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: Community

community/contributing.md
community/pr.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: API Reference

api.rst
Expand Down
3 changes: 3 additions & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- DynamicScatter
- GatherPoints
- FurthestPointSample
- FurthestPointSampleWithDist
Expand All @@ -22,11 +23,13 @@ We implement common CUDA ops used in detection, segmentation, etc.
- RoIPointPool3d
- RoIPool
- RoIAlign
- RoIAwarePool3d
- SimpleRoIAlign
- SigmoidFocalLoss
- SoftmaxFocalLoss
- SoftNMS
- Synchronized BatchNorm
- Voxelization
- ThreeInterpolate
- ThreeNN
- Weight standardization
Expand Down
14 changes: 7 additions & 7 deletions docs_zh_CN/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
您可以在页面左下角切换中英文文档。

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: 介绍与安装

get_started/introduction.md
get_started/installation.md
get_started/build.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: 深入理解 MMCV

understand_mmcv/config.md
Expand All @@ -26,7 +26,7 @@
understand_mmcv/utils.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: 部署

deployment/onnx.md
Expand All @@ -36,26 +36,26 @@
deployment/tensorrt_custom_ops.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: 兼容性

compatibility.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: 常见问题

faq.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: 社区

community/contributing.md
community/pr.md

.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: API 文档

api.rst
Expand Down
3 changes: 3 additions & 0 deletions docs_zh_CN/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- DynamicScatter
- GatherPoints
- FurthestPointSample
- FurthestPointSampleWithDist
Expand All @@ -22,11 +23,13 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- RoIPointPool3d
- RoIPool
- RoIAlign
- RoIAwarePool3d
- SimpleRoIAlign
- SigmoidFocalLoss
- SoftmaxFocalLoss
- SoftNMS
- Synchronized BatchNorm
- Voxelization
- ThreeInterpolate
- ThreeNN
- Weight standardization
Expand Down
2 changes: 1 addition & 1 deletion docs_zh_CN/understand_mmcv/registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Converter1(object):
self.a = a
self.b = b
```
使用注册器管理模块的关键步骤是,将实现的模块到注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类之间的映射就可以由 `CONVERTERS` 构建和维护,如下所示:
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类之间的映射就可以由 `CONVERTERS` 构建和维护,如下所示:

通过这种方式,就可以通过 `CONVERTERS` 建立字符串与类之间的映射,如下所示:

Expand Down
10 changes: 9 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,23 @@
from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .points_sampler import PointsSampler
from .psa_mask import PSAMask
from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_pool import RoIPool, roi_pool
from .roiaware_pool3d import RoIAwarePool3d
from .roipoint_pool3d import RoIPointPool3d
from .saconv import SAConv2d
from .scatter_points import DynamicScatter, dynamic_scatter
from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
from .three_nn import three_nn
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d
from .voxelize import Voxelization, voxelization

__all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
Expand All @@ -68,5 +73,8 @@
'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn',
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter',
'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu',
'points_in_boxes_all'
]
4 changes: 2 additions & 2 deletions mmcv/ops/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class Correlation(nn.Module):
where :math:`\star` is the valid 2d sliding window convolution operator,
and :math:`\mathcal{S}` means shifting the input features (auto-complete
zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
[-\text{max_displacement} \times \text{dilation_patch},
\text{max_displacement} \times \text{dilation_patch}]`.
[-\text{max\_displacement} \times \text{dilation\_patch},
\text{max\_displacement} \times \text{dilation\_patch}]`.
Args:
kernel_size (int): The size of sliding window i.e. local neighborhood
Expand Down
93 changes: 93 additions & 0 deletions mmcv/ops/csrc/common/cuda/points_in_boxes_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINT_IN_BOXES_CUDA_KERNEL_CUH
#define POINT_IN_BOXES_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}

template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
// cz in the bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size /
2.0; // shift to the center since cz in box3d is the bottom center

if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}

template <typename T>
__global__ void points_in_boxes_part_forward_cuda_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1

int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;

boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num + pt_idx;

T local_x = 0, local_y = 0;
int cur_in_flag = 0;
for (int k = 0; k < boxes_num; k++) {
cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[0] = k;
break;
}
}
}

template <typename T>
__global__ void points_in_boxes_all_forward_cuda_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1

int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;

boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num;

T local_x = 0, local_y = 0;
for (int k = 0; k < boxes_num; k++) {
const int cur_in_flag =
check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[k] = 1;
}
}
}

#endif // POINT_IN_BOXES_CUDA_KERNEL_CUH
Loading

0 comments on commit 7012aed

Please sign in to comment.