# Implementation of RetinaNet Anchor Generator.

In [5]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [22]:
from typing import Optional, List, Union, Tuple
import math
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F

## Retina Anchors.

paper: [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)

Section 4. RetinaNet Detector -> Anchors.

In [37]:
class BufferList(nn.Module):
    def __init__(self, buffers):
        super().__init__()
        for i, buffer in enumerate(buffers):
            self.register_buffer(str(i), buffer)
            
    def __len__(self):
        return len(self._buffers)
    
    def __iter__(self):
        return iter(self._buffers.values())
    
    def __repr__(self):
        return str(self._buffers.values())

In [38]:
def _broadcast_params(params: Union[List[float], Tuple[float]], num_features: int):
    if not isinstance(params[0], (list, tuple)):
        return [params] * num_features
    if len(params) == 1:
        return list(params) * num_features
    
    assert len(params) == num_features
    return params

In [51]:
class AnchorBoxGenerator(nn.Module):
    def __init__(self, 
                 sizes: List[float],
                 aspect_ratios: List[float],
                 strides: List[int],
                 scales: Optional[List[float]] = [1.0],
                 offset: Optional[float]=0.5
                ):
        """
        Compute anchors in the standard way described in
        "Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks"
        paper.
        """
        super().__init__()
        
        self.strides = strides
        self.num_features = len(self.strides)
        sizes = _broadcast_params([[size * scale for scale in scales] for size in sizes], self.num_features)
        aspect_ratios = _broadcast_params(aspect_ratios, self.num_features)
        self.cell_anchors = self._calculate_anchors(sizes , aspect_ratios)
        
        
        print(f"sizes: {sizes}")
        print(f"aspect ratios: {aspect_ratios}")
        self.offset = offset
        assert 0.0 <= self.offset < 1.0
        
    def _calculate_anchors(self, sizes, aspect_ratios):
        cell_anchors = [
            self.generate_anchor_boxes(s, a).float() for s, a in zip(sizes, aspect_ratios)
        ]
        return cell_anchors
        return BufferList(cell_anchors)
        
    def generate_anchor_boxes(self, sizes=(32, 128, 256, 512), aspect_ratios=(0.5, 1, 2)):
        """
        Generate a tensor storing canonical anchor boxes of different 
        sizes and aspect ratios centered at (0, 0). 
        
        Returns:
            Tensor of shape (len(sizes) * len(aspect_ratios), 4) storing anchor boxes in
            bounding-box (X_min, Y_min, X_max, Y_max) coords.
        """
        
        anchors = []
        for size in sizes:
            area = size ** 2.0
            for aspect_ratio in aspect_ratios:
                w = math.sqrt(area / aspect_ratio)
                h = aspect_ratio * w
                x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0 # centered @ (0, 0)
                anchors.append([x0, y0, x1, y1])
                
        return torch.tensor(anchors)
    
    def forward(self, features: List[torch.Tensor]):
        """
        Args:
            features: list of backbone feature maps on which to generate anchors.
        
        """
        # Note that the generated anchors depend on the feature maps and not the 
        # gt images themselves. We are generating anchors w.r.t to feature maps.
        
        grid_sizes = [feature_map.shape[-2:] for feature_map in features]
        anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
        
        return anchors_over_all_feature_maps
        
    def _grid_anchors(self, grid_sizes: List[List[int]]):
        """
        Returns:
            list[Tensor]: #feature map tensors of shape (locs x cell_anchors) * 4
        """
        anchors = []
        buffers = [x[1] for x in self.cell_anchors.named_buffers()]
        
        for size, stride, base_anchors in zip(grid_sizes, self.strides, buffers):
            shift_x, shift_y = self._create_grid_offsets(size, stride, self.offset, base_anchors.device)
            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
            
            anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
            
        return anchors
    def _create_grid_offsets(size: List[int], stride: int, offset: float, device: torch.device):
        grid_height, grid_width = size
        shifts_x = torch.arange(
            offset * stride, grid_width * stride, step=stride, dtype=torch.float32, device=device
        )
        
        shifts_y = torch.arange(
            offset * stride, grid_height * stride, step=stride, dtype=torch.float32, device=device
        )
        
        shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
        shift_x, = shift_x.reshape(-1)
        shift_y = shift_y.reshape(-1)
        
        return shift_x, shift_y

In [52]:
anchors = AnchorBoxGenerator(
    sizes=[32., 64., 128., 256., 512.],
    aspect_ratios=[0.5, 1., 2.],
    scales=[1., 2 ** (1 / 3), 2 ** (2 / 3)],
    strides=[2, 2, 2, 2, 2]
)

sizes: [[32.0, 40.31747359663594, 50.79683366298238], [64.0, 80.63494719327188, 101.59366732596476], [128.0, 161.26989438654377, 203.18733465192952], [256.0, 322.53978877308754, 406.37466930385904], [512.0, 645.0795775461751, 812.7493386077181]]
aspect ratios: [[0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0]]


In [53]:
anchors.cell_anchors

[tensor([[-22.6274, -11.3137,  22.6274,  11.3137],
         [-16.0000, -16.0000,  16.0000,  16.0000],
         [-11.3137, -22.6274,  11.3137,  22.6274],
         [-28.5088, -14.2544,  28.5088,  14.2544],
         [-20.1587, -20.1587,  20.1587,  20.1587],
         [-14.2544, -28.5088,  14.2544,  28.5088],
         [-35.9188, -17.9594,  35.9188,  17.9594],
         [-25.3984, -25.3984,  25.3984,  25.3984],
         [-17.9594, -35.9188,  17.9594,  35.9188]]),
 tensor([[-45.2548, -22.6274,  45.2548,  22.6274],
         [-32.0000, -32.0000,  32.0000,  32.0000],
         [-22.6274, -45.2548,  22.6274,  45.2548],
         [-57.0175, -28.5088,  57.0175,  28.5088],
         [-40.3175, -40.3175,  40.3175,  40.3175],
         [-28.5088, -57.0175,  28.5088,  57.0175],
         [-71.8376, -35.9188,  71.8376,  35.9188],
         [-50.7968, -50.7968,  50.7968,  50.7968],
         [-35.9188, -71.8376,  35.9188,  71.8376]]),
 tensor([[ -90.5097,  -45.2548,   90.5097,   45.2548],
         [ -64.0000,  -