# Putting All of It Together (Loss)

In [1]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from model.matcher import Matcher
from model.box_regression import Box2BoxTransform
from model.loss import sigmoid_focal_loss, smooth_l1_loss
from utils.box_utils import pairwise_iou, cat_boxes

In [4]:
class RetinaLoss(nn.Module):
    def __init__(self,
                 num_classes=80,
                 focal_loss_alpha=0.25, 
                 focal_loss_gamma=2.0, 
                 smooth_l1_beta=0.1,
                 test_score_thresh=0.05,
                 test_topk_candidates=1000,
                 test_nms_thresh=0.5, 
                 max_detection_per_image=100
                ):
        super().__init__()
        self.anchor_matcher = Matcher([0.4, 0.5], [-1, 0, 1])
        self.box2box_transform = Box2BoxTransform([1., 1., 1., 1.])
        self.num_classes = num_classes
        
        # Loss params.
        self.focal_loss_alpha = focal_loss_alpha
        self.focal_loss_gamma = focal_loss_gamma
        self.smooth_l1_beta = smooth_l1_beta
        
        # Inference params.
        self.test_score_thresh = test_score_thresh
        self.test_topk_candidates = test_topk_candidates
        self.test_nms_thresh = test_nms_thresh
        self.max_detection_per_image = max_detection_per_image
        
        """
        In Detectron1, loss is normalized by number of foreground samples in the batch.
        When batch size is 1 per GPU, #foreground has a large variance and
        using it lead to lower performance. Here we maintain an EMA of #foreground to
        stabilize the normalizer.
        """
        self.loss_normalizer = 100  # initialize with any reasonable #fg that's not too small
        self.loss_normalizer_momentum = 0.9
        
    def forward(self, pred_logits, pred_anchor_deltas, anchors, boxes, labels):
        gt_labels, gt_boxes = self.label_anchors(anchors, boxes, labels)
        losses = self.losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes)
        return losses 
    
    def label_anchors(self, anchors, boxes, labels):
        anchors = cat_boxes(anchors)
        gt_labels = []
        matched_gt_boxes = []
        bs = len(boxes)
        
        for i in range(bs):
            matched_quality_matrix = pairwise_iou(boxes[i], anchors)
            matched_idxs, anchor_labels = self.anchor_matcher(matched_quality_matrix)
            del matched_quality_matrix
            
            if len(boxes[i]) > 0:
                matched_gt_boxes_i = boxes[i][matched_idxs]
                gt_labels_i = labels[i][matched_idxs]
                # Label 0 means background. 
                gt_labels_i[anchor_labels == 0] = self.num_classes
                # Label -1 means ignore.
                gt_labels_i[anchor_labels == -1] = -1
            
            else:
                matched_gt_boxes_i = torch.zeros_like(anchors)
                gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes
                
            gt_labels.append(gt_labels_i)
            matched_gt_boxes.append(matched_gt_boxes_i)
        
        return gt_labels, matched_gt_boxes
    
    def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes):
        """
        Args: 
            anchors (list[Tensors]): a list of feature maps for each level.
            gt_labels, gt_boxes: see output of :meth: `self.label_anchors`.
                Their shapes are (N, R) and (N, R, 4), respectively, where R
                is the total number of anchors across all feature levels.
                i.e. sum(Hi x Wi x Ai)
            pred_logits, pred_anchor_deltas: both are list[Tensor]. Each
                element in the list corresponds to one level and has shape
                (N, Hi x Wi x Ai, K or 4) - where K is the number of classes.
            
        Returns:
            dict[str, Tensor]:
                mapping from named loss to a scalar tensor storing the
                loss for classification and bbox regression. Used during
                training only. The dict keys are: "loss_cls" and "loss_box_reg".
        """
        
        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels).squeeze() # (N, R)
        anchors = cat_boxes(anchors)
        gt_anchor_deltas = [self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4)
        
        valid_mask = gt_labels >= 0
        pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
        num_pos_anchors = pos_mask.sum().item()
        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
            1 - self.loss_normalizer_momentum
        ) * max(num_pos_anchors, 1)
        
        # Classification and regression loss.
        gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[
            :, :-1
        ]  # no loss for the last (bg) class
        loss_cls = sigmoid_focal_loss(
            torch.cat(pred_logits, dim=1)[valid_mask],
            gt_labels_target.to(pred_logits[0].dtype),
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum"
        )
        
        loss_box_reg = smooth_l1_loss(
            torch.cat(pred_anchor_deltas, dim=1)[pos_mask],
            gt_anchor_deltas[pos_mask],
            beta=self.smooth_l1_beta,
            reduction="sum",
        )
        return {
            "loss_cls": loss_cls / self.loss_normalizer,
            "loss_box_reg": loss_box_reg / self.loss_normalizer,
        }

## Test Loss

In [5]:
import random 
import torch

In [6]:
from model.model import RetinaNet500

In [7]:
model = RetinaNet500()

In [8]:
loss = RetinaLoss()

In [9]:
data = torch.randn((16, 3, 512, 512))
objs = [random.randint(1, 7) for _ in range(16)]
labels = [torch.randint(0, 79, (num_o, 1)) for num_o in objs]
boxes = [torch.randn((num_o, 4)) for num_o in objs]

In [10]:
pred_logits, pred_bboxes, anchors = model(data)

In [11]:
losses = loss(pred_logits, pred_bboxes, anchors, boxes, labels)

In [12]:
losses['loss_cls'].item(), losses['loss_box_reg'].item()

(0.0, 0.0)

In [13]:
from model.loss import RetinaLoss

In [14]:
losses = loss(pred_logits, pred_bboxes, anchors, boxes, labels)

In [15]:
losses['loss_cls'].item(), losses['loss_box_reg'].item()

(0.0, 0.0)