Skip to content

Commit

Permalink
change to aclnn.
Browse files Browse the repository at this point in the history
  • Loading branch information
RRaoyzee committed May 6, 2024
1 parent f0497a6 commit be17839
Showing 1 changed file with 3 additions and 53 deletions.
56 changes: 3 additions & 53 deletions mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,59 +91,9 @@ void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shap
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight, const int im2col_step) {
check_support(value, attn_weight);
int64_t num_heads = value.size(2);
int64_t embed_dims = value.size(3);
int64_t num_points = attn_weight.size(4);
TORCH_CHECK(embed_dims % 32 == 0, "embed_dims must be a multiple of 32, but embed_dims is", embed_dims, ".");
TORCH_CHECK(num_points % 4 == 0, "num_points must be a multiple of four, but num_points is", num_points, ".");
TORCH_CHECK(num_heads % 4 == 0, "num_heads must be a multiple of four, but num_heads is", num_heads, ".");
at::Tensor value_fp32 = value;
at::Tensor spatial_shapes_int32 = spatial_shapes;
at::Tensor level_start_index_int32 = level_start_index;
at::Tensor sampling_loc_fp32 = sampling_loc.transpose(4, 5).contiguous();
at::Tensor attn_weight_fp32 = attn_weight;
at::Tensor grad_output_fp32 = grad_output;
if (value.scalar_type() != at::kFloat) {
value_fp32 = value.to(at::kFloat);
}
if (spatial_shapes.scalar_type() != at::kInt) {
spatial_shapes_int32 = spatial_shapes.to(at::kInt);
}
if (level_start_index.scalar_type() != at::kInt) {
level_start_index_int32 = level_start_index.to(at::kInt);
}
if (sampling_loc.scalar_type() != at::kFloat) {
sampling_loc_fp32 = sampling_loc_fp32.to(at::kFloat);
}
if (attn_weight.scalar_type() != at::kFloat) {
attn_weight_fp32 = attn_weight.to(at::kFloat);
}
if (grad_output.scalar_type() != at::kFloat) {
grad_output_fp32 = grad_output.to(at::kFloat);
}
ori_type = value.scalar_type();
at::Tensor grad_value_temp = at::zeros(value_fp32.sizes(), value_fp32.options());
at::Tensor grad_sampling_loc_temp = at::zeros(sampling_loc_fp32.sizes(), sampling_loc_fp32.options());
at::Tensor grad_attn_weight_temp = at::zeros(attn_weight_fp32.sizes(), attn_weight_fp32.options());

OpCommand cmd;
cmd.Name("MultiScaleDeformableAttentionGrad")
.Input(value_fp32)
.Input(spatial_shapes_int32)
.Input(level_start_index_int32)
.Input(sampling_loc_fp32)
.Input(attn_weight_fp32)
.Input(grad_output_fp32)
.Output(grad_value_temp)
.Output(grad_sampling_loc_temp)
.Output(grad_attn_weight_temp)
.Run();
grad_value_temp = grad_value_temp.to(ori_type);
grad_sampling_loc_temp = grad_sampling_loc_temp.transpose(4, 5).contiguous().to(ori_type);
grad_attn_weight_temp = grad_attn_weight_temp.to(ori_type);
grad_value.copy_(grad_value_temp);
grad_sampling_loc.copy_(grad_sampling_loc_temp);
grad_attn_weight.copy_(grad_attn_weight_temp);
EXEC_NPU_CMD(aclnnMultiScaleDeformableAttentionGrad, value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, grad_output
grad_value, grad_sampling_loc, grad_attn_weight);
}

REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_backward_npu);

0 comments on commit be17839

Please sign in to comment.