# Focal Loss

In [2]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

In [53]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2., alpha=0.25, size_average=True):
        super(FocalLoss, self).__init__()

        self.gamma = gamma 
        self.alpha = alpha 
        self.size_average = size_average

        if isinstance(alpha, (float, int)):
            self.alpha = torch.Tensor([alpha, 1. - alpha])
        elif isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha)

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1) # N, C, H, W => N, C, H * W
            input = input.transpose(1, 2)
            input = input.contiguous().view(-1, input.size(2)) # N, C, H * W -> N * H * W, C
            
        target = target.view(-1, 1)
        
        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.exp()
        
        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * at
            
        loss = -1 * pow(1-pt, self.gamma) * logpt
        
        if self.size_average:
            return loss.mean()
        
        return loss.sum()





# Focal Test

In [54]:
import torch
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F 

import os, sys, random, time

In [55]:
start_time = time.time()
max_err = 0

x = torch.rand(128000, 2) * random.randint(1, 10)
#x = x.cuda()

target = torch.rand(128000).ge(0.1).long()
#target = target.cuda()

output0 = FocalLoss()(x, target)
output1 = nn.CrossEntropyLoss()(x, target)
output2 = nn.BCEWithLogitsLoss()(x, torch.rand(x.shape).random_(2))


print("Focal Loss: {}".format(output0))
print("CE Loss: {}".format(output1))
print("BCE Loss: {}".format(output2))

RuntimeError: The size of tensor a (256000) must match the size of tensor b (128000) at non-singleton dimension 0

# RetinaNet

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import numpy as np 

In [38]:
class ClassificationModel(nn.Module):
    def __init__(self, features_in, features_out=256, num_anchors=9, num_classes=80):
        super(ClassificationModel, self).__init__()

        self.num_anchors = num_anchors
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(features_in, features_out, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(features_out, features_out, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(features_out, features_out, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(features_out, features_out, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(features_out, num_anchors * num_classes, kernel_size=3, padding=1)
        self.output_act = nn.Sigmoid()

    def forward(self, x):
        out = self.conv1(x)
        print("layer 1: shape {}".format(out.shape))
        out = self.act1(out)

        out = self.conv2(out)
        print("layer 2: shape {}".format(out.shape))
        out = self.act2(out)

        out = self.conv3(out)
        print("layer 3: shape {}".format(out.shape))
        out = self.act3(out)

        out = self.conv4(out)
        print("layer 4: shape {}".format(out.shape))
        out = self.act4(out)

        out = self.output(out)
        print("layer output: shape {}".format(out.shape))
        out = self.output_act(out)

        # out is B x C x W x H, with C = n_classes * n_anchors
        print(out.shape)
        out1 = out.permute(0, 2, 3, 1) # B, C, W, H => B, W, H, C

        batch_size, width, height, channels = out1.shape

        out2 = out1.view(batch_size, width, height, self.num_anchors, self.num_classes) # B, W, H, C => B, W, H, #anchors, #classes

        return out2.contiguous().view(x.shape[0], -1, self.num_classes) # B, W, H, #anchors, #classes => B, W * H * #anchors, #classes


In [35]:
x = torch.randn(1, 256, 256, 3) * 255

x = torch.clamp(x, 0, 255)

In [39]:
classification = ClassificationModel(features_in=256)(x)

layer 1: shape torch.Size([1, 256, 256, 3])
layer 2: shape torch.Size([1, 256, 256, 3])
layer 3: shape torch.Size([1, 256, 256, 3])
layer 4: shape torch.Size([1, 256, 256, 3])
layer output: shape torch.Size([1, 720, 256, 3])
torch.Size([1, 720, 256, 3])


In [37]:
classification

tensor([[[0.6675, 0.8112, 0.5319,  ..., 0.4915, 0.5656, 0.4535],
         [0.4240, 0.4811, 0.6332,  ..., 0.4638, 0.4593, 0.3847],
         [0.4679, 0.5647, 0.7307,  ..., 0.4766, 0.5034, 0.5730],
         ...,
         [0.4156, 0.6728, 0.4103,  ..., 0.7202, 0.5673, 0.5896],
         [0.4846, 0.7695, 0.6661,  ..., 0.1779, 0.2749, 0.2846],
         [0.6998, 0.4297, 0.5478,  ..., 0.3712, 0.7847, 0.5753]]],
       grad_fn=<ViewBackward>)

In [41]:
alpha = 0.25
gamma = 2.0

targets = torch.zeros(classification.shape)
targets = torch.where(torch.lt(classification, 0.4), targets, 1. - targets)
targets

tensor([[[0., 0., 1.,  ..., 1., 0., 1.],
         [0., 0., 1.,  ..., 1., 1., 0.],
         [0., 1., 1.,  ..., 1., 0., 0.],
         ...,
         [0., 1., 1.,  ..., 1., 0., 0.],
         [1., 0., 1.,  ..., 1., 1., 1.],
         [1., 0., 1.,  ..., 1., 0., 1.]]])

In [43]:
alpha_factor = torch.ones(targets.shape) * alpha
alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1 - alpha_factor)
alpha_factor

tensor([[[0.7500, 0.7500, 0.2500,  ..., 0.2500, 0.7500, 0.2500],
         [0.7500, 0.7500, 0.2500,  ..., 0.2500, 0.2500, 0.7500],
         [0.7500, 0.2500, 0.2500,  ..., 0.2500, 0.7500, 0.7500],
         ...,
         [0.7500, 0.2500, 0.2500,  ..., 0.2500, 0.7500, 0.7500],
         [0.2500, 0.7500, 0.2500,  ..., 0.2500, 0.2500, 0.2500],
         [0.2500, 0.7500, 0.2500,  ..., 0.2500, 0.7500, 0.2500]]])

In [45]:
focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
focal_weight

tensor([[[0.3734, 0.2478, 0.3728,  ..., 0.4098, 0.1884, 0.3931],
         [0.1092, 0.3380, 0.2173,  ..., 0.3912, 0.5951, 0.3356],
         [0.3776, 0.3895, 0.4681,  ..., 0.2609, 0.2131, 0.2418],
         ...,
         [0.2737, 0.3216, 0.4888,  ..., 0.5191, 0.3035, 0.1805],
         [0.2249, 0.3027, 0.4130,  ..., 0.2277, 0.3505, 0.2988],
         [0.3613, 0.2720, 0.5198,  ..., 0.3241, 0.2999, 0.4858]]],
       grad_fn=<SWhereBackward>)

In [46]:
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
focal_weight

tensor([[[0.1046, 0.0460, 0.0347,  ..., 0.0420, 0.0266, 0.0386],
         [0.0089, 0.0857, 0.0118,  ..., 0.0383, 0.0885, 0.0845],
         [0.1069, 0.0379, 0.0548,  ..., 0.0170, 0.0341, 0.0439],
         ...,
         [0.0562, 0.0259, 0.0597,  ..., 0.0674, 0.0691, 0.0244],
         [0.0126, 0.0687, 0.0426,  ..., 0.0130, 0.0307, 0.0223],
         [0.0326, 0.0555, 0.0675,  ..., 0.0263, 0.0674, 0.0590]]],
       grad_fn=<MulBackward0>)

In [47]:
bce = -(targets * torch.log(classification)) + (1. - targets)*torch.log(1. - classification)
bce

tensor([[[-0.4675, -0.2847,  0.4665,  ...,  0.5273, -0.2088,  0.4993],
         [-0.1156, -0.4125,  0.2450,  ...,  0.4962,  0.9040, -0.4088],
         [-0.4741,  0.4934,  0.6312,  ...,  0.3024, -0.2397, -0.2768],
         ...,
         [-0.3198,  0.3880,  0.6711,  ...,  0.7320, -0.3617, -0.1990],
         [ 0.2548, -0.3606,  0.5328,  ...,  0.2584,  0.4316,  0.3550],
         [ 0.4483, -0.3174,  0.7335,  ...,  0.3918, -0.3565,  0.6651]]],
       grad_fn=<AddBackward0>)

In [48]:
clss_loss = focal_weight * bce
clss_loss

tensor([[[-0.0489, -0.0131,  0.0162,  ...,  0.0221, -0.0056,  0.0193],
         [-0.0010, -0.0353,  0.0029,  ...,  0.0190,  0.0800, -0.0345],
         [-0.0507,  0.0187,  0.0346,  ...,  0.0051, -0.0082, -0.0121],
         ...,
         [-0.0180,  0.0100,  0.0401,  ...,  0.0493, -0.0250, -0.0049],
         [ 0.0032, -0.0248,  0.0227,  ...,  0.0033,  0.0133,  0.0079],
         [ 0.0146, -0.0176,  0.0495,  ...,  0.0103, -0.0240,  0.0392]]],
       grad_fn=<MulBackward0>)

In [None]:
# Another implementation
def focal_loss(y_true, y_pred):
    pt1 = torch.where(torch.eq(targets, 1.), )