# Import libraries

In [1]:
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
import time
from torchvision import tv_tensors
from torchvision.transforms import v2

In [2]:
class Timer:
    """Record multiple running times."""
    def __init__(self):
        """Defined in :numref:`sec_minibatch_sgd`"""
        self.times = []
        self.start()

    def start(self):
        """Start the timer."""
        self.tik = time.time()

    def stop(self):
        """Stop the timer and record the time in a list."""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """Return the average time."""
        return sum(self.times) / len(self.times)

    def sum(self):
        """Return the sum of time."""
        return sum(self.times)

    def cumsum(self):
        """Return the accumulated time."""
        return np.array(self.times).cumsum().tolist()

# Dataset

In [3]:
voc_trainset = torchvision.datasets.VOCDetection(
    root='/kaggle/working/', year='2012', image_set='train', download=True)

Downloading http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar to /kaggle/working/VOCtrainval_11-May-2012.tar


100%|██████████| 1999639040/1999639040 [00:06<00:00, 296323674.68it/s]


Extracting /kaggle/working/VOCtrainval_11-May-2012.tar to /kaggle/working/


In [4]:
voc_valset = torchvision.datasets.VOCDetection(
    root='/kaggle/working/', year='2012', image_set='val', download=True)

Using downloaded and verified file: /kaggle/working/VOCtrainval_11-May-2012.tar
Extracting /kaggle/working/VOCtrainval_11-May-2012.tar to /kaggle/working/


In [5]:
VOC_I2N = {0:'aeroplane',
1:'bicycle',
2:'bird',
3:'boat',
4:'bottle',
5:'bus',
6:'car',
7:'cat',
8:'chair',
9:'cow',
10:'diningtable',
11:'dog',
12:'horse',
13:'motorbike',
14:'person',
15:'pottedplant',
16:'sheep',
17:'sofa',
18:'train',
19:'tvmonitor'}
VOC_N2I = {n: i for (i, n) in zip(VOC_I2N.keys(), VOC_I2N.values())}

In [6]:
from tqdm import tqdm
bbox = lambda x: [int(x['xmin']), int(x['ymin']), int(x['xmax']), int(x['ymax'])]
dic = {'file_name':[], 'bbox': [], 'name': []}
for img, annotation in tqdm(voc_trainset):
    for obj in annotation['annotation']['object']:
        dic['file_name'].append(annotation['annotation']['filename'])
        dic['bbox'].append(bbox(obj['bndbox']))
        dic['name'].append(VOC_N2I[obj['name']])

100%|██████████| 5717/5717 [00:16<00:00, 348.76it/s]


In [7]:
import pandas as pd
data = pd.DataFrame(dic)
data.head()

Unnamed: 0,file_name,bbox,name
0,2008_000008.jpg,"[53, 87, 471, 420]",12
1,2008_000008.jpg,"[158, 44, 289, 167]",14
2,2008_000015.jpg,"[270, 1, 378, 176]",4
3,2008_000015.jpg,"[57, 1, 164, 150]",4
4,2008_000019.jpg,"[139, 2, 372, 197]",11


In [8]:
def resize_transform(image, target, max_size):
    w, h = image.size
    w_delta = w - max_size
    h_delta = h - max_size
    if w_delta > 0:
        image, target = v2.Pad([w_delta//2, 0])(image, target)
        w = max_size
    else:
        image, target = v2.CenterCrop([h, max_size])(image, target)
    if h_delta > 0:
        image, target = v2.Pad([0, h_delta//2])(image, target)
        h = max_size
    else:
        image, target = v2.CenterCrop([max_size, h])(image, target)
    image, target = v2.Resize((max_size, max_size))(image, target)
    return image, target

In [9]:
class VOCDetection(torch.utils.data.Dataset):
    def __init__(self, root, data):
        super().__init__()
        self.root = root
        self.data = data
        # self.transforms = transforms
        # self.bbox = lambda x: [int(x['xmin']), int(x['ymin']), int(x['xmax']), int(x['ymax'])]
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_name = row[0]
        img = Image.open(os.path.join(self.root, file_name))
        w, h = img.size
        bbox = row[1]
        bboxes = torch.Tensor(bbox)
        bboxes = tv_tensors.BoundingBoxes(bboxes, format = 'XYXY', canvas_size = (h, w))
        img, bboxes = resize_transform(img, bboxes, 500)
        img = transforms.ToTensor()(img)
        bbox_tensor = bboxes[0]
        bbox = bbox_tensor.tolist()
        name = row[2]
        label = torch.Tensor([name, bbox[0], bbox[1], bbox[2], bbox[3]])
        return img, label.unsqueeze(0)
    def __len__(self):
        return len(self.data)

In [10]:
train_dataset = VOCDetection(root='/kaggle/working/VOCdevkit/VOC2012/JPEGImages', data=data)

In [11]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)

In [12]:
imgs, labels = next(iter(train_dataloader))
print(imgs.shape)
print(labels.shape)

torch.Size([8, 3, 500, 500])
torch.Size([8, 1, 5])


  file_name = row[0]
  bbox = row[1]
  name = row[2]


# Model

In [13]:
sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79], [0.88, 0.961]]
ratios = [[1, 2, 0.5]] * 5
num_anchors = len(sizes[0]) + len(ratios[0]) - 1 # n + m - 1
print(num_anchors)

4


In [14]:
def multibox_prior(data, sizes, ratios):
    """Generate anchor boxes with different shapes centered on each pixel."""
    in_height, in_width = data.shape[-2:]
    device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
    boxes_per_pixel = (num_sizes + num_ratios - 1)
    size_tensor = torch.tensor(sizes, device=device)
    ratio_tensor = torch.tensor(ratios, device=device)
    # Offsets are required to move the anchor to the center of a pixel. Since
    # a pixel has height=1 and width=1, we choose to offset our centers by 0.5
    offset_h, offset_w = 0.5, 0.5
    steps_h = 1.0 / in_height  # Scaled steps in y axis
    steps_w = 1.0 / in_width  # Scaled steps in x axis

    # Generate all center points for the anchor boxes
    center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
    center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
    shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
    shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)

    # Generate `boxes_per_pixel` number of heights and widths that are later
    # used to create anchor box corner coordinates (xmin, xmax, ymin, ymax)
    w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
                   sizes[0] * torch.sqrt(ratio_tensor[1:])))\
                   * in_height / in_width  # Handle rectangular inputs
    h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
                   sizes[0] / torch.sqrt(ratio_tensor[1:])))
    # Divide by 2 to get half height and half width
    anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
                                        in_height * in_width, 1) / 2

    # Each center point will have `boxes_per_pixel` number of anchor boxes, so
    # generate a grid of all anchor box centers with `boxes_per_pixel` repeats
    out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
                dim=1).repeat_interleave(boxes_per_pixel, dim=0)
    output = out_grid + anchor_manipulations
    return output.unsqueeze(0)

In [15]:
def cls_predictor(num_inputs, num_anchors, num_classes):
    return nn.Conv2d(num_inputs, num_anchors * (num_classes + 1),
                     kernel_size=3, padding=1)

In [16]:
def bbox_predictor(num_inputs, num_anchors):
    return nn.Conv2d(num_inputs, num_anchors * 4, kernel_size=3, padding=1)

In [17]:
class SSD_ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(SSD_ResNet18, self).__init__()
        self.num_classes = num_classes
        self.resnet18 = torchvision.models.resnet18(pretrained=True)
        self.adaptivemaxpool = nn.AdaptiveMaxPool2d((1,1))
        self.cls_predictor1 = cls_predictor(64, num_anchors, num_classes)
        self.bbox_predictor1 = bbox_predictor(64, num_anchors)
        self.cls_predictor2 = cls_predictor(128, num_anchors, num_classes)
        self.bbox_predictor2 = bbox_predictor(128, num_anchors)
        self.cls_predictor3 = cls_predictor(256, num_anchors, num_classes)
        self.bbox_predictor3 = bbox_predictor(256, num_anchors)
        self.cls_predictor4 = cls_predictor(512, num_anchors, num_classes)
        self.bbox_predictor4 = bbox_predictor(512, num_anchors)
        self.cls_predictor5 = cls_predictor(512, num_anchors, num_classes)
        self.bbox_predictor5 = bbox_predictor(512, num_anchors)

    def forward(self, x):
        bs = x.shape[0]
        out1 = self.resnet18.layer1(self.resnet18.maxpool(self.resnet18.relu(self.resnet18.bn1(self.resnet18.conv1(x)))))
        anchors1 = multibox_prior(out1, sizes[0], ratios[0]).reshape((1, -1))
        cls_preds1 = self.cls_predictor1(out1).permute(0, 2, 3, 1).reshape(bs, -1)
        bbox_preds1 = self.bbox_predictor1(out1).permute(0, 2, 3, 1).reshape(bs, -1)

        out2 = self.resnet18.layer2(out1)
        anchors2 = multibox_prior(out2, sizes[1], ratios[1]).reshape((1, -1))
        cls_preds2 = self.cls_predictor2(out2).permute(0, 2, 3, 1).reshape(bs, -1)
        bbox_preds2 = self.bbox_predictor2(out2).permute(0, 2, 3, 1).reshape(bs, -1)

        out3 = self.resnet18.layer3(out2)
        anchors3 = multibox_prior(out3, sizes[2], ratios[2]).reshape((1, -1))
        cls_preds3 = self.cls_predictor3(out3).permute(0, 2, 3, 1).reshape(bs, -1)
        bbox_preds3 = self.bbox_predictor3(out3).permute(0, 2, 3, 1).reshape(bs, -1)

        out4 = self.resnet18.layer4(out3)
        anchors4 = multibox_prior(out4, sizes[3], ratios[3]).reshape((1, -1))
        cls_preds4 = self.cls_predictor4(out4).permute(0, 2, 3, 1).reshape(bs, -1)
        bbox_preds4 = self.bbox_predictor4(out4).permute(0, 2, 3, 1).reshape(bs, -1)

        out5 = self.adaptivemaxpool(out4)
        anchors5 = multibox_prior(out5, sizes[4], ratios[4]).reshape((1, -1))
        cls_preds5 = self.cls_predictor5(out5).permute(0, 2, 3, 1).reshape(bs, -1)
        bbox_preds5 = self.bbox_predictor5(out5).permute(0, 2, 3, 1).reshape(bs, -1)

        anchors = torch.cat([anchors1, anchors2, anchors3, anchors4, anchors5], dim = 1).reshape(1, -1, 4)
        cls_preds = torch.cat([cls_preds1, cls_preds2, cls_preds3, cls_preds4, cls_preds5], dim = 1).reshape(bs, -1, self.num_classes+1)
        bbox_preds = torch.cat([bbox_preds1, bbox_preds2, bbox_preds3, bbox_preds4, bbox_preds5], dim = 1).reshape(bs, -1)
        return anchors, cls_preds, bbox_preds

In [18]:
net = SSD_ResNet18(num_classes=20)
X = torch.zeros((8, 3, 500, 500))
anchors, cls_preds, bbox_preds = net(X)

print('output anchors:', anchors.shape)
print('output class preds:', cls_preds.shape)
print('output bbox preds:', bbox_preds.shape)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 206MB/s]


output anchors: torch.Size([1, 83500, 4])
output class preds: torch.Size([8, 83500, 21])
output bbox preds: torch.Size([8, 334000])


# Training

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
!pip install torchsummary
from torchsummary import summary
summary(net.to(device), (3, 500, 500))

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 250, 250]           9,408
       BatchNorm2d-2         [-1, 64, 250, 250]             128
              ReLU-3         [-1, 64, 250, 250]               0
         MaxPool2d-4         [-1, 64, 125, 125]               0
            Conv2d-5         [-1, 64, 125, 125]          36,864
       BatchNorm2d-6         [-1, 64, 125, 125]             128
              ReLU-7         [-1, 64, 125, 125]               0
            Conv2d-8         [-1, 64, 125, 125]          36,864
       BatchNorm2d-9         [-1, 64, 125, 125]             128
             ReLU-10         [-1, 64, 125, 125]          

In [20]:
def box_iou(boxes1, boxes2):
    """Compute pairwise IoU across two lists of anchor or bounding boxes."""
    box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *
                              (boxes[:, 3] - boxes[:, 1]))
    # Shape of `boxes1`, `boxes2`, `areas1`, `areas2`: (no. of boxes1, 4),
    # (no. of boxes2, 4), (no. of boxes1,), (no. of boxes2,)
    areas1 = box_area(boxes1)
    areas2 = box_area(boxes2)
    # Shape of `inter_upperlefts`, `inter_lowerrights`, `inters`: (no. of
    # boxes1, no. of boxes2, 2)
    inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])
    inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
    inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)
    # Shape of `inter_areas` and `union_areas`: (no. of boxes1, no. of boxes2)
    inter_areas = inters[:, :, 0] * inters[:, :, 1]
    union_areas = areas1[:, None] + areas2 - inter_areas
    return inter_areas / union_areas

In [21]:
def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):
    """Assign closest ground-truth bounding boxes to anchor boxes."""
    num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]
    # Element x_ij in the i-th row and j-th column is the IoU of the anchor
    # box i and the ground-truth bounding box j
    jaccard = box_iou(anchors, ground_truth)
    # Initialize the tensor to hold the assigned ground-truth bounding box for
    # each anchor
    anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,
                                  device=device)
    # Assign ground-truth bounding boxes according to the threshold
    max_ious, indices = torch.max(jaccard, dim=1)
    anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)
    box_j = indices[max_ious >= iou_threshold]
    anchors_bbox_map[anc_i] = box_j
    col_discard = torch.full((num_anchors,), -1)
    row_discard = torch.full((num_gt_boxes,), -1)
    for _ in range(num_gt_boxes):
        max_idx = torch.argmax(jaccard)  # Find the largest IoU
        box_idx = (max_idx % num_gt_boxes).long()
        anc_idx = (max_idx / num_gt_boxes).long()
        anchors_bbox_map[anc_idx] = box_idx
        jaccard[:, box_idx] = col_discard
        jaccard[anc_idx, :] = row_discard
    return anchors_bbox_map

In [22]:
def box_corner_to_center(boxes):
    """ Convert from (upper-left, lower-right) to (center, width, height)."""
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:,3]
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    boxes = torch.stack((cx, cy, w, h), axis=-1)
    return boxes

In [23]:
def offset_boxes(anchors, assigned_bb, eps=1e-6):
    """Transform for anchor box offsets."""
    c_anc = box_corner_to_center(anchors)
    c_assigned_bb = box_corner_to_center(assigned_bb)
    offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]
    offset_wh = 5 * torch.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])
    offset = torch.cat([offset_xy, offset_wh], axis=1)
    return offset

In [24]:
def multibox_target(anchors, labels):
    """Label anchor boxes using ground-truth bounding boxes."""
    batch_size, anchors = labels.shape[0], anchors.squeeze(0)
    batch_offset, batch_mask, batch_class_labels = [], [], []
    device, num_anchors = anchors.device, anchors.shape[0]
    for i in range(batch_size):
        label = labels[i, :, :]
        anchors_bbox_map = assign_anchor_to_bbox(
            label[:, 1:], anchors, device)
        bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(1, 4)
        # Initialize class labels and assigned bounding box coordinates with
        # zeros
        class_labels = torch.zeros(num_anchors, dtype=torch.long,
                                   device=device)
        assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,
                                  device=device)
        # Label classes of anchor boxes using their assigned ground-truth
        # bounding boxes. If an anchor box is not assigned any, we label its
        # class as background (the value remains zero)
        indices_true = torch.nonzero(anchors_bbox_map >= 0)
        bb_idx = anchors_bbox_map[indices_true]
        class_labels[indices_true] = label[bb_idx, 0].long() + 1
        assigned_bb[indices_true] = label[bb_idx, 1:]
        # Offset transformation
        offset = offset_boxes(anchors, assigned_bb) * bbox_mask
        batch_offset.append(offset.reshape(-1))
        batch_mask.append(bbox_mask.reshape(-1))
        batch_class_labels.append(class_labels)
    bbox_offset = torch.stack(batch_offset)
    bbox_mask = torch.stack(batch_mask)
    class_labels = torch.stack(batch_class_labels)
    return (bbox_offset, bbox_mask, class_labels)

In [25]:
trainer = torch.optim.NAdam(net.parameters(), lr=0.0001, weight_decay=5e-4)

In [26]:
cls_loss = nn.CrossEntropyLoss(reduction='sum')
bbox_loss = nn.L1Loss(reduction='sum')

def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):
    batch_size, num_classes = cls_preds.shape[0], cls_preds.shape[2]
    # bbox_loss
    bbox = bbox_loss(bbox_preds * bbox_masks,
                     bbox_labels * bbox_masks)
    cls_mask = bbox_masks.reshape(batch_size, -1, 4)
    cls_mask = cls_mask.sum(dim = -1).reshape(-1)
    foreground_idxs = torch.where(cls_mask > 0)[0]
    num_foreground = foreground_idxs.shape[0]

    background_idxs = torch.where(cls_mask == 0)[0]
    keep_bg_idxs = background_idxs[torch.rand(*background_idxs.shape)>0.95]
    cls_mask[keep_bg_idxs] = 4
    cls = cls_loss(cls_preds.reshape(-1, num_classes)[cls_mask>0],
                   cls_labels.reshape(-1)[cls_mask>0])
    return cls/num_foreground, bbox/num_foreground

In [27]:
def cls_eval(cls_preds, cls_labels):
    # Because the class prediction results are on the final dimension,
    # `argmax` needs to specify this dimension
    return float((cls_preds.argmax(dim=-1).type(
        cls_labels.dtype) == cls_labels).sum())

def bbox_eval(bbox_preds, bbox_labels, bbox_masks):
    return float((torch.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())

In [28]:
num_epochs, timer = 10, Timer()
net = net.to(device)

In [30]:
metric = {'train_loss': [],
          'cls_loss': [],
          'bbox_loss': []}
for epoch in range(num_epochs):
    # Sum of training accuracy, no. of examples in sum of training accuracy,
    # Sum of absolute error, no. of examples in sum of absolute error
    net.train()
    for features, target in train_dataloader:
        timer.start()
        trainer.zero_grad()
        X, Y = features.to(device), target.to(device)
        # Generate multiscale anchor boxes and predict their classes and
        # offsets
        anchors, cls_preds, bbox_preds = net(X)
        # Label the classes and offsets of these anchor boxes
        bbox_labels, bbox_masks, cls_labels = multibox_target(anchors, Y)
        # Calculate the loss function using the predicted and labeled values
        # of the classes and offsets
        cls_l, bbox_l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,
                      bbox_masks)
        l = 0.1*cls_l + 0.9*bbox_l
        l.backward()
        trainer.step()
        metric['train_loss'].append(l.item())
        metric['cls_loss'].append(cls_l.item())
        metric['bbox_loss'].append(bbox_l.item())
        # print('Loss: %.2f Cls_loss: %.2f Bbox_loss %.2f' %(l.item(), cls_l.item(), bbox_l.item()))
    print(f'Epoch {epoch}: Loss: %.2f Cls_loss: %.2f Bbox_loss %.2f'%(sum(metric['train_loss'][-len(train_dataloader):])/len(train_dataloader),
                                                                      sum(metric['cls_loss'][-len(train_dataloader):])/len(train_dataloader),
                                                                      sum(metric['bbox_loss'][-len(train_dataloader):])/len(train_dataloader)))

  file_name = row[0]
  bbox = row[1]
  name = row[2]


Epoch 0: Loss: 6887.14 Cls_loss: 2.76 Bbox_loss 7652.07
Epoch 1: Loss: 6888.20 Cls_loss: 2.89 Bbox_loss 7653.23
Epoch 2: Loss: 6888.06 Cls_loss: 3.03 Bbox_loss 7653.06
Epoch 3: Loss: 6887.45 Cls_loss: 3.10 Bbox_loss 7652.38
Epoch 4: Loss: 6888.34 Cls_loss: 3.07 Bbox_loss 7653.37
Epoch 5: Loss: 6886.87 Cls_loss: 3.07 Bbox_loss 7651.73
Epoch 6: Loss: 6886.46 Cls_loss: 3.05 Bbox_loss 7651.28
Epoch 7: Loss: 6886.37 Cls_loss: 3.06 Bbox_loss 7651.18
Epoch 8: Loss: 6884.96 Cls_loss: 3.06 Bbox_loss 7649.62
Epoch 9: Loss: 6886.11 Cls_loss: 3.09 Bbox_loss 7650.89
