In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Binary focal loss 

In [2]:
class BinaryFocalWithLogitsLoss(nn.Module):
    """Computes the focal loss with logits for binary data.

    The Focal Loss is designed to address the one-stage object detection scenario in
    which there is an extreme imbalance between foreground and background classes during
    training (e.g., 1:1000). Focal loss is defined as:

    FL = alpha(1 - p)^gamma * CE(p, y)
    where p are the probabilities, after applying the Softmax layer to the logits,
    alpha is a balancing parameter, gamma is the focusing parameter, and CE(p, y) is the
    cross entropy loss. When gamma=0 and alpha=1 the focal loss equals cross entropy.

    See: https://arxiv.org/abs/1708.02002

    Arguments:
        num_classes (int): number of classes in the classification problem
        gamma (float, optional): focusing parameter. Default: 2.
        alpha (float, optional): balancing parameter. Default: 0.25.
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'mean'
        eps (float, optional): small value to avoid division by zero. Default: 1e-6.

    """

    def __init__(self, gamma=2, alpha=0.25, reduction="mean"):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        if reduction.lower() == "none":
            self.reduction_op = None
        elif reduction.lower() == "mean":
            self.reduction_op = torch.mean
        elif reduction.lower() == "sum":
            self.reduction_op = torch.sum
        else:
            raise ValueError(
                "expected one of ('none', 'mean', 'sum'), got {}".format(reduction)
            )        
    
    def forward(self, input, target):
        if input.size() != target.size():
            raise ValueError(
                "size mismatch, {} != {}".format(input.size(), target.size())
            )
        elif target.unique(sorted=True).tolist() not in [[0, 1], [0], [1]]:
            raise ValueError("target values are not binary")
            
        input = input.view(-1)
        target = target.view(-1)
        
        # Following the paper: probabilities = probabilities if y=1; otherwise, probabilities = 1-probabilities
        probabilities = torch.sigmoid(input)
        probabilities = torch.where(target == 1, probabilities, 1 - probabilities)
        
        # Compute the loss
        focal = self.alpha * (1 - probabilities).pow(self.gamma)
        bce = nn.functional.binary_cross_entropy_with_logits(input, target, reduction="none")
        loss = focal * bce

        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss
        
    def forward_heng(self, logits, labels):
        """https://www.kaggle.com/c/carvana-image-masking-challenge/discussion/39951"""
        probs  = torch.sigmoid(logits)

        w_pos = torch.pow((1-probs), self.gamma)
        w_neg = torch.pow((probs), self.gamma)
        weights = (labels==1).float()*w_pos + (labels==0).float()*w_neg

        inputs   = logits.view (-1)
        targets = labels.view(-1)
        weights   = weights.view (-1)

        loss = weights * inputs.clamp(min=0) - weights * inputs * targets + weights * torch.log(1 + torch.exp(-inputs.abs()))
        loss = loss.sum() / weights.sum()

        return loss
    
    def forward_adrien(self, input, target):
        """https://becominghuman.ai/investigating-focal-and-dice-loss-for-the-kaggle-2018-data-science-bowl-65fb9af4f36c"""
        # Inspired by the implementation of binary_cross_entropy_with_logits
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()

        # This formula gives us the log sigmoid of 1-p if y is 0 and of p if y is 1
        invprobs = nn.functional.logsigmoid(-input * (target * 2 - 1))
        loss = (invprobs * self.gamma).exp() * loss
        
        return loss.mean()

The following losses should match very closely:

In [3]:
loss = BinaryFocalWithLogitsLoss(alpha=1)

target = torch.Tensor([1])
out = torch.Tensor([2.2])
print("Target:\n", target)
print("Model out:\n", out)
print("BF Loss x100:\n", loss.forward(out, target) * 100)
print("BF Loss Heng x100:\n", loss.forward_heng(out, target) * 100)
print("Loss Adrien x100:\n", loss.forward_adrien(out, target) * 100)
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 tensor([1.])
Model out:
 tensor([2.2000])
BF Loss x100:
 tensor(0.1046)
BF Loss Heng x100:
 tensor(10.5083)
Loss Adrien x100:
 tensor(0.1046)
BCE Loss:
 tensor(0.1051)


In [4]:
target = torch.Tensor([1])
out = torch.Tensor([3.43])
print("Target:\n", target)
print("Model out:\n", out)
print("BF Loss x1000:\n", loss.forward(out, target) * 1000)
print("BF Loss Heng x1000:\n", loss.forward_heng(out, target) * 1000)
print("Loss Adrien x1000:\n", loss.forward_adrien(out, target) * 1000)
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 tensor([1.])
Model out:
 tensor([3.4300])
BF Loss x1000:
 tensor(0.0314)
BF Loss Heng x1000:
 tensor(31.8735)
Loss Adrien x1000:
 tensor(0.0314)
BCE Loss:
 tensor(0.0319)


Some more tests

In [5]:
target = torch.Tensor([1, 0])
out = torch.Tensor([100, -50])
print("Target:\n", target)
print("Model out:\n", out)
print("BF Loss:\n", loss.forward(out, target))
print("BF Loss Heng:\n", loss.forward_heng(out, target))
print("Loss Adrien:\n", loss.forward_adrien(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 tensor([1., 0.])
Model out:
 tensor([100., -50.])
BF Loss:
 tensor(0.)
BF Loss Heng:
 tensor(0.)
Loss Adrien:
 tensor(0.)
BCE Loss:
 tensor(0.)


In [6]:
target = torch.Tensor([1, 0, 0, 0, 1])
out = torch.Tensor([-5, -2.5, -6, -10, -2])
print("Target:\n", target)
print("Model out:\n", out)
print("BF Loss:\n", loss.forward(out, target))
print("BF Loss Heng:\n", loss.forward_heng(out, target))
print("Loss Adrien:\n", loss.forward_adrien(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 tensor([1., 0., 0., 0., 1.])
Model out:
 tensor([ -5.0000,  -2.5000,  -6.0000, -10.0000,  -2.0000])
BF Loss:
 tensor(1.3181)
BF Loss Heng:
 tensor(3.7272)
Loss Adrien:
 tensor(1.3181)
BCE Loss:
 tensor(1.4430)


In [7]:
target = torch.randint(2, (2, 5, 5))
out = torch.randint(2, (2, 5, 5)) * 3.44
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("BF Loss:\n", loss.forward(out, target))
print("BF Loss Heng:\n", loss.forward_heng(out, target))
print("Loss Adrien:\n", loss.forward_adrien(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 5, 5])
BF Loss:
 tensor(1.1296)
BF Loss Heng:
 tensor(2.6541)
Loss Adrien:
 tensor(1.1296)
BCE Loss:
 tensor(1.4632)


In [8]:
target = torch.randint(2, (2, 5, 5)).float()
out = (target * 100) - 50
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("BF Loss:\n", loss.forward(out, target))
print("BF Loss Heng:\n", loss.forward_heng(out, target))
print("Loss Adrien:\n", loss.forward_adrien(out, target))
print("BCE Loss:\n", nn.functional.binary_cross_entropy_with_logits(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 5, 5])
BF Loss:
 tensor(0.)
BF Loss Heng:
 tensor(0.)
Loss Adrien:
 tensor(0.)
BCE Loss:
 tensor(0.)


In [9]:
target = torch.randint(2, (5, 2048, 2048))
out = target * 100
%timeit loss.forward(out, target)

1.4 s ± 254 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Multi-class focal loss

In [10]:
class FocalWithLogitsLoss(nn.Module):
    """Computes the focal loss with logits.

    The Focal Loss is designed to address the one-stage object detection scenario in
    which there is an extreme imbalance between foreground and background classes during
    training (e.g., 1:1000). Focal loss is defined as:

    FL = alpha(1 - p)^gamma * CE(p, y)
    where p are the probabilities, after applying the Softmax layer to the logits,
    alpha is a balancing parameter, gamma is the focusing parameter, and CE(p, y) is the
    cross entropy loss. When gamma=0 and alpha=1 the focal loss equals cross entropy.

    See: https://arxiv.org/abs/1708.02002

    Arguments:
        gamma (float, optional): focusing parameter. Default: 2.
        alpha (float, optional): balancing parameter. Default: 0.25.
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'mean'
        eps (float, optional): small value to avoid division by zero. Default: 1e-6.

    """

    def __init__(self, gamma=2, alpha=0.25, reduction="mean"):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        if reduction.lower() == "none":
            self.reduction_op = None
        elif reduction.lower() == "mean":
            self.reduction_op = torch.mean
        elif reduction.lower() == "sum":
            self.reduction_op = torch.sum
        else:
            raise ValueError(
                "expected one of ('none', 'mean', 'sum'), got {}".format(reduction)
            )

    def forward(self, input, target):
        if input.dim() == 4:
            input = input.permute(0, 2, 3, 1)
            input = input.contiguous().view(-1, input.size(-1))
        elif input.dim() != 2:
            raise ValueError(
                "expected input of size 4 or 2, got {}".format(input.dim())
            )

        if target.dim() == 3:
            target = target.contiguous().view(-1)
        elif target.dim() != 1:
            raise ValueError(
                "expected target of size 3 or 1, got {}".format(target.dim())
            )

        if target.dim() != input.dim() - 1:
            raise ValueError(
                "expected target dimension {} for input dimension {}, got {}".format(
                    input.dim() - 1, input.dim(), target.dim()
                )
            )

        m = input.size(0)
        probabilities = nn.functional.softmax(input[range(m), target], dim=0)
        focal = self.alpha * (1 - probabilities).pow(self.gamma)
        ce = nn.functional.cross_entropy(
            input, target, reduction="none"
        )
        loss = focal * ce

        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss
        
    def forward_onehot(self, input, target):
        if input.dim() != 2 and input.dim() != 4:
            raise ValueError("expected input of size 4 or 2, got {}".format(input.dim()))
            
        if target.dim() != 1 and target.dim() != 3:
            raise ValueError("expected target of size 3 or 1, got {}".format(target.dim()))
            
        target_onehot = to_onehot(target, input.size(1))
            
        m = input.size(0)
        probabilities = torch.sum(target_onehot * F.softmax(input, dim=0), dim=1)
        focal = self.alpha * (1 - probabilities).pow(self.gamma)
        ce = F.cross_entropy(input, target, reduction="none")
        loss = focal * ce
        
        if self.reduction_op is not None:
            return self.reduction_op(loss)
        else:
            return loss

In [11]:
def to_onehot(tensor, num_classes):    
    tensor = tensor.unsqueeze(1)
    onehot = torch.zeros(tensor.size(0), num_classes, *tensor.size()[2:])
    onehot.scatter_(1, tensor, 1)
    
    return onehot

In [12]:
loss = FocalWithLogitsLoss()

target = torch.Tensor([1]).long()
out = torch.Tensor([[-100, 100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))
print("Onehot loss:\n", loss.forward_onehot(out, target))

Target:
 tensor([1])
Model out:
 tensor([[-100.,  100., -100.,  -50.]])
Loss:
 tensor(0.)
Onehot loss:
 tensor(0.)


In [13]:
target = torch.Tensor([1, 0, 0]).long()
out = torch.Tensor([[-100, 100, -100, -50], [-100, -100, -100, 50], [100, -100, -100, -50]]).float()
print("Target:\n", target)
print("Model out:\n", out)
print("Loss:\n", loss.forward(out, target))
print("Onehot loss:\n", loss.forward_onehot(out, target))

Target:
 tensor([1, 0, 0])
Model out:
 tensor([[-100.,  100., -100.,  -50.],
        [-100., -100., -100.,   50.],
        [ 100., -100., -100.,  -50.]])
Loss:
 tensor(12.5000)
Onehot loss:
 tensor(12.5000)


In [14]:
target = torch.randint(3, (2, 5, 5)).long()
out = torch.randint(3, (2, 5, 5)).long()
out = to_onehot(out, 3).float() * 100
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))
print("Onehot loss:\n", loss.forward_onehot(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 3, 5, 5])
Loss:
 tensor(15.)
Onehot loss:
 tensor(6.3750)


In [15]:
target = torch.randint(3, (2, 5, 5)).long()
out = to_onehot(target, 3).float() * 100
print("Target:\n", target.size())
print("Model out:\n", out.size())
print("Loss:\n", loss.forward(out, target))
print("Onehot loss:\n", loss.forward_onehot(out, target))

Target:
 torch.Size([2, 5, 5])
Model out:
 torch.Size([2, 3, 5, 5])
Loss:
 tensor(0.)
Onehot loss:
 tensor(0.)


In [16]:
target = torch.randint(3, (5, 2048, 2048)).long()
out = to_onehot(target, 3).float() * 100
%timeit loss.forward(out, target)

3.32 s ± 144 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
%timeit loss.forward_onehot(out, target)

5.44 s ± 58.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


The one-hot solution is almost 2 times slower than the permute-view solution. Will use permute-view solution.