Skip to content

Commit

Permalink
[Refactor] Replace tin_shift op of MLU backend with mlu-ops (open-mml…
Browse files Browse the repository at this point in the history
  • Loading branch information
ClowDragon committed Oct 25, 2023
1 parent f3cec77 commit 2b6a805
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 444 deletions.
307 changes: 0 additions & 307 deletions mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu

This file was deleted.

28 changes: 14 additions & 14 deletions mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

#include "mlu_common_helper.h"

void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma, const float alpha) {
void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target, Tensor weight,
Tensor output, const float gamma,
const float alpha) {
// params check
TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ",
"But now gamma is ", gamma, ".");
Expand Down Expand Up @@ -82,15 +82,15 @@ void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target,
auto handle = mluOpGetCurrentHandle();

// launch kernel
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward(handle, prefer, reduction, input_desc.desc(),
input_ptr, target_desc.desc(), target_ptr,
weight_desc.desc(), weight_ptr, alpha, gamma,
output_desc.desc(), output_ptr));
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward(
handle, prefer, reduction, input_desc.desc(), input_ptr,
target_desc.desc(), target_ptr, weight_desc.desc(), weight_ptr, alpha,
gamma, output_desc.desc(), output_ptr));
}

void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma, const float alpha) {
void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, Tensor weight,
Tensor output, const float gamma,
const float alpha) {
// params check
TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ",
"But now gamma is ", gamma, ".");
Expand Down Expand Up @@ -158,10 +158,10 @@ void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target,
auto handle = mluOpGetCurrentHandle();

// launch kernel
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward(handle, prefer, reduction, input_desc.desc(),
input_ptr, target_desc.desc(), target_ptr,
weight_desc.desc(), weight_ptr, alpha, gamma,
output_desc.desc(), output_ptr));
TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward(
handle, prefer, reduction, input_desc.desc(), input_ptr,
target_desc.desc(), target_ptr, weight_desc.desc(), weight_ptr, alpha,
gamma, output_desc.desc(), output_ptr));
}

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Expand Down
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include "pytorch_device_registry.hpp"

#define MLUOP_MAJOR 0
#define MLUOP_MINOR 7
#define MLUOP_PATCHLEVEL 1
#define MLUOP_MINOR 8
#define MLUOP_PATCHLEVEL 0

/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
Expand Down
Loading

0 comments on commit 2b6a805

Please sign in to comment.