Skip to content

Commit

Permalink
add charmferback
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed Dec 20, 2023
1 parent fcab9be commit b89ace6
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
Expand All @@ -8,8 +7,8 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
Tensor dist2, Tensor idx1, Tensor idx2) {
at::Tensor xyz1 = at::ones_like(XYZ1);
at::Tensor xyz2 = at::ones_like(XYZ2);
xyz1 = XYZ1.transpose(1,2);
xyz2 = XYZ2.transpose(1,2);
xyz1 = XYZ1.transpose(1, 2);
xyz2 = XYZ2.transpose(1, 2);
OpCommand cmd;
cmd.Name("ChamferDistance")
.Input(xyz1)
Expand All @@ -21,7 +20,20 @@ void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
.Run();
}

void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor grad_dist1, Tensor grad_dist2,
Tensor grad_xyz1, Tensor grad_xyz2) {
EXEC_NPU_CMD(aclnnChamferDistanceBackward, xyz1, xyz2, idx1, idx2,
grad_dist1, grad_dist2, grad_xyz1, grad_xyz2);
}

void chamfer_distance_forward_impl(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
Tensor dist2, Tensor idx1, Tensor idx2);
REGISTER_NPU_IMPL(chamfer_distance_forward_impl,
chamfer_distance_forward_npu);

void chamfer_distance_backward_impl(Tensor xyz1, Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2,
Tensor grad_xyz1, Tensor grad_xyz2);
REGISTER_NPU_IMPL(chamfer_distance_backward_impl,
chamfer_distance_backward_npu);

0 comments on commit b89ace6

Please sign in to comment.