In [13]:
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import os
from pycocotools.coco import COCO
import cv2
import torch

In [14]:
import random
import numpy as np
import matplotlib.pyplot as plt 
%matplotlib inline

# data loader

In [15]:
class SubtractMeans(object):
    def __init__(self, mean):
        self.mean = np.array(mean, dtype=np.float32)

    def __call__(self, image, boxes = None, labels = None):
        image = image.astype(np.float32)
        image -= self.mean
        return image.astype(np.float32), boxes, labels

In [16]:
class Resize(object):
    def __init__(self, Resize = 512):
        self.Resize = Resize

    def __call__(self, image, boxes = None, labels = None):
        image = cv2.resize(image, (self.Resize, self.Resize))
        return image, boxes, labels

In [17]:
class ConvertToFloat(object):
    """convert image from int to float
    Args:
        image: numpy array 
    """
    def __call__(self, image, boxes = None, labels = None):
        return image.astype(np.float32), boxes, labels


In [18]:
class Compose(object):
    """Composes several augmentations together.
    Args:
        transforms (List[Transform]): list of transforms to compose.
    Example:
        >>> augmentations.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, boxes = None, labels = None):
        for transform in self.transforms:
            img, boxes, labels = transform(img, boxes, labels)
        return img, boxes, labels

In [19]:
class SSDAugmentation(object):
    def __init__(self, imgSize = 512, mean = (0, 0, 0)):
        self.mean = mean
        self.imgSize = imgSize
        self.augment = Compose([
            ConvertToFloat(),
            Resize(self.imgSize),
            SubtractMeans(self.mean)
        ])

    def __call__(self, img, boxes, labels):
        return self.augment(img, boxes, labels)

In [20]:
class COCOAnnotationTransform(object):
    """Transforms a COCO annotation into a Tensor of bbox coords and label index
    Initilized with a dictionary lookup of classnames to indexes
    """
    def __init__(self):
        self.label_map = self._get_label_map('/Users/kehwaweng/Documents/ObjectDetection/torch_ssd_mobilenet/coco_labels.txt')

    def __call__(self, target, width, height):
        """
        Args:
            target (dict): COCO target json annotation as a python dict
            height (int): height
            width (int): width
        Returns:
            a list containing lists of bounding boxes  [bbox coords, class idx]
        """
        scale = np.array([width, height, width, height])
        res = []
        for obj in target:
            if 'bbox' in obj:
                bbox = obj['bbox']
                bbox[2] += bbox[0]
                bbox[3] += bbox[1]
                label_idx = self.label_map[obj['category_id']] - 1
                final_box = list(np.array(bbox)/scale)
                final_box.append(label_idx)
                res += [final_box]  # [xmin, ymin, xmax, ymax, label_idx]
            else:
                print("no bbox problem!")
        return res 
    
    def _get_label_map(self, label_file):
        label_map = {}
        labels = open(label_file, 'r')
        for line in labels:
            ids = line.split(',')
            label_map[int(ids[0])] = int(ids[1])
        return label_map

In [21]:
class COCODetection(Dataset):
    def __init__(self, image_set = "train2017", transform = SSDAugmentation(), target_transform = COCOAnnotationTransform()):
        self.root = os.path.join("/Volumes/IPEVO_X0244/coco_dataset/",image_set)
        self.coco = COCO(annotation_file= os.path.join("/Volumes/IPEVO_X0244/coco_dataset/annotations_2017",
                                                       "instances_{}.json".format(image_set)))
        self.ids = list(self.coco.imgToAnns.keys())
        self.transforms = transform
        self.target_transform = target_transform
        self.transform = transform
    
    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index
        Returns:
            tuple: Tuple (image, target).
                   target is the object returned by coco.loadAnns
        """
        img, gt, h, w = self.pull_item(idx)
        return img, gt
    
    def __len__(self): 
        return len(self.ids)
    
    def pull_item(self, idx):
        """
        Args:
            idx (int): Index
        Returns:
            tuple: tuple (image, target, width, height)
                    target is the object returned by coco.loadAnns
        """
        img_id = self.ids[idx]
        target = self.coco.imgToAnns[img_id]
#         ann_ids = self.coco.getAnnIds(img_id)
#         target = self.coco.loadAnns(ann_ids)
        img_path = os.path.join("/Volumes/IPEVO_X0244/coco_dataset/train2017/", self.coco.loadImgs(img_id)[0]['file_name'])
        assert os.path.exists(img_path), "loading image error"
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        height, width, _ = img.shape
        
        if self.target_transform is not None:
            target = self.target_transform(target, width, height)
        if self.transform is not None:
            target = np.array(target)
            img, boxes, labels = self.transform(img, 
                                                target[:, :4],
                                                target[:, 4])
            # to rgb
            img = img[:, :, (2, 1, 0)]

            target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        
        return torch.from_numpy(img.transpose(2, 0, 1)), target, height, width

In [22]:
dataset = COCODetection()


loading annotations into memory...
Done (t=36.62s)
creating index...
index created!


In [23]:
def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).
    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations
    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on
                                 0 dim
    """
    targets = []
    imgs = []
    for sample in batch:
        imgs.append(sample[0])
        targets.append(torch.FloatTensor(sample[1]))
    return torch.stack(imgs, 0), targets

In [24]:
loader = DataLoader(dataset= dataset, batch_size = 6, shuffle= True, collate_fn= detection_collate)

for img, target in loader:
    break

In [25]:
print(img.shape)
for i in target:
    print(i.shape)

torch.Size([6, 3, 512, 512])
torch.Size([2, 5])
torch.Size([37, 5])
torch.Size([3, 5])
torch.Size([3, 5])
torch.Size([18, 5])
torch.Size([14, 5])


# mobilenet v2

In [26]:
import torch
import torch.nn as nn
from torchsummary import summary

In [27]:
def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

In [28]:
def ConvbBNReLU(in_channels, out_channels, stride, use_batch_norm = True):
    if use_batch_norm:
        return nn.Sequential(
            nn.Conv2d(in_channels= in_channels,
                      out_channels= out_channels,
                      kernel_size= 3,
                      stride= stride,
                      padding= 1,
                      bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = False)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels= in_channels, 
                      out_channels= out_channels,
                      kernel_size= 3,
                      stride= stride,
                      padding= 1,
                      bias= False),
            nn.ReLU(inplace= False)
        )
    
        

In [29]:
def conv_1x1_bn(in_channels, out_channels, use_batch_norm = True):
    if use_batch_norm:
        return nn.Sequential(
            nn.Conv2d(in_channels, 
                      out_channels, 
                      kernel_size = 1,
                      stride = 1,
                      padding = 0,
                      bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = False)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels, 
                      out_channels, 
                      kernel_size = 1,
                      stride = 1,
                      padding = 0,
                      bias = False),
            nn.ReLU(inplace = False)
        )

In [30]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio, use_batch_norm = True):
        super(InvertedResidual, self).__init__()

        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(in_channels * expand_ratio)
        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        if expand_ratio == 1:
            if use_batch_norm:
                self.conv = nn.Sequential(
                    # dw
                    nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups = hidden_dim, bias = False),
                    nn.BatchNorm2d(hidden_dim),
                    nn.ReLU(inplace = False),
                    # pw-linear
                    nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias = False),
                    nn.BatchNorm2d(out_channels),
                )
            else:
                self.conv = nn.Sequential(
                    # dw
                    nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups = hidden_dim, bias = False),
                    nn.ReLU(inplace = False),
                    # pw-linear
                    nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias = False),
                )
        else:
            if use_batch_norm:
                self.conv = nn.Sequential(
                    # pw
                    nn.Conv2d(in_channels, hidden_dim, 1, 1, 0, bias = False),
                    nn.BatchNorm2d(hidden_dim),
                    nn.ReLU(inplace = False),
                    # dw
                    nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups = hidden_dim, bias = False),
                    nn.BatchNorm2d(hidden_dim),
                    nn.ReLU(inplace = False),
                    # pw-linear
                    nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias = False),
                    nn.BatchNorm2d(out_channels),
                )
            else:
                self.conv = nn.Sequential(
                    # pw
                    nn.Conv2d(in_channels, hidden_dim, 1, 1, 0, bias = False),
                    nn.ReLU(inplace = False),
                    # dw
                    nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups = hidden_dim, bias = False),
                    nn.ReLU(inplace = False),
                    # pw-linear
                    nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias = False),
                )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


In [31]:
class mobilenetv2(nn.Module):
    def __init__(self, 
                 n_classes = 80,
                 image_size = 512,
                 width_mult = 1.,
                 round_nearest = 8, 
                 dropout_ratio = 0.2,
                 use_batch_norm = True,):
        super(mobilenetv2, self).__init__()
        
        assert image_size % 32 == 0
        image_channels = 3
        
        last_channel = 1280
        input_channel = 32
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        
        inverted_residual_setting = [
            #expand_ratio, channel, number of residual bloc, stride
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]
        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                             "or a 4-element list, got {}".format(inverted_residual_setting))
        
        self.extract_feature = []
        self.extract_feature.append( ConvbBNReLU(image_channels, input_channel, stride = 2))
        
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                self.extract_feature.append(InvertedResidual(in_channels = input_channel,
                                                             out_channels = output_channel,
                                                             stride = stride, 
                                                             expand_ratio = t,
                                                             use_batch_norm = use_batch_norm
                                                             )
                                           )
                input_channel = output_channel
                
        self.extract_feature.append(conv_1x1_bn(in_channels = input_channel, 
                                         out_channels = self.last_channel,
                                         use_batch_norm = use_batch_norm))
        # make it nn.Sequential
        self.extract_feature = nn.Sequential(*self.extract_feature)
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, n_classes),
        )
    
    
        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)    
                
    def forward(self, x):
        x = self.extract_feature(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x       
        

In [32]:
model = mobilenetv2()
# summary(model, (3,512,512))

In [33]:
def SeperableConv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0):
    """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
    """
    return nn.Sequential(
        nn.Conv2d(in_channels = in_channels,
                  out_channels = in_channels,
                  kernel_size = kernel_size,
                  groups = in_channels,
                  stride = stride,
                  padding = padding),
        nn.BatchNorm2d(in_channels),
        nn.ReLU(inplace = False),
        nn.Conv2d(in_channels = in_channels, 
                  out_channels = out_channels,
                  kernel_size = 1),
    )

In [34]:
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, stride,):
        super(Conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size= 1, stride= stride, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace= False)
        )
    def forward(self, x):
        return self.conv(x)

In [35]:
class liteConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(liteConv, self).__init__()

        hidden_dim = out_channels // 2        
        self.conv = nn.Sequential(
            # pw
            nn.Conv2d(in_channels, hidden_dim, 1, 1, 0, bias = False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace = False),
            # dw
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups = hidden_dim, bias = False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace = False),
            # pw-linear
            nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = False)
        )
    
    def forward(self, x):
        return self.conv(x)
        

In [36]:
class liteConv_2(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(liteConv_2, self).__init__()

        hidden_dim = out_channels // 2        
        
            # pw
        self.conv1 = nn.Conv2d(in_channels, hidden_dim, 1, 1, 0, bias = False)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        self.act1 = nn.ReLU(inplace = False)
            # dw
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, 3, 2, 1, groups = hidden_dim, bias = False)
        self.bn2 = nn.BatchNorm2d(hidden_dim)
        self.act2 = nn.ReLU(inplace = False)
            # pw-linear
        self.conv3 = nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias = False)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.act3 = nn.ReLU(inplace = False)
        
    
    def forward(self, x):
        x = self.conv1(x)
        print(x.shape)
        x = self.bn1(x)
        print(x.shape)
        x = self.act1(x)
        print(x.shape)
        
        x = self.conv2(x)
        print(x.shape)
        x = self.bn2(x)
        print(x.shape)
        x = self.act2(x)
        print(x.shape)

        x = self.conv3(x)
        print(x.shape)
        x = self.bn3(x)
        print(x.shape)
        x = self.act3(x)
        print(x.shape)
        
        return x

In [37]:
def add_extras():
    extras_layers = nn.ModuleList([
#         Conv(1280, 1280, stride = 2), 
        liteConv(1280, 512, stride = 2, ),
        liteConv(512, 256, stride = 2, ),
        liteConv(256, 256, stride = 2, ),
        liteConv(256, 64, stride = 2, )
    ])
    return extras_layers
# add_extras()

In [38]:
def multibox(n_classes, width_mult = 1.0, ):
    
    loc_layers = nn.ModuleList([
        SeperableConv2d(in_channels = round(576 * width_mult),
                        out_channels = 6 * 4,
                        kernel_size = 3, 
                        padding = 1),
        SeperableConv2d(in_channels = 1280, out_channels = 6 * 4, kernel_size = 3, padding = 1),
        SeperableConv2d(in_channels = 512, out_channels = 6 * 4, kernel_size = 3, padding = 1),
        SeperableConv2d(in_channels = 256, out_channels = 6 * 4, kernel_size = 3, padding = 1),
        SeperableConv2d(in_channels = 256, out_channels = 6 * 4, kernel_size = 3, padding = 1),
        nn.Conv2d(in_channels = 64, out_channels = 6 * 4, kernel_size = 1),
    ])
    
    
    conf_layers = nn.ModuleList([
        SeperableConv2d(in_channels = round(576 * width_mult),
                        out_channels = 6 * n_classes, 
                        kernel_size = 3, 
                        padding = 1),
        SeperableConv2d(in_channels = 1280, out_channels = 6 * n_classes, kernel_size = 3, padding = 1),
        SeperableConv2d(in_channels = 512, out_channels = 6 * n_classes, kernel_size = 3, padding = 1),
        SeperableConv2d(in_channels = 256, out_channels = 6 * n_classes, kernel_size = 3, padding = 1),
        SeperableConv2d(in_channels = 256, out_channels = 6 * n_classes, kernel_size = 3, padding = 1),
        nn.Conv2d(in_channels = 64, out_channels = 6 * n_classes, kernel_size = 1),
    ])	
    return loc_layers, conf_layers
    
# multibox(80, 1.0)
    

In [39]:
from collections import namedtuple
GraphPath = namedtuple("GraphPath", ['s0', 'name', 's1'])  #


In [40]:
source_layer_indexes = [
    GraphPath(14, 'conv', 3),
    19,
]
source_layer_indexes

[GraphPath(s0=14, name='conv', s1=3), 19]

In [41]:
class ssd(nn.Module):
    '''single shot multibox architecture
       backbone network is mobilenetv2
    '''
    def __init__(self,
                 image_size,
                 base,
                 extras_layers,
                 loc_layers,
                 conf_layers,
                 source_layer_index,
                 n_classes,
                 dropout_ratio = 0.1):
        super(ssd, self).__init__()       
        '''
        Args:
            image_size (int): size of input image size
            base (nn.Module): backbone network - MobileNetV2 extract feature part
            extras_layers (nn.ModuleList): extra layers that feed to multibox loc and conf layers
            loc_layers (nn.ModuleList): bounding box output layer
            conf_layers (nn.ModuleList): class confidence output layer
            n_classes (int): number of class need to detect
            dropout_ratio (float): percentage of dropout ratio
        '''
        self.n_classes = n_classes
        self.image_size = image_size
        self.base = base
        self.source_layer_indexes = source_layer_index
        self.extras_layers = extras_layers
        self.dropout = nn.Dropout(p = dropout_ratio, inplace = False)
        self.loc_layers = loc_layers
        self.conf_layers = conf_layers
        
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        loc = []
        conf = []
        source = []
        start_layer_index = 0
        header_index = 0
        for end_layer_index in self.source_layer_indexes:
            
            if isinstance(end_layer_index, GraphPath):
                extras_from_base = end_layer_index
                end_layer_index = end_layer_index.s0
                added_layer = None
            elif isinstance(end_layer_index, tuple):
                added_layer = end_layer_index[1]
                end_layer_index = end_layer_index[0]
                extras_from_base = None
            else:
                added_layer = None
                extras_from_base = None
            
            for layer in self.base[start_layer_index: end_layer_index]:
                x = layer(x)
                
            if added_layer:
                y = added_layer(x)
            else:
                y = x
                
            if extras_from_base:
                sub = getattr(self.base[end_layer_index], extras_from_base.name)
                
                for layer in sub[:extras_from_base.s1]:
                    x = layer(x)
                    
                y = x
                
                for layer in sub[extras_from_base.s1:]:
                    x = layer(x)
                    
                end_layer_index += 1
                
            start_layer_index = end_layer_index
            source.append(y)
        
        for layer in self.base[end_layer_index:]:
            x = layer(x)
        
        for i,layer in enumerate(self.extras_layers):
#             print(f"index in extras layers {i}, {x.shape}")
            x = layer(x)
#             print("feature map shape after extras:",x.shape )
            source.append(x)
        
        
        for (x, l, c) in zip(source, self.loc_layers, self.conf_layers):    
#             print(x.shape)
            loc.append( l(x).permute(0, 2, 3, 1).contiguous())
            conf.append( c(x).permute(0, 2, 3, 1).contiguous())
        
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
       
        output = (
            loc.view(loc.size(0), - 1, 4), 
#             torch.argmax(self.softmax(conf.view(conf.size(0), -1, self.n_classes)), dim = 2)
            conf.view(conf.size(0), -1, self.n_classes)
        )
        return output
        

In [43]:
image_size = 512
extras_layer = add_extras()
loc_layers, conf_layers = multibox(80, 1.0)
n_classes = 80

ssd_module = ssd(image_size, model.extract_feature, extras_layer, loc_layers, conf_layers, source_layer_indexes, n_classes)
# torch.save(ssd_module, "./test_ssd.h5")

In [44]:
# x = torch.rand((2,3,512,512))
# o = ssd_module.forward(x)
# summary(ssd_module, (3,512,512))

inference = ssd_module(img)

torch.Size([6, 576, 32, 32])
torch.Size([6, 1280, 16, 16])
torch.Size([6, 512, 8, 8])
torch.Size([6, 256, 4, 4])
torch.Size([6, 256, 2, 2])
torch.Size([6, 64, 1, 1])


In [46]:
print(inference[0].shape)
print(inference[1].shape)

torch.Size([6, 8190, 4])
torch.Size([6, 8190, 80])


# prior bbox

In [None]:
from itertools import product as product
from math import sqrt
class PriorBox(object):
    """Compute priorbox coordinates in center-offset form for each source
    feature map.
    Note:
    This 'layer' has changed between versions of the original SSD
    paper, so we include both versions, but note v2 is the most tested and most
    recent version of the paper.
    """

    def __init__(self, cfg):
        super(PriorBox, self).__init__()
        self.image_size = cfg['min_dim']
        # number of priors for feature map location (either 4 or 6)
        self.num_priors = len(cfg['aspect_ratios'])
        self.variance = cfg['variance'] or [0.1]
        self.feature_maps = cfg['feature_maps']
        self.min_sizes = cfg['min_sizes']
        self.max_sizes = cfg['max_sizes']
        self.steps = cfg['steps']
        self.aspect_ratios = cfg['aspect_ratios']
        self.clip = cfg['clip']
        for v in self.variance:
            if v <= 0:
                raise ValueError('Variances must be greater than 0')

    def forward(self):
        mean = []
        for k, f in enumerate(self.feature_maps):
            for i, j in product(range(f), repeat=2):
                f_k = self.image_size / self.steps[k]
                cx = (j + 0.5) / f_k 
                cy = (i + 0.5) / f_k

                s_k = self.min_sizes[k] / self.image_size
                mean += [cx, cy, s_k, s_k]

                # aspect_ratio: 1
                # rel size: sqrt(s_k * s_(k+1))
                if self.max_sizes:
                    s_k_prime = sqrt(s_k * (self.max_sizes[k] / self.image_size))
                    mean += [cx, cy, s_k_prime, s_k_prime]

                # rest of aspect ratios
                for ar in self.aspect_ratios[k]:
                    mean += [cx, cy, s_k * sqrt(ar), s_k / sqrt(ar)]
                    mean += [cx, cy, s_k / sqrt(ar), s_k * sqrt(ar)]

        # back to torch land
        output = torch.Tensor(mean).view(-1, 4)
        if self.clip:
            output.clamp_(max=1, min=0)
        return output

In [None]:
MOBILEV2_512 = {
    "feature_maps": [32, 16, 8, 4, 2, 1],
    "min_dim": 512,
    "steps": [16, 32, 64, 128, 256, 512],
    "min_sizes": [102.4,  174.08, 245.76, 317.44, 389.12, 460.8],
    "max_sizes": [174.08, 245.76, 317.44, 389.12, 460.8,  512  ],
    "aspect_ratios": [[2,3], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]],
    "variance": [0.1, 0.2],
    "clip": True,
}    

priorbox = PriorBox(MOBILEV2_512)
priors = priorbox.forward()
print(priors.shape)

# multibox loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [None]:
def point_form(boxes):
    """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
    representation for comparison to point form ground truth data.
    Args:
        boxes: (tensor) center-size default boxes from priorbox layers.
    Return:
        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
    """
    return torch.cat((boxes[:, :2] - boxes[:, 2:]/2,     # xmin, ymin
                     boxes[:, :2] + boxes[:, 2:]/2), 1)  # xmax, ymax


def center_size(boxes):
    """ Convert prior_boxes to (cx, cy, w, h)
    representation for comparison to center-size form ground truth data.
    Args:
        boxes: (tensor) point_form boxes
    Return:
        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
    """
    return torch.cat((boxes[:, 2:] + boxes[:, :2])/2,  # cx, cy
                     boxes[:, 2:] - boxes[:, :2], 1)  # w, h


def intersect(box_a, box_b):
    """ We resize both tensors to [A,B,2] without new malloc:
    [A,2] -> [A,1,2] -> [A,B,2]
    [B,2] -> [1,B,2] -> [A,B,2]
    Then we compute the area of intersect between box_a and box_b.
    Args:
      box_a: (tensor) bounding boxes, Shape: [A,4].
      box_b: (tensor) bounding boxes, Shape: [B,4].
    Return:
      (tensor) intersection area, Shape: [A,B].
    """
    A = box_a.size(0)
    B = box_b.size(0)
    max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
                       box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
    min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
                       box_b[:, :2].unsqueeze(0).expand(A, B, 2))
    inter = torch.clamp((max_xy - min_xy), min=0)
    return inter[:, :, 0] * inter[:, :, 1]


def jaccard(box_a, box_b):
    """Compute the jaccard overlap of two sets of boxes.  The jaccard overlap
    is simply the intersection over union of two boxes.  Here we operate on
    ground truth boxes and default boxes.
    E.g.:
        A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
    Args:
        box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
        box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
    Return:
        jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
    """
    inter = intersect(box_a, box_b)
    area_a = ((box_a[:, 2]-box_a[:, 0]) *
              (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
    area_b = ((box_b[:, 2]-box_b[:, 0]) *
              (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
    union = area_a + area_b - inter
    return inter / union  # [A,B]



def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
    """Match each prior box with the ground truth box of the highest jaccard
    overlap, encode the bounding boxes, then return the matched indices
    corresponding to both confidence and location preds.
    Args:
        threshold: (float) The overlap threshold used when mathing boxes.
        truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
        priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
        variances: (tensor) Variances corresponding to each prior coord,
            Shape: [num_priors, 4].
        labels: (tensor) All the class labels for the image, Shape: [num_obj].
        loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
        conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
        idx: (int) current batch index
    Return:
        The matched indices corresponding to 1)location and 2)confidence preds.
    """
    # jaccard index
    overlaps = jaccard(
        truths,
        point_form(priors)
    )
    # (Bipartite Matching)
    # [1,num_objects] best prior for each ground truth
    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)

    # [1,num_priors] best ground truth for each prior
    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)

    best_truth_idx.squeeze_(0)
    best_truth_overlap.squeeze_(0)
    best_prior_idx.squeeze_(1)
    best_prior_overlap.squeeze_(1)
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)  # ensure best prior

    # TODO refactor: index  best_prior_idx with long tensor
    # ensure every gt matches with its prior of max overlap
    for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j

    matches = truths[best_truth_idx]          # Shape: [num_priors,4]
    conf = labels[best_truth_idx] + 1         # Shape: [num_priors]
    conf[best_truth_overlap < threshold] = 0  # label as background
    loc = encode(matches, priors, variances)
    loc_t[idx] = loc    # [num_priors,4] encoded offsets to learn
    conf_t[idx] = conf  # [num_priors] top class label for each prior



def encode(matched, priors, variances):
    """Encode the variances from the priorbox layers into the ground truth boxes
    we have matched (based on jaccard overlap) with the prior boxes.
    Args:
        matched: (tensor) Coords of ground truth for each prior in point-form
            Shape: [num_priors, 4].
        priors: (tensor) Prior boxes in center-offset form
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        encoded boxes (tensor), Shape: [num_priors, 4]
    """

    # dist b/t match center and prior's center
    g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
    # encode variance
    g_cxcy /= (variances[0] * priors[:, 2:])
    # match wh / prior wh
    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
    g_wh = torch.log(g_wh) / variances[1]
    # return target for smooth_l1_loss
    return torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4]

In [None]:
def log_sum_exp(x):
    """Utility function for computing log_sum_exp while determining
    This will be used to determine unaveraged confidence loss across
    all examples in a batch.
    Args:
        x (Variable(tensor)): conf_preds from conf layers
    """
    x_max = x.data.max()
    return torch.log( torch.sum( torch.exp(x - x_max), 1, keepdim = True)) + x_max

In [None]:
from torch.autograd import Variable
import torch.nn.functional as F
class MultiBoxLoss(nn.Module):
    """SSD Weighted Loss Function
    Compute Targets:
        1) Produce Confidence Target Indices by matching  ground truth boxes
           with (default) 'priorboxes' that have jaccard index > threshold parameter
           (default threshold: 0.5).
        2) Produce localization target by 'encoding' variance into offsets of ground
           truth boxes and their matched  'priorboxes'.
        3) Hard negative mining to filter the excessive number of negative examples
           that comes with using a large number of default bounding boxes.
           (default negative:positive ratio 3:1)
    Objective Loss:
        L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
        Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
        weighted by α which is set to 1 by cross val.
        Args:
            c: class confidences,
            l: predicted boxes,
            g: ground truth boxes
            N: number of matched default boxes
        See: https://arxiv.org/pdf/1512.02325.pdf for more details.
    """
    def __init__(self,
                 n_classes,
                 overlap_thresh = 0.5,
                 prior_for_matching = True,
                 bkg_label = 0,
                 neg_mining = True,
                 neg_pos = 3,
                 neg_overlap = 0.5,
                 encode_target = False):
        super(MultiBoxLoss, self).__init__()
        self.n_classes = n_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching  = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlap = neg_overlap
        self.variance = [0.1,0.2]
        
    def forward(self, predictions, priors, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net. (loc, conf)
            
            conf    (tensor): with shape (batch_size, num_priors, num_classes)
            loc     (tensor): with shape (batch_size, num_priors, 4)
            priors  (tensor): with shape (num_priors, 4)
            targets (tensor): Ground truth boxes and labels, with shape 
                              (batch_size, num_objects, 5), last index store
                              [xmin, ymin, xmax, ymax, label]
        """

        loc_data, conf_data = predictions
        priors = priors
        num_batch = loc_data.size(0)
        num_priors = (priors.size(0))
        n_classes = self.n_classes

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(num_batch, num_priors, 4)
        conf_t = torch.LongTensor(num_batch, num_priors)
        for idx in range(num_batch):
            ground_true_bboxes = targets[idx][:, :-1].data
            labels = targets[idx][:, -1].data
            defaults = priors.data
            match(self.threshold,
                  ground_true_bboxes,
                  defaults,
                  self.variance,
                  labels,
                  loc_t,
                  conf_t,
                  idx)
   
        # wrap targets
        loc_t = Variable(loc_t, requires_grad = False)
        conf_t = Variable(conf_t, requires_grad = False)

        
        pos= conf_t > 0

        # Localization Loss (Smooth L1)
        # Shape: [batch,num_priors,4]
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
        loc_p = loc_data[pos_idx].view(-1, 4)
        loc_t = loc_t[pos_idx].view(-1, 4)
        loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average = False)

        # Compute max conf across batch for hard negative mining
        batch_conf = conf_data.view(-1, self.n_classes)
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

        # Hard Negative Mining
        loss_c = loss_c.view( pos.size()[0], pos.size()[1])
        loss_c[pos] = 0 # filter out pos boxes for now
        loss_c = loss_c.view(num_batch, -1)
        _,loss_idx = loss_c.sort(1, descending = True)
        _,idx_rank = loss_idx.sort(1)
        num_pos = pos.long().sum(1, keepdim = True)
        num_neg = torch.clamp( self.negpos_ratio * num_pos, max = pos.size(1) - 1)
        neg = idx_rank < num_neg.expand_as(idx_rank)

        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.n_classes)
        targets_weighted = conf_t[(pos + neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted, size_average = False)

        # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N

        N = num_pos.data.sum().float()
        loss_l = loss_l / N
        loss_c /= N
        return loss_l,loss_c
    
    

In [None]:
criterion = MultiBoxLoss(n_classes = 80)

In [None]:
epoch = 2
lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr= lr,)

In [None]:
for img, target in loader:
    print("image shape: {}".format(img.shape))
#     print("label shape: {}".format(target.shape))
    for t in target:
        print("sub target shape:", t.shape)
    break
    

In [None]:
infernece = ssd_module(img)

In [None]:
print(infernece[0].shape)
print(infernece[1].shape)

In [None]:
criterion(infernece, priors, target)

In [None]:
np_priors = priors.detach().numpy()

In [248]:
np_priors.shape

(8190, 4)

In [None]:
with open("./priors.dat", "wb") as binfile:
    binfile.write(np_priors.tobytes())