# Implementation of Focal Loss.

In [2]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from model.matcher import Matcher
from model.box_regression import Box2BoxTransform

## Focus Loss.

paper: [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)

Section 3. Focal Loss -> 3.2. Focal Loss Definition.

In [5]:
def sigmoid_focal_loss(inputs: torch.Tensor, targets: torch.Tensor, gamma=2, alpha=.25, reduction="none") -> torch.Tensor:
    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    
    loss = ce_loss * ((1 - p_t) ** gamma)
    
    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss
        
    if reduction == "mean":
        loss = loss.mean()
        
    elif reduction == "sum":
        loss = loss.sum()
        
    return loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=.25):
        self.gamma = gamma
        self.alpha = alpha
        
    def forward(self, anchors, encoded_labels):
        pass
        