In [1]:
## Takes in a number of labels and predictions and computes the focal loss
## Inputs: 
#  - labels as N x 1 integer categories
#  - logits as N x num_of_classes represting class probabilities

In [2]:
import torch

In [13]:
## Test Inputs

Labels = torch.Tensor([1,0,4])
Logits_pre_softmax  = torch.Tensor([[-0.0431,  2.2473, -2.6419,  0.0818,  0.2348], 
                                    [-0.7288, -0.0937, -0.6486,  0.6135,  0.6113], 
                                    [-1.3231,  1.1315, -0.9422,  0.5064, -1.3047]])
Logits = torch.nn.functional.softmax(Logits_pre_softmax, dim=1)
Logits

tensor([[0.0746, 0.7369, 0.0055, 0.0845, 0.0985],
        [0.0861, 0.1624, 0.0933, 0.3295, 0.3288],
        [0.0468, 0.5452, 0.0685, 0.2918, 0.0477]])

In [64]:
## Current Version of Function
def focal_loss_softmax(labels, logits, alpha=0.5, gamma=2):
    """
     github.com/tensorflow/models/blob/master/\
         research/object_detection/core/losses.py
     Computer focal loss for binary classification
     Args:
       labels: A int32 tensor of shape [batch_size]. N x 1
       logits: A float32 tensor of shape [batch_size]. N x C
       alpha: A scalar for focal loss alpha hyper-parameter.
       If positive samples number > negtive samples number,
       alpha < 0.5 and vice versa.
       gamma: A scalar for focal loss gamma hyper-parameter.
     Returns:
       A tensor of the same shape as `labels`
     """
    # print("from focal_loss_softmax")
    # probs = softmax(logits, dim=1)
    probs = logits
    labels = torch.nn.functional.one_hot(labels.squeeze().long(), num_classes=probs.shape[1])
    
    print(f"logits shape: {logits.shape}")
    print("logits(softmax):")
    print(logits)
    print()
    
    print(f"labels shape: {labels.shape}")
    print("labels:")
    print(labels)
    print()

    alphas = 1/labels.sum(dim=0)
    print(f"alphas_pre_correction: {alphas}")
#     alphas[torch.isinf(alphas)] = torch.sum(alphas[torch.logical_not(torch.isinf(alphas))])
    alphas[torch.isinf(alphas)] = 0
    print(f"alphas shape: {alphas.shape}")
    print(f"alphas_post_correction: {alphas}")
    print()

    modulating = torch.pow(1-probs, gamma)
    print(f"modulation(gamma={gamma}):")
    print(modulating)
    print()
    
    pt = torch.sum(probs * labels, dim=1)
    focal_term = (1-pt)**gamma
    print("pt:")
    print(pt)
    print("focal_term:")
    print(focal_term)
    print()
    
    neg_log = -torch.log(probs)  # == log_p
    print("neg_log:")
    print(neg_log)
    print("cross_entropy")
    print(neg_log*labels)
    loss = torch.sum(neg_log*labels*alphas, dim=1)
    print("loss")
    print(loss)
    print()
    
    correction = focal_term*loss
    print("correction:")
    print(correction)
    print()
    
    focal_cross_entropy = alphas * modulating * neg_log
    # print(f"focal_cross_entropy.shape: {focal_cross_entropy.shape}")
    # print(f"focal_cross_entropy: {focal_cross_entropy}")
    return focal_cross_entropy

res1 = focal_loss_softmax(Labels, Logits, gamma=2)
print("Current Version Result, res1:")
print(res1)
tensor([0.0211, 2.0485, 2.7595])

Current Version Result, res1:
tensor([0.0211, 2.0485, 2.7595])


In [59]:
## Quick Check on if softmax and torch.log gives same as log_softmax
A = torch.log(torch.nn.functional.softmax(Logits_pre_softmax, dim=1))
B = torch.nn.functional.log_softmax(Logits_pre_softmax, dim=1)
print("A:")
print(A)
print("B:")
print(B)

A:
tensor([[-2.5958, -0.3054, -5.1946, -2.4708, -2.3179],
        [-2.4526, -1.8175, -2.3724, -1.1103, -1.1125],
        [-3.0613, -0.6067, -2.6804, -1.2318, -3.0429]])
B:
tensor([[-2.5958, -0.3054, -5.1946, -2.4708, -2.3178],
        [-2.4526, -1.8175, -2.3724, -1.1103, -1.1125],
        [-3.0613, -0.6067, -2.6804, -1.2318, -3.0429]])


In [51]:
##  Running the stock focal_loss
from typing import Optional, Sequence
from torch import Tensor
from torch import nn
from torch.nn import functional as F


class FocalLoss(nn.Module):
    """ Focal Loss, as described in https://arxiv.org/abs/1708.02002.
    It is essentially an enhancement to cross entropy loss and is
    useful for classification tasks when there is a large class imbalance.
    x is expected to contain raw, unnormalized scores for each class.
    y is expected to contain class labels.
    Shape:
        - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
        - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
    """

    def __init__(self,
                 alpha: Optional[Tensor] = None,
                 gamma: float = 0.,
                 reduction: str = 'mean',
                 ignore_index: int = -100):
        """Constructor.
        Args:
            alpha (Tensor, optional): Weights for each class. Defaults to None.
            gamma (float, optional): A constant, as described in the paper.
                Defaults to 0.
            reduction (str, optional): 'mean', 'sum' or 'none'.
                Defaults to 'mean'.
            ignore_index (int, optional): class label to ignore.
                Defaults to -100.
        """
        if reduction not in ('mean', 'sum', 'none'):
            raise ValueError(
                'Reduction must be one of: "mean", "sum", "none".')

        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction

        self.nll_loss = nn.NLLLoss(
            weight=alpha, reduction='none', ignore_index=ignore_index)

    def __repr__(self):
        arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction']
        arg_vals = [self.__dict__[k] for k in arg_keys]
        arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)]
        arg_str = ', '.join(arg_strs)
        return f'{type(self).__name__}({arg_str})'

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        if x.ndim > 2:
            # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
            c = x.shape[1]
            x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
            # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
            y = y.view(-1)

        unignored_mask = y != self.ignore_index
        y = y[unignored_mask]
        if len(y) == 0:
            return torch.tensor(0.)
        x = x[unignored_mask]

        # compute weighted cross entropy term: -alpha * log(pt)
        # (alpha is already part of self.nll_loss)
        
        log_p = F.log_softmax(x, dim=-1)
        print("log_p")
        print(log_p)
        print("y")
        print(y)
        y = y.type(torch.LongTensor)
        ce = self.nll_loss(log_p, y)

        # get true class column from each row
        all_rows = torch.arange(len(x))
        log_pt = log_p[all_rows, y]

        # compute focal term: (1 - pt)^gamma
        pt = log_pt.exp()
        focal_term = (1 - pt)**self.gamma

        # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
        loss = focal_term * ce

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()

        return loss


def focal_loss(alpha: Optional[Sequence] = None,
               gamma: float = 0.,
               reduction: str = 'mean',
               ignore_index: int = -100,
               device='cpu',
               dtype=torch.float32) -> FocalLoss:
    """Factory function for FocalLoss.
    Args:
        alpha (Sequence, optional): Weights for each class. Will be converted
            to a Tensor if not None. Defaults to None.
        gamma (float, optional): A constant, as described in the paper.
            Defaults to 0.
        reduction (str, optional): 'mean', 'sum' or 'none'.
            Defaults to 'mean'.
        ignore_index (int, optional): class label to ignore.
            Defaults to -100.
        device (str, optional): Device to move alpha to. Defaults to 'cpu'.
        dtype (torch.dtype, optional): dtype to cast alpha to.
            Defaults to torch.float32.
    Returns:
        A FocalLoss object
    """
    if alpha is not None:
        if not isinstance(alpha, Tensor):
            alpha = torch.tensor(alpha)
        alpha = alpha.to(device=device, dtype=dtype)

    fl = FocalLoss(
        alpha=alpha,
        gamma=gamma,
        reduction=reduction,
        ignore_index=ignore_index)
    return fl


In [53]:
labels = torch.nn.functional.one_hot(Labels.squeeze().long(), num_classes=Logits.shape[1])
print("labels")
print(labels)
print()

alphas = 1/labels.sum(dim=0)
alphas[torch.isinf(alphas)] = torch.sum(alphas[torch.logical_not(torch.isinf(alphas))])
print("alphas")
print(alphas)
print()

FL = focal_loss(alpha=alphas, gamma=2, reduction='none')
FL(Logits_pre_softmax, Labels)


labels
tensor([[0, 1, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1]])

alphas
tensor([1., 1., 3., 3., 1.])

log_p
tensor([[-2.5958, -0.3054, -5.1946, -2.4708, -2.3178],
        [-2.4526, -1.8175, -2.3724, -1.1103, -1.1125],
        [-3.0613, -0.6067, -2.6804, -1.2318, -3.0429]])
y
tensor([1., 0., 4.])


tensor([0.0211, 2.0485, 2.7595])

In [46]:
# test_nll_loss = torch.nn.NLLLoss(
#             weight=alphas, reduction='none')
test_nll_loss = torch.nn.NLLLoss(
            reduction='none')

In [48]:
ax = torch.Tensor([[-2.5958, -0.3054, -5.1946, -2.4708, -2.3178],
        [-2.4526, -1.8175, -2.3724, -1.1103, -1.1125],
        [-3.0613, -0.6067, -2.6804, -1.2318, -3.0429]])
by = torch.Tensor([1, 0, 4])
by = by.type(torch.LongTensor)
test_nll_loss(ax, by)

tensor([0.3054, 2.4526, 3.0429])

In [55]:
## Test Ru's version
def focal_loss_sigmoid(labels, logits, alpha=0.5, gamma=2):
    """
     github.com/tensorflow/models/blob/master/\
         research/object_detection/core/losses.py
     Computer focal loss for binary classification
     Args:
       labels: A int32 tensor of shape [batch_size]. N x 1
       logits: A float32 tensor of shape [batch_size]. N x C
       alpha: A scalar for focal loss alpha hyper-parameter.
       If positive samples number > negtive samples number,
       alpha < 0.5 and vice versa.
       gamma: A scalar for focal loss gamma hyper-parameter.
     Returns:
       A tensor of the same shape as `labels`
     """
     
    prob = logits.sigmoid()    #### [CORRECTION] Should be softmax since multi-class classification; assuming the focal cross works this way
    print("prob(sigmoids):")
    print(prob)
    print()
    
    labels = torch.nn.functional.one_hot(labels.squeeze().long(), num_classes=prob.shape[1])

    cross_ent = torch.clamp(logits, min=0) - logits * labels + torch.log(1+torch.exp(-torch.abs(logits)))
    print("cross_ent:")
    print(cross_ent)
    print()
    
    prob_t = (labels*prob) + (1-labels) * (1-prob)
    print("prob_t:")
    print(prob_t)
    print()
    
    modulating = torch.pow(1-prob_t, gamma)
    print("modulating:")
    print(modulating)
    print()
    
    alpha_weight = (labels*alpha)+(1-labels)*(1-alpha)
    print("alpha_weight:")
    print(alpha_weight)
    print()

    focal_cross_entropy = modulating * alpha_weight * cross_ent
    print("focal_cross_entropy:")
    print(focal_cross_entropy)
    print()

    return focal_cross_entropy

In [57]:
focal_loss_sigmoid(Labels, Logits_pre_softmax)

prob(sigmoids):
tensor([[0.4892, 0.9044, 0.0665, 0.5204, 0.5584],
        [0.3255, 0.4766, 0.3433, 0.6487, 0.6482],
        [0.2103, 0.7561, 0.2805, 0.6240, 0.2134]])

cross_ent:
tensor([[0.6718, 0.1005, 0.0688, 0.7349, 0.8174],
        [1.1225, 0.6474, 0.4205, 1.0462, 1.0448],
        [0.2361, 1.4111, 0.3291, 0.9781, 1.5447]])

prob_t:
tensor([[0.5108, 0.9044, 0.9335, 0.4796, 0.4416],
        [0.3255, 0.5234, 0.6567, 0.3513, 0.3518],
        [0.7897, 0.2439, 0.7195, 0.3760, 0.2134]])

modulating:
tensor([[0.2393, 0.0091, 0.0044, 0.2709, 0.3118],
        [0.4550, 0.2271, 0.1179, 0.4209, 0.4202],
        [0.0442, 0.5717, 0.0787, 0.3893, 0.6188]])

alpha_weight:
tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])

focal_cross_entropy:
tensor([[8.0399e-02, 4.5892e-04, 1.5209e-04, 9.9524e-02, 1.2746e-01],
        [2.5538e-01, 7.3525e-02, 2.4782e-02, 2.2016e-01, 2.1952e-01],
        [5.2212e-0

tensor([[8.0399e-02, 4.5892e-04, 1.5209e-04, 9.9524e-02, 1.2746e-01],
        [2.5538e-01, 7.3525e-02, 2.4782e-02, 2.2016e-01, 2.1952e-01],
        [5.2212e-03, 4.0336e-01, 1.2944e-02, 1.9039e-01, 4.7791e-01]])