In [4]:
# import essential packages
import torch 
import numpy as np 
import torch.nn as nn
import torch.nn.functional as F 

## Equation (7) in paper

In [2]:
def anchorBboxSize(ah_i, aw_i, level):
    minimum_size = 20
    AH, AW = minimum_size * np.pow(2, level-1)
    b_i = (np.log(ah_i/AH), np.log(aw_i/AW))

    return b_i

In [3]:
def anchorBboxGenerator(b_i, level=1):
    hidden_dim = 5
    theta_dim = 10
    theta_standard = torch.randn(theta_dim)

    # two layers
    residual_theta = F.linear(F.relu(F.linear(b_i, (2, hidden_dim))), (hidden_dim, theta_dim))
    theta_b_i = theta_standard + residual_theta

    return theta_b_i

# Original RetinaNet

In [7]:
class OriginalRetinaNet(nn.Module):
    num_anchors = 9 # 3 scales x 3 ratios

    def __init__(self, num_classes=20):
        super(OriginalRetinaNet, self).__init__()
        self.fpn # = FPN50()
        self.num_classes = num_classes
        self.reg_head = self._make_head(self.num_anchors * 4) # 9 * 4 = 36
        self.cls_head = self._make_head(self.num_anchors * self.num_classes) # 9 * 20 = 180

    def forward(self, x):
        fms = self.fpn(x)

        loc_preds = []
        cls_preds = []
        for fm in fms:
            loc_pred = self.loc_pred(fm)
            cls_pred = self.cls_pred(fm)
            loc_pred = loc_pred.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4) # [N, 9*4, H, W] -> [N, H, W, 9*4] -> [N, H*W*9, 4]
            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.num_classes)

            loc_preds.append(loc_pred)
            cls_preds.append(cls_pred)

        return torch.cat(loc_preds, 1), torch.cat(cls_preds, 1)

    def _make_head(self, out_planes):
        layers = []
        for _ in range(4):
            layers.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1))
            layers.append(nn.ReLU(True))
        
        layers.append(nn.Conv2d(256, out_planes, kernel_size=3, stride=1, padding=1))

        return nn.Sequential(*layers) 


# MetaAnchor 

In [None]:
class MetaRetinaNet(nn.Module):
    def __init__(self, num_classes=20):
        super(MetaRetinaNet, self).__init__()
        self.fpn # = FPN50()
        self.num_classes = num_classes
        self.reg_head = self._make_head(4) # 4
        self.cls_head = self._make_head(self.num_classes) #20

    def forward(self, x):
        fms = self.fpn(x)

        loc_preds = []
        cls_preds = []
        for fm in fms:
            loc_pred = self.loc_pred(fm)
            cls_pred = self.cls_pred(fm)
            loc_pred = loc_pred.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4) # [N, 4, H, W] -> [N, H, W, 94] -> [N, H*W, 4]
            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.num_classes)

            loc_preds.append(loc_pred)
            cls_preds.append(cls_pred)

        return torch.cat(loc_preds, 1), torch.cat(cls_preds, 1)

    def _make_head(self, out_planes):
        layers = []
        for _ in range(4):
            layers.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1))
            layers.append(nn.ReLU(True))
        
        layers.append(nn.Conv2d(256, out_planes, kernel_size=3, stride=1, padding=1))

        return nn.Sequential(*layers) 

def focal_loss_meta(b_i, cls_pred, cls_label, loc_pred, loc_label):
    '''
    bi = [N,2]
    cls_pred = [N,20]
    cls_label = [N,]
    reg_pred = [N,4]
    reg_label = [N,4]
    
    '''

    alpha = 0.25
    gamma = 2
    num_classes = 20

    t = torch.eye(num_classes + 1) (loc_label, )
    t = t[:, 1:] # t is one-hot vector

    p = F.logsigmoid(cls_pred)
    pt = p*t + (1-p)*(1-t) # pt = p if t > 0 else 1 - p

    m = alpha*t + (1-alpha)*(1-t)
    m = m * (1-pt).pow(gamma)

    weight = anchorBboxGenerator(b_i, )

    cls_loss = F.binary_cross_entropy_with_logits(x, t, m, size_average=False)
