Skip to content

Commit

Permalink
[Feature] torch_npu support aclnn and add op (#2998)
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed Jan 7, 2024
1 parent 2e44eae commit c7c02a7
Show file tree
Hide file tree
Showing 20 changed files with 906 additions and 93 deletions.
6 changes: 3 additions & 3 deletions docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ We implement common ops used in detection, segmentation, etc.
| ---------------------------- | --- | ---- | --- | --- | ------ |
| ActiveRotatedFilter ||| | ||
| AssignScoreWithK | || | | |
| BallQuery | ||| | |
| BallQuery | ||| | |
| BBoxOverlaps | |||||
| BorderAlign | || | | |
| BoxIouRotated |||| ||
| BoxIouQuadri ||| | | |
| CARAFE | ||| | |
| ChamferDistance | || | | |
| ChamferDistance | || | | |
| CrissCrossAttention | || | | |
| ContourExpand || | | | |
| ConvexIoU | || | | |
Expand Down Expand Up @@ -44,7 +44,7 @@ We implement common ops used in detection, segmentation, etc.
| RotatedFeatureAlign |||| ||
| RoIPointPool3d | ||| | |
| RoIPool | ||| ||
| RoIAlignRotated |||| | |
| RoIAlignRotated |||| | |
| RiRoIAlignRotated | || | | |
| RoIAlign |||| ||
| RoIAwarePool3d | ||| | |
Expand Down
6 changes: 3 additions & 3 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ MMCV 提供了检测、分割等任务中常用的算子
| ---------------------------- | --- | ---- | --- | --- | ------ |
| ActiveRotatedFilter ||| | ||
| AssignScoreWithK | || | | |
| BallQuery | ||| | |
| BallQuery | ||| | |
| BBoxOverlaps | |||||
| BorderAlign | || | | |
| BoxIouRotated |||| ||
| BoxIouQuadri ||| | | |
| CARAFE | ||| | |
| ChamferDistance | || | | |
| ChamferDistance | || | | |
| CrissCrossAttention | || | | |
| ContourExpand || | | | |
| ConvexIoU | || | | |
Expand Down Expand Up @@ -44,7 +44,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| RotatedFeatureAlign |||| ||
| RoIPointPool3d | ||| | |
| RoIPool | ||| ||
| RoIAlignRotated |||| | |
| RoIAlignRotated |||| | |
| RiRoIAlignRotated | || | | |
| RoIAlign |||| ||
| RoIAwarePool3d | ||| | |
Expand Down
8 changes: 4 additions & 4 deletions mmcv/ops/chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]:
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()

dist1 = torch.zeros(batch_size, n).to(device)
dist2 = torch.zeros(batch_size, m).to(device)
dist1 = torch.zeros(batch_size, n).type(xyz1.dtype).to(device)
dist2 = torch.zeros(batch_size, m).type(xyz2.dtype).to(device)
idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device)
idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device)

Expand Down Expand Up @@ -81,8 +81,8 @@ def backward(ctx,
device = grad_dist1.device
grad_dist1 = grad_dist1.contiguous()
grad_dist2 = grad_dist2.contiguous()
grad_xyz1 = torch.zeros(xyz1.size()).to(device)
grad_xyz2 = torch.zeros(xyz2.size()).to(device)
grad_xyz1 = torch.zeros(xyz1.size()).type(xyz1.dtype).to(device)
grad_xyz2 = torch.zeros(xyz2.size()).type(xyz2.dtype).to(device)

ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2,
grad_dist1, grad_dist2, grad_xyz1,
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/common/pytorch_npu_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
#ifndef PYTORCH_NPU_HELPER_HPP_
#define PYTORCH_NPU_HELPER_HPP_

#include <torch_npu/csrc/aten/CustomFunctions.h>
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
#include <torch_npu/csrc/framework/utils/OpAdapter.h>

#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_npu_util.hpp"

#define NPU_NAME_SPACE at_npu::native

Expand Down
Loading

0 comments on commit c7c02a7

Please sign in to comment.