In [1]:
from polyloss import PolyLoss, PolyFocalLoss

In [2]:
from math import pow

In [3]:
itern = 1000

In [4]:
def crossentropyloss(prob):
    loss = 0.0
    for i in range(itern):
        loss += 1/(i+1) * pow(1-prob, i+1)
    return loss

In [5]:
def focalloss(prob, gamma):
    loss = 0.0
    for i in range(itern):
        loss += 1/(i+1) * pow(1-prob, gamma+i+1)
    return loss

In [6]:
ploss = PolyLoss(epsilon=0)

In [7]:
pfocal = PolyFocalLoss(epsilon=0, gamma=2, alpha=[0.25, 0, 0])

In [8]:
import torch

In [9]:
a = torch.rand([1, 3])

In [10]:
m = torch.nn.Softmax(dim=1)

In [11]:
ip = m(a)

In [12]:
ip

tensor([[0.5086, 0.2517, 0.2397]])

In [13]:
tgt = torch.Tensor([2]).long()

In [14]:
ploss(a, tgt)

tensor(1.4283)

In [15]:
pfocal(a, tgt)

tensor(0.)

In [16]:
crossentropyloss(ip[0][2])

1.4283041468903026

In [17]:
focalloss(ip[0][2], 2)

0.8256071168755332

In [18]:
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.CrossEntropyLoss):
    ''' Focal loss for classification tasks on imbalanced datasets '''

    def __init__(self, gamma, alpha=None, ignore_index=-100, reduction='none'):
        super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none')
        self.reduction = reduction
        self.gamma = gamma

    def forward(self, input_, target):
        cross_entropy = super().forward(input_, target)
        print(cross_entropy)
        # Temporarily mask out ignore index to '0' for valid gather-indices input.
        # This won't contribute final loss as the cross_entropy contribution
        # for these would be zero.
        target = target * (target != self.ignore_index).long()
        input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1))
        print(input_prob)
        loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy
        if self.reduction == 'mean':
            return torch.mean(loss) 
        elif self.reduction == 'sum':
            return torch.sum(loss)
        return loss

In [19]:
fl = FocalLoss(2, reduction='mean')

In [20]:
fl(a, tgt)

tensor(1.4283)
tensor([[0.2397]])


tensor(0.8256)