Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix softmax focal loss algorithm #2893

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 77 additions & 39 deletions mmcv/ops/csrc/common/cuda/softmax_focal_loss_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,62 +10,100 @@

template <typename T>
__global__ void softmax_focal_loss_forward_cuda_kernel(
const int nthreads, const T* softmax, const int64_t* target,
const T* weight, T* output, const T gamma, const T alpha,
const int num_classes) {
const int nthreads, const T* __restrict__ log_softmax_prob,
const int64_t* __restrict__ target, const T* __restrict__ weight,
T* __restrict__ output,
const T gamma, const T alpha, const int num_classes) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int64_t label = target[index];
T pred = softmax[index * num_classes + label];
const int n = index / num_classes;
const int c = index % num_classes;

if (label >= 0) {
output[index] =
-alpha * pow((T)1. - pred, gamma) * log(max(pred, (T)FLT_MIN));
// focal loss
// FL(p) = - alpha * (1-p)^gamma * log(p) if curr_class == label
//
// note that log_softmax_prob is calculated in Python part
// by using PyTorch API F.log_softmax()
const int64_t label = target[n];
if (c == label) {
const T w = (weight != NULL) ? weight[label] : T(1);
const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w;

const T log_pred = log_softmax_prob[index];
const T pred = exp(log_pred);

output[index] = -alpha_fac * pow(1 - pred, gamma) * log_pred;
} else {
output[index] = 0;
}
if (weight != NULL) {
output[index] *= weight[label];
}
}
}

template <typename T>
__global__ void softmax_focal_loss_backward_cuda1_kernel(
const int nthreads, const T* softmax, const int64_t* target,
const T* weight, T* buff, const T gamma, const T alpha,
const int num_classes) {
__global__ void softmax_focal_loss_backward_cuda_kernel(
const int nthreads, const T* __restrict__ log_softmax_prob,
const int64_t* __restrict__ target, const T* __restrict__ weight,
T* __restrict__ sum_buff_along_class, T* __restrict__ grad_input,
const T gamma, const T alpha, const int num_classes) {
// forward node: x ----> p ----> FL
// func: SM FL
//
// backward node: x <---- p <---- FL
// index: j i FL
//
// For simplicity, the alpha of FL is ignored here
// dFL/dp = - [((1-p)^gamma) / p
// - gamma * (1-p)^(gamma-1) * log(p)]
// dp_i/dx_j = dSM/dx_j
// = p_i * (1-p_j) i==j;
// p_i * (0-p_j) i!=j;
// = p_i * (delta - p_j) where delta is Kronecker delta
//
// Replacing the p of dFL/dp with p_i, then
// dFL/dx_j = dFL/dp_i * dp_i/dx_j
// = - (delta - p_j) * [ (1-p_i)^gamma
// - gamma * (1-p_i)^(gamma-1) * log(p) * p_i]
// = (delta - p_j) * [- (1-p_i)^gamma +
// gamma * (1-p_i)^(gamma-1) * log(p) * p_i]
//
// Let B_i denote [- (1-p_i)^gamma +
// gamma * (1-p_i)^(gamma-1) * log(p) * p_i],
// and indices {i} is summed for all classes at index j
// since x_j received all the gradients from {p_i}.
// Then, dFL/dx_j = sum_i{ (delta - p_j) * B_i }
// = sum_i{ delta*B_i - p_j*B_i }
// = B_j - (p_j * sum_i{B_i})

CUDA_1D_KERNEL_LOOP(index, nthreads) {
int64_t label = target[index];
T pred = softmax[index * num_classes + label];
// B_i
const int n = index / num_classes;
const int c = index % num_classes;

if (label >= 0) {
buff[index] = alpha * (-pow((T)1. - pred, gamma) +
gamma * pow((T)1. - pred, gamma - 1) * pred *
log(max(pred, (T)FLT_MIN)));
const int64_t label = target[n];
if (c == label) {
const T w = (weight != NULL) ? weight[label] : T(1);
const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w;

const T log_pred = log_softmax_prob[index];
const T pred = exp(log_pred);
const T one_minus_pred = 1 - pred;

const T buff = alpha_fac * (
-pow(one_minus_pred, gamma) +
gamma * pow(one_minus_pred, gamma - 1) * log_pred * pred
);
grad_input[index] = buff;
sum_buff_along_class[n] += buff;
} else {
buff[index] = 0;
}
if (weight != NULL) {
buff[index] *= weight[label];
grad_input[index] = 0;
}
}
}

template <typename T>
__global__ void softmax_focal_loss_backward_cuda2_kernel(
const int nthreads, const T* softmax, const int64_t* target, const T* buff,
T* grad_input, const int num_classes) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n = index / num_classes;
int c = index % num_classes;
int64_t label = target[n];
// dFL/dx_j
const int n = index / num_classes;

if (label >= 0) {
T flag = (label == c ? (T)1. : (T)0.);
grad_input[index] = buff[n] * (flag - softmax[index]);
} else {
grad_input[index] = 0;
}
const T pred = exp(log_softmax_prob[index]);
grad_input[index] -= pred * sum_buff_along_class[n];
}
}

Expand Down
61 changes: 40 additions & 21 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,17 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target,
const float gamma,
const float alpha);

void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor output,
void SoftmaxFocalLossForwardCUDAKernelLauncher(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha);

void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor buff,
void SoftmaxFocalLossBackwardCUDAKernelLauncher(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha);
Expand All @@ -433,18 +437,26 @@ void sigmoid_focal_loss_backward_cuda(Tensor input, Tensor target,
gamma, alpha);
}

void softmax_focal_loss_forward_cuda(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
SoftmaxFocalLossForwardCUDAKernelLauncher(input, target, weight, output,
gamma, alpha);
void softmax_focal_loss_forward_cuda(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha) {
SoftmaxFocalLossForwardCUDAKernelLauncher(log_softmax_prob, target, weight,
output, gamma, alpha);
}

void softmax_focal_loss_backward_cuda(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha) {
SoftmaxFocalLossBackwardCUDAKernelLauncher(input, target, weight, buff,
grad_input, gamma, alpha);
void softmax_focal_loss_backward_cuda(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha) {
SoftmaxFocalLossBackwardCUDAKernelLauncher(log_softmax_prob, target, weight,
sum_buff_along_class, grad_input,
gamma, alpha);
}

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Expand All @@ -454,13 +466,20 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor grad_input,
float gamma, float alpha);

void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);

void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha);
void softmax_focal_loss_forward_impl(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha);

void softmax_focal_loss_backward_impl(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha);

REGISTER_DEVICE_IMPL(sigmoid_focal_loss_forward_impl, CUDA,
sigmoid_focal_loss_forward_cuda);
Expand Down
59 changes: 24 additions & 35 deletions mmcv/ops/csrc/pytorch/cuda/focal_loss_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,64 +47,53 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target,
AT_CUDA_CHECK(cudaGetLastError());
}

void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor output,
void SoftmaxFocalLossForwardCUDAKernelLauncher(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha) {
int output_size = output.numel();
int num_classes = softmax.size(1);
int num_classes = log_softmax_prob.size(1);

AT_ASSERTM(target.max().item<int64_t>() <= (int64_t)num_classes,
"target label should smaller or equal than num classes");
at::cuda::CUDAGuard device_guard(softmax.device());
at::cuda::CUDAGuard device_guard(log_softmax_prob.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
softmax.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
log_softmax_prob.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
softmax_focal_loss_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), gamma, alpha, num_classes);
output_size,
log_softmax_prob.data_ptr<scalar_t>(), target.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
gamma, alpha, num_classes);
});

AT_CUDA_CHECK(cudaGetLastError());
}

void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor buff,
void SoftmaxFocalLossBackwardCUDAKernelLauncher(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha) {
int num_classes = softmax.size(1);
int output_size = grad_input.numel();
int num_classes = log_softmax_prob.size(1);

int output_size = buff.numel();
at::cuda::CUDAGuard device_guard(grad_input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(),
"softmax_focal_loss_backward_cuda1_"
"kernel",
[&] {
softmax_focal_loss_backward_cuda1_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
buff.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});

AT_CUDA_CHECK(cudaGetLastError());

output_size = grad_input.numel();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(),
"softmax_focal_loss_backward_cuda2_"
"kernel",
[&] {
softmax_focal_loss_backward_cuda2_kernel<scalar_t>
log_softmax_prob.scalar_type(), "softmax_focal_loss_backward_cuda_kernel", [&] {
softmax_focal_loss_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), buff.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), num_classes);
output_size,
log_softmax_prob.data_ptr<scalar_t>(), target.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), sum_buff_along_class.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(),
gamma, alpha, num_classes);
});

AT_CUDA_CHECK(cudaGetLastError());
Expand Down
52 changes: 35 additions & 17 deletions mmcv/ops/csrc/pytorch/focal_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,26 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
grad_input, gamma, alpha);
}

void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, input, target, weight,
output, gamma, alpha);
void softmax_focal_loss_forward_impl(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, log_softmax_prob,
target, weight, output, gamma, alpha);
}

void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, input, target, weight,
buff, grad_input, gamma, alpha);
void softmax_focal_loss_backward_impl(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, log_softmax_prob,
target, weight, sum_buff_along_class, grad_input,
gamma, alpha);
}

#ifdef MMCV_WITH_DIOPI
Expand Down Expand Up @@ -127,14 +135,24 @@ void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
#endif
}

void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
softmax_focal_loss_forward_impl(input, target, weight, output, gamma, alpha);
void softmax_focal_loss_forward(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha) {
softmax_focal_loss_forward_impl(log_softmax_prob, target, weight,
output, gamma, alpha);
}

void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input, float gamma,
float alpha) {
softmax_focal_loss_backward_impl(input, target, weight, buff, grad_input,
void softmax_focal_loss_backward(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha) {
softmax_focal_loss_backward_impl(log_softmax_prob, target, weight,
sum_buff_along_class, grad_input,
gamma, alpha);
}
Loading