In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class FocalLoss2d(nn.Module):
    def __init__(self, gamma=0, weight=1, reduction="mean"):
        super().__init__()
        self.gamma = gamma
        self.weight = weight
        if reduction.lower() == "none":
            self.reduction_op = None
        elif reduction.lower() == "mean":
            self.reduction_op = torch.mean
        elif reduction.lower() == "sum":
            self.reduction_op = torch.sum
        else:
            raise ValueError("expected one of ('none', 'mean', 'sum'), got {}".format(reduction))
        
    def forward(self, input, target):
        if input.dim() == 4:
            input = input.contiguous().view(-1, input.size(1))
        elif input.dim() != 2:
            raise ValueError("expected input of size 4 or 2, got {}".format(input.size()))
            
        if target.dim() == 3:
            target = target.contiguous().view(-1)
        elif target.dim() != 1:
            raise ValueError("expected target of size 3 or 1, got {}".format(target.size()))
            
        m = input.size(0)
        probabilities = F.softmax(input[range(m), target], dim=0)
        focal = self.weight * (1 - probabilities).pow(self.gamma)
        ce = F.cross_entropy(input, target, reduction="none")
        loss = focal * ce
        
        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss    

In [50]:
def onehot(tensor, num_classes):
    tensor = tensor.unsqueeze(1)
    print(tensor.size())
    onehot = torch.zeros(tensor.size(0), num_classes, tensor.size(2), tensor.size(3))
    onehot.scatter_(1, tensor, 1)
    
    return onehot

In [4]:
loss = FocalLoss2d()

target = torch.Tensor([1]).long()
out = torch.Tensor([[-100, 100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))

Target:
 tensor([1])
Model out:
 tensor([[-100.,  100., -100.,  -50.]])
Loss:
 tensor(0.)


In [5]:
target = torch.Tensor([1, 0, 0]).long()
out = torch.Tensor([[-100, 100, -100, -50], [-100, -100, -100, 50], [100, -100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))

Target:
 tensor([1, 0, 0])
Model out:
 tensor([[-100.,  100., -100.,  -50.],
        [-100., -100., -100.,   50.],
        [ 100., -100., -100.,  -50.]])
Loss:
 tensor(50.)


In [51]:
target = torch.randint(3, (2, 5, 5)).long()
out = torch.randint(3, (2, 5, 5)).long()
out = onehot(out, 3).float()
print("Target:\n", target[0])
print("Model out:\n", out[0, :, :, :])
print("Loss:\n", loss.forward(out, target))

torch.Size([2, 1, 5, 5])
Target:
 tensor([[2, 1, 1, 2, 1],
        [2, 2, 2, 1, 2],
        [2, 2, 1, 2, 1],
        [1, 0, 0, 0, 2],
        [2, 2, 0, 0, 1]])
Model out:
 tensor([[[0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 1.],
         [0., 1., 0., 1., 0.],
         [1., 1., 0., 1., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 1., 1., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 1., 0., 1.],
         [0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 1.]],

        [[1., 1., 0., 0., 1.],
         [0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1.],
         [1., 1., 0., 1., 0.]]])
Loss:
 tensor(1.1430)


In [59]:
target = torch.randint(3, (2, 5, 5)).long()
out = onehot(target, 3).float()
assert torch.eq(torch.max(out, 1)[1], target)
print("Target:\n", target[0])
print("Model out:\n", out[0, :, :, :])
print("Loss:\n", loss.forward(out, target))

torch.Size([2, 1, 5, 5])


RuntimeError: bool value of Tensor with more than one value is ambiguous

In [62]:
torch.eq(torch.max(out, 1)[1], target).all()

tensor(1, dtype=torch.uint8)

In [55]:
b

tensor([[[0, 1, 2, 2, 2],
         [1, 2, 0, 1, 1],
         [1, 1, 1, 1, 2],
         [1, 2, 1, 1, 1],
         [2, 2, 1, 1, 1]],

        [[2, 1, 0, 2, 0],
         [0, 2, 0, 1, 2],
         [2, 2, 1, 2, 2],
         [0, 0, 1, 1, 2],
         [1, 0, 2, 2, 1]]])