In [25]:
import sys
sys.path.append('..')
import torch
import torch.nn as nn
import numpy as np

In [38]:
class Focal2DLoss(nn.Module):
    def __init__(self, alpha=0.98, gamma=2):
        super(Focal2DLoss, self).__init__()
        self.base_criterion = nn.MSELoss()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, logit, label):
        pos_label = label
        pos_logit = logit * pos_label
        pos_beta = (pos_label.sum().item() - pos_logit.sum().item()) / (pos_label.sum().item() + 1e-5)
        pos_loss = self.alpha * (pos_beta ** self.gamma) * self.base_criterion(pos_logit, pos_label)
        
        neg_label = 1 - label
        neg_logit = logit * neg_label
        neg_beta = 1 - (neg_label.sum().item() - neg_logit.sum().item()) / (neg_label.sum().item() + 1e-5)
        neg_loss = (1 - self.alpha) * (neg_beta ** self.gamma) * self.base_criterion(neg_logit + label, label)
        
        return pos_loss, neg_loss

In [63]:
logit = torch.ones((7, 7)) * 0.5
logit

tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])

In [64]:
label = torch.zeros((7, 7))
label[2][3] = 1
label

tensor([[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]])

In [65]:
base_criterion = nn.BCELoss(reduction='none')
base_loss = base_criterion(logit, label)
pos_loss = base_loss * label
neg_loss = base_loss * (1 - label)
print('pos_loss: {} | neg_loss: {} | neg_loss/pos_loss: {}'.format(pos_loss.sum().item(), neg_loss.sum().item(), neg_loss.sum().item()/pos_loss.sum().item()))

pos_loss: 0.6931471824645996 | neg_loss: 33.271080017089844 | neg_loss/pos_loss: 48.00002201377925


In [66]:
improve_criterion = Focal2DLoss()
pos_loss, neg_loss = improve_criterion(logit, label)
print('pos_loss: {} | neg_loss: {} | neg_loss/pos_loss: {}'.format(pos_loss.sum().item(), neg_loss.sum().item(), neg_loss.sum().item()/pos_loss.sum().item()))

pos_loss: 0.00124997494276613 | neg_loss: 0.0012244903482496738 | neg_loss/pos_loss: 0.9796119156916377
