# **Object Detection**

In [39]:
import torch
import torch.nn as nn

### **Utility functions**

In [3]:
def box_corner_to_center(boxes):
    """
    boxes: (num_boxes, x1, y1, x2, y2)
    where (x1,y1) is the upper corner left coordinate
    and   (x2, y2) is the bottom corner right coo
    """
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    boxes = torch.stack((cx, cy, w, h), axis=-1)
    return boxes

def box_center_to_corner(boxes):
    cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    boxes = torch.stack((x1, y1, x2, y2), axis=-1)
    return boxes

### **Create Anchor Boxes**

In [24]:
class CreateAnchorBoxes:
    """
    Take the image scale size to generate the coordinates of 
    all anchor boxes given the ratios and sizes
    
    image: (batch_size, channel, width, height)
    sizes: 1D list
    ratios: 1D list
    """
    def __init__(self, image, sizes, ratios):
        self.image_height, self.image_width = image.shape[-2], image.shape[-1]
        self.sizes = torch.tensor(sizes)
        self.ratios = torch.tensor(ratios)
        self.num_boxes = len(sizes) * len(ratios)
        self.create_grid_center_points()
        self.create_anchor_width_height()

    def create_grid_center_points(self):
        self.center_x_axis = (torch.arange(self.image_width) + 0.5) / self.image_width
        self.center_y_axis = (torch.arange(self.image_height) + 0.5) / self.image_height
        self.grid_x_coords, self.grid_y_coords = torch.meshgrid(self.center_x_axis, self.center_y_axis)
        self.grid_x_coords = self.grid_x_coords.reshape(-1)
        self.grid_y_coords = self.grid_y_coords.reshape(-1)

    def create_anchor_width_height(self):
        width_anchor = torch.tensor([])
        height_anchor = torch.tensor([])

        for size in self.sizes:
            for ratio in self.ratios:
                width_anchor = torch.cat((width_anchor, size * torch.sqrt(ratio).unsqueeze(0)))
                height_anchor = torch.cat((height_anchor, size / torch.sqrt(ratio).unsqueeze(0)))

        self.grid_width_height = torch.stack(
            (-width_anchor, -height_anchor,
            width_anchor, height_anchor)
        ).T.repeat(self.image_height * self.image_width, 1) / 2

    def mapping(self):
        out_grid = torch.stack(
            [self.grid_x_coords, self.grid_y_coords,
             self.grid_x_coords, self.grid_y_coords],
        dim=1).repeat_interleave(self.num_boxes, dim=0)
        return out_grid + self.grid_width_height

### **Intersection Over Union**

In [7]:
def calculate_box_area(boxes):
    return ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]))

def BoxIou(boxes1, boxes2):
    areas1 = calculate_box_area(boxes1)
    areas2 = calculate_box_area(boxes2)

    inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])
    inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
    inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)

    inter_areas = inters[:, :, 0] * inters[:, :, 1]
    union_areas = areas1[:, None] + areas2 - inter_areas
    return inter_areas / union_areas

In [5]:
def assign_anchor_to_bbox(ground_truth, anchors, iou_threshold=0.5):
    num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]
    jaccard = BoxIou(anchors, ground_truth)

    anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long)
    max_ious, indices = torch.max(jaccard, dim=1)

    anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)
    box_j = indices[max_ious >= iou_threshold]
    anchors_bbox_map[anc_i] = box_j

    col_discard = torch.full((num_anchors,), -1)
    row_discard = torch.full((num_gt_boxes,), -1)

    for _ in range(num_gt_boxes):
        max_idx = torch.argmax(jaccard)
        box_idx = (max_idx % num_gt_boxes).long()
        anc_idx = (max_idx / num_gt_boxes).long()
        anchors_bbox_map[anc_idx] = box_idx
        jaccard[:, box_idx] = col_discard
        jaccard[anc_idx, :] = row_discard

    return anchors_bbox_map, jaccard

### **Post Prediction**

In [6]:
def offset_boxes(anchors, asigned_bboxes, sigma_xy=0.1, sigma_wh=0.2, mean=0, eps=1e-6):
    center_anchors = box_corner_to_center(anchors)
    center_asigned_bboxes = box_corner_to_center(asigned_bboxes)
    offset_xy = ((center_asigned_bboxes[:, :2] - center_anchors[:, :2]) / center_anchors[:, 2:] - mean) / sigma_xy
    offset_wh = (torch.log(eps + center_asigned_bboxes[:, 2:] / center_anchors[:, 2:] - mean)) / sigma_wh
    offset = torch.cat([offset_xy, offset_wh], axis=1)
    return offset


def multibox_target(anchors, labels):

    batch_size, anchors = labels.shape[0], anchors.squeeze(0)
    batch_offset, batch_mask, batch_class_labels = [], [], []
    num_anchors = anchors.shape[0]

    for i in range(batch_size):
        label = labels[i, :, :]
        anchors_bbox_map, _ = assign_anchor_to_bbox(label[:, 1:], anchors)
        bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(1, 4)

        class_labels = torch.zeros(num_anchors, dtype=torch.long)
        assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32)

        indices_true = torch.nonzero(anchors_bbox_map >= 0)
        bb_idx = anchors_bbox_map[indices_true]

        class_labels[indices_true] = label[bb_idx, 0].long() + 1
        assigned_bb[indices_true] = label[bb_idx, 1:]

        # Offset transformation
        offset = offset_boxes(anchors, assigned_bb) * bbox_mask
        batch_offset.append(offset.reshape(-1))
        batch_mask.append(bbox_mask.reshape(-1))
        batch_class_labels.append(class_labels)

    bbox_offset = torch.stack(batch_offset)
    bbox_mask = torch.stack(batch_mask)
    class_labels = torch.stack(batch_class_labels)
    return (bbox_offset, bbox_mask, class_labels)

## **Training**

### **Prepare dataset**

### **Configuration**

In [29]:
# We have 5 scales
depths = [64, 128, 128, 128, 128]
width =  [32, 16,   8,   4,   1]
height = [32, 16,   8,   4,   1]


X_multiscale = [
    torch.randn(1, depths[0], width[0], height[0]),
    torch.randn(1, depths[1], width[1], height[1]),
    torch.randn(1, depths[2], width[2], height[2]),
    torch.randn(1, depths[3], width[3], height[3]),
    torch.randn(1, depths[4], width[4], height[4]),
]

sizes = [[0.2,  0.272], 
         [0.37, 0.447], 
         [0.54, 0.619], 
         [0.71, 0.79], 
         [0.88, 0.961]]

ratios = [[1, 2, 0.5],
          [1, 2, 0.5],
          [1, 2, 0.5],
          [1, 2, 0.5],
          [1, 2, 0.5]]

num_anchors_per_pixel = len(sizes[0]) * len(ratios[0])
print(f"Each group (pixel) has {num_anchors_per_pixel} anchor boxes")

Each group (pixel) has 6 anchor boxes


### **Prepare the anchor boxes**

In [88]:
def prepare_anchor_boxes(X_multiscale, sizes, ratios, num_anchors_per_pixel):
    
    anchors_multiscale = []

    for i in range(len(X_multiscale)):
        create_anchor_func = CreateAnchorBoxes(
            image=X_multiscale[i], 
            sizes=sizes[i], 
            ratios=ratios[i]
        )
        anchors = create_anchor_func.mapping()
        anchors_multiscale.append(anchors)
        
        total_anchors = X_multiscale[i].shape[2] * X_multiscale[i].shape[3] * num_anchors_per_pixel
        print(f"Total anchor boxes are: {total_anchors} -> anchor shape: {anchors.shape}")
        
    anchors_multiscale = torch.cat(anchors_multiscale, dim=0)
    return anchors_multiscale
        
anchors_multiscale = prepare_anchor_boxes(X_multiscale, sizes, ratios, num_anchors_per_pixel)
print(anchors_multiscale.shape)

Total anchor boxes are: 6144 -> anchor shape: torch.Size([6144, 4])
Total anchor boxes are: 1536 -> anchor shape: torch.Size([1536, 4])
Total anchor boxes are: 384 -> anchor shape: torch.Size([384, 4])
Total anchor boxes are: 96 -> anchor shape: torch.Size([96, 4])
Total anchor boxes are: 6 -> anchor shape: torch.Size([6, 4])
torch.Size([8166, 4])


### **Prepare the Network**

**Base Network**

In [67]:
class Downsampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
    def forward(self, x):
        return self.layers(x)
        
class BaseNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            Downsampling(3, 16),
            Downsampling(16, 32),
            Downsampling(32, 64),
        )
        
    def forward(self, x):
        return self.layers(x)

In [92]:
class ScaleModule(nn.Module):
    def __init__(self, in_channel, out_channel, num_classes, num_anchors, baseNet=None, avg_pool=False):
        super().__init__()
        
        if baseNet is not None:
            self.downsample = baseNet
        elif avg_pool:
            self.downsample = nn.AdaptiveMaxPool2d((1,1))
        else:
            self.downsample = Downsampling(in_channel, out_channel)
            
        self.cls_head = nn.Conv2d(out_channel, num_anchors * (num_classes + 1), kernel_size=3, padding=1)
        self.bbox_head = nn.Conv2d(out_channel, num_anchors * 4, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = self.downsample(x)
        cls_pred = self.cls_head(x)
        bbox_pred = self.bbox_head(x)
        return x, cls_pred, bbox_pred
    

class SSD(nn.Module):
    def __init__(self, num_classes, num_anchors):
        super().__init__()
        self.num_classes = num_classes
        self.cls_preds, self.bbox_preds = [], []
        self.base_net = BaseNet()
        self.scale_module_1 = ScaleModule(3, 64, num_classes, num_anchors, baseNet=self.base_net)
        self.scale_module_2 = ScaleModule(64, 128, num_classes, num_anchors)
        self.scale_module_3 = ScaleModule(128, 128, num_classes, num_anchors)
        self.scale_module_4 = ScaleModule(128, 128, num_classes, num_anchors)
        self.scale_module_5 = ScaleModule(128, 128, num_classes, num_anchors, avg_pool=True)
        
    def flatten_pred(self, pred):
        return torch.flatten(pred.permute(0, 2, 3, 1), start_dim=1)
        
    def concat_preds(self, preds):
        return torch.cat([self.flatten_pred(pred) for pred in preds], dim=1)
        
    def post_process(self):
        self.cls_preds = self.concat_preds(self.cls_preds)
        self.bbox_preds = self.concat_preds(self.bbox_preds)
            
    def forward(self, x):
        x, cls_pred, bbox_pred = self.scale_module_1(x)
        self.cls_preds.append(cls_pred)
        self.bbox_preds.append(bbox_pred)
        
        x, cls_pred, bbox_pred = self.scale_module_2(x)
        self.cls_preds.append(cls_pred)
        self.bbox_preds.append(bbox_pred)
        
        x, cls_pred, bbox_pred = self.scale_module_3(x)
        self.cls_preds.append(cls_pred)
        self.bbox_preds.append(bbox_pred)
        
        x, cls_pred, bbox_pred = self.scale_module_4(x)
        self.cls_preds.append(cls_pred)
        self.bbox_preds.append(bbox_pred)
        
        x, cls_pred, bbox_pred = self.scale_module_5(x)
        self.cls_preds.append(cls_pred)
        self.bbox_preds.append(bbox_pred)
        
        self.post_process()
        
        return x, self.cls_preds, self.bbox_preds

In [None]:
X = torch.rand(1, 3, 256, 256)
model = SSD(num_classes=1, num_anchors=6)
Y, cls_preds, bbox_preds = model(X)
Y.shape