In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class FocalLoss2d(nn.Module):
    def __init__(self, gamma=0, weight=1, reduction="mean"):
        super().__init__()
        self.gamma = gamma
        self.weight = weight
        if reduction.lower() == "none":
            self.reduction_op = None
        elif reduction.lower() == "mean":
            self.reduction_op = torch.mean
        elif reduction.lower() == "sum":
            self.reduction_op = torch.sum
        else:
            raise ValueError("expected one of ('none', 'mean', 'sum'), got {}".format(reduction))
        
    def forward(self, input, target):
        if input.dim() == 4:
            input = input.permute(0, 2, 3, 1)
            input = input.contiguous().view(-1, input.size(-1))
        elif input.dim() != 2:
            raise ValueError("expected input of size 4 or 2, got {}".format(input.size()))
            
        if target.dim() == 3:
            target = target.contiguous().view(-1)
        elif target.dim() != 1:
            raise ValueError("expected target of size 3 or 1, got {}".format(target.size()))
            
        m = input.size(0)
        probabilities = F.softmax(input[range(m), target], dim=0)
        focal = self.weight * (1 - probabilities).pow(self.gamma)
        ce = F.cross_entropy(input, target, reduction="none")
        loss = focal * ce
        
        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss
        
    def forward_onehot(self, input, target):
        if input.dim() != 2 and input.dim() != 4:
            raise ValueError("expected input of size 4 or 2, got {}".format(input.dim()))
            
        if target.dim() != 1 and target.dim() != 3:
            raise ValueError("expected target of size 3 or 1, got {}".format(target.dim()))
            
        target_onehot = onehot(target, input.size(1))
            
        m = input.size(0)
        probabilities = torch.sum(target_onehot * F.softmax(input, dim=0), dim=1)
        focal = self.weight * (1 - probabilities).pow(self.gamma)
        ce = F.cross_entropy(input, target, reduction="none")
        loss = focal * ce
        
        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss

In [3]:
def onehot(tensor, num_classes):    
    tensor = tensor.unsqueeze(1)
    onehot = torch.zeros(tensor.size(0), num_classes, *tensor.size()[2:])
    onehot.scatter_(1, tensor, 1)
    
    return onehot

In [4]:
loss = FocalLoss2d()

target = torch.Tensor([1]).long()
out = torch.Tensor([[-100, 100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))
print("Onehot loss:\n", loss.forward_onehot(out, target))

Target:
 tensor([1])
Model out:
 tensor([[-100.,  100., -100.,  -50.]])
Loss:
 tensor(0.)
Onehot loss:
 tensor(0.)


In [5]:
target = torch.Tensor([1, 0, 0]).long()
out = torch.Tensor([[-100, 100, -100, -50], [-100, -100, -100, 50], [100, -100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))
print("Onehot loss:\n", loss.forward_onehot(out, target))

Target:
 tensor([1, 0, 0])
Model out:
 tensor([[-100.,  100., -100.,  -50.],
        [-100., -100., -100.,   50.],
        [ 100., -100., -100.,  -50.]])
Loss:
 tensor(50.)
Onehot loss:
 tensor(50.)


In [6]:
target = torch.randint(3, (2, 5, 5)).long()
out = torch.randint(3, (2, 5, 5)).long()
out = onehot(out, 3).float() * 100
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))
print("Onehot loss:\n", loss.forward_onehot(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 3, 5, 5])
Loss:
 tensor(82.)
Onehot loss:
 tensor(82.)


In [7]:
target = torch.randint(3, (2, 5, 5)).long()
out = onehot(target, 3).float() * 100
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))
print("Onehot loss:\n", loss.forward_onehot(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 3, 5, 5])
Loss:
 tensor(0.)
Onehot loss:
 tensor(0.)


In [10]:
target = torch.randint(3, (2, 5, 5)).long()
out = onehot(target, 3).float() * 100
%timeit -n 1000 loss.forward(out, target)

2.47 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [11]:
%timeit -n 1000 loss.forward_onehot(out, target)

5.33 ms ± 1.74 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)


The one-hot solution is almost 2 times slower than the permute-view solution. Will use permute-view solution.