In [None]:
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

import random
import numpy as np
import cv2
import xml.etree.ElementTree as ET
import os


class VOCDataset(Dataset):

    def __init__(self, is_train, file_names, base_dir, image_size=448, grid_size=7, num_bboxes=2, num_classes=20):
        self.is_train = is_train
        self.image_size = image_size

        self.S = grid_size
        self.B = num_bboxes
        self.C = num_classes

        mean = [122.67891434, 116.66876762, 104.00698793]
        self.mean = np.array(mean, dtype=np.float32)

        self.to_tensor = transforms.ToTensor()

        self.paths, self.boxes, self.labels = [], [], []

        for line in file_names:
            label_path = f"{base_dir}/{'train' if is_train else 'valid'}/{line}.xml"
            image_path = f"{base_dir}/{'train' if is_train else 'valid'}/{line}.jpg"
            
            # Parse XML file in VOC format
            tree = ET.parse(label_path)
            root = tree.getroot()
            
            box = []
            label = []
            
            # Extract object information from XML
            for obj in root.findall('object'):
                name = obj.find('name').text
                # Map class names to class indices (you may need to customize this)
                class_idx = 0 if name == 'worker' else 1  # Assuming 'worker' is 0, 'pig' is 1
                
                bbox = obj.find('bndbox')
                x1 = int(float(bbox.find('xmin').text))
                y1 = int(float(bbox.find('ymin').text))
                x2 = int(float(bbox.find('xmax').text))
                y2 = int(float(bbox.find('ymax').text))
                
                box.append([x1, y1, x2, y2])
                label.append(class_idx)
                
                if len(box) > 0:
                    self.boxes.append(torch.Tensor(box))
                    self.labels.append(torch.LongTensor(label))
                    self.paths.append(image_path)


        self.num_samples = len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        img = cv2.imread(path)
        boxes = self.boxes[idx].clone()  # [n, 4]
        labels = self.labels[idx].clone()  # [n,]

        if self.is_train:
            img, boxes = self.random_flip(img, boxes)
            img, boxes = self.random_scale(img, boxes)

            img = self.random_blur(img)
            img = self.random_brightness(img)
            img = self.random_hue(img)
            img = self.random_saturation(img)

            img, boxes, labels = self.random_shift(img, boxes, labels)
            img, boxes, labels = self.random_crop(img, boxes, labels)

        # # For debug.
        # debug_dir = 'tmp/voc_tta'
        # os.makedirs(debug_dir, exist_ok=True)
        # img_show = img.copy()
        # box_show = boxes.numpy().reshape(-1)
        # n = len(box_show) // 4
        # for b in range(n):
        #     pt1 = (int(box_show[4 * b + 0]), int(box_show[4 * b + 1]))
        #     pt2 = (int(box_show[4 * b + 2]), int(box_show[4 * b + 3]))
        #     cv2.rectangle(img_show, pt1=pt1, pt2=pt2, color=(0, 255, 0), thickness=1)
        # cv2.imwrite(os.path.join(debug_dir, 'test_{}.jpg'.format(idx)), img_show)

        h, w, _ = img.shape
        boxes /= torch.Tensor([[w, h, w, h]]).expand_as(boxes)  # normalize (x1, y1, x2, y2) w.r.t. image width/height.
        target = self.encode(boxes, labels)  # [S, S, 5 x B + C]

        img = cv2.resize(img, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # assuming the model is pretrained with RGB images.
        # img = (img - self.mean) / 255.0  # normalize from -1.0 to 1.0.
        img = np.ascontiguousarray(img, dtype=np.float32)
        img /= 255.0
        img = self.to_tensor(img)

        return img, target

    def __len__(self):
        return self.num_samples

    def encode(self, boxes, labels):

        S, B, C = self.S, self.B, self.C
        N = 5 * B + C

        target = torch.zeros(S, S, N)
        cell_size = 1.0 / float(S)
        boxes_wh = boxes[:, 2:] - boxes[:, :2]  # width and height for each box, [n, 2]
        boxes_xy = (boxes[:, 2:] + boxes[:, :2]) / 2.0  # center x & y for each box, [n, 2]
        for b in range(boxes.size(0)):
            xy, wh, label = boxes_xy[b], boxes_wh[b], int(labels[b])

            ij = (xy / cell_size).ceil() - 1.0
            i, j = int(ij[0]), int(ij[1])  # y & x index which represents its location on the grid.
            x0y0 = ij * cell_size  # x & y of the cell left-top corner.
            xy_normalized = (xy - x0y0) / cell_size  # x & y of the box on the cell, normalized from 0.0 to 1.0.

            # TBM, remove redundant dimensions from target tensor.
            # To remove these, loss implementation also has to be modified.
            for k in range(B):
                S = 5 * k
                target[j, i, S:S + 2] = xy_normalized
                target[j, i, S + 2:S + 4] = wh
                target[j, i, S + 4] = 1.0
            target[j, i, 5 * B + label] = 1.0

        return target

    @staticmethod
    def random_flip(img, boxes):
        if random.random() < 0.5:
            return img, boxes

        h, w, _ = img.shape

        img = np.fliplr(img)

        x1, x2 = boxes[:, 0], boxes[:, 2]
        x1_new = w - x2
        x2_new = w - x1
        boxes[:, 0], boxes[:, 2] = x1_new, x2_new

        return img, boxes

    @staticmethod
    def random_scale(img, boxes):
        if random.random() < 0.5:
            return img, boxes

        scale = random.uniform(0.8, 1.2)
        h, w, _ = img.shape
        img = cv2.resize(img, dsize=(int(w * scale), h), interpolation=cv2.INTER_LINEAR)
        scale_tensor = torch.FloatTensor([[scale, 1.0, scale, 1.0]]).expand_as(boxes)
        boxes = boxes * scale_tensor

        return img, boxes

    @staticmethod
    def random_blur(bgr):
        if random.random() < 0.5:
            return bgr

        ksize = random.choice([2, 3, 4, 5])
        bgr = cv2.blur(bgr, (ksize, ksize))
        return bgr

    @staticmethod
    def random_brightness(bgr):
        if random.random() < 0.5:
            return bgr

        hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        adjust = random.uniform(0.5, 1.5)
        v = v * adjust
        v = np.clip(v, 0, 255).astype(hsv.dtype)
        hsv = cv2.merge((h, s, v))
        bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

        return bgr

    @staticmethod
    def random_hue(bgr):
        if random.random() < 0.5:
            return bgr

        hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        adjust = random.uniform(0.8, 1.2)
        h = h * adjust
        h = np.clip(h, 0, 255).astype(hsv.dtype)
        hsv = cv2.merge((h, s, v))
        bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

        return bgr

    @staticmethod
    def random_saturation(bgr):
        if random.random() < 0.5:
            return bgr

        hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        adjust = random.uniform(0.5, 1.5)
        s = s * adjust
        s = np.clip(s, 0, 255).astype(hsv.dtype)
        hsv = cv2.merge((h, s, v))
        bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

        return bgr

    def random_shift(self, img, boxes, labels):
        if random.random() < 0.5:
            return img, boxes, labels

        center = (boxes[:, 2:] + boxes[:, :2]) / 2.0

        h, w, c = img.shape
        img_out = np.zeros((h, w, c), dtype=img.dtype)
        mean_bgr = self.mean[::-1]
        img_out[:, :] = mean_bgr

        dx = random.uniform(-w * 0.2, w * 0.2)
        dy = random.uniform(-h * 0.2, h * 0.2)
        dx, dy = int(dx), int(dy)

        if dx >= 0 and dy >= 0:
            img_out[dy:, dx:] = img[:h - dy, :w - dx]
        elif dx >= 0 and dy < 0:
            img_out[:h + dy, dx:] = img[-dy:, :w - dx]
        elif dx < 0 and dy >= 0:
            img_out[dy:, :w + dx] = img[:h - dy, -dx:]
        elif dx < 0 and dy < 0:
            img_out[:h + dy, :w + dx] = img[-dy:, -dx:]

        center = center + torch.FloatTensor([[dx, dy]]).expand_as(center)  # [n, 2]
        mask_x = (center[:, 0] >= 0) & (center[:, 0] < w)  # [n,]
        mask_y = (center[:, 1] >= 0) & (center[:, 1] < h)  # [n,]
        mask = (mask_x & mask_y).view(-1, 1)  # [n, 1], mask for the boxes within the image after shift.

        boxes_out = boxes[mask.expand_as(boxes)].view(-1, 4)  # [m, 4]
        if len(boxes_out) == 0:
            return img, boxes, labels
        shift = torch.FloatTensor([[dx, dy, dx, dy]]).expand_as(boxes_out)  # [m, 4]

        boxes_out = boxes_out + shift
        boxes_out[:, 0] = boxes_out[:, 0].clamp_(min=0, max=w)
        boxes_out[:, 2] = boxes_out[:, 2].clamp_(min=0, max=w)
        boxes_out[:, 1] = boxes_out[:, 1].clamp_(min=0, max=h)
        boxes_out[:, 3] = boxes_out[:, 3].clamp_(min=0, max=h)

        labels_out = labels[mask.view(-1)]

        return img_out, boxes_out, labels_out

    def random_crop(self, img, boxes, labels):
        if random.random() < 0.5:
            return img, boxes, labels

        center = (boxes[:, 2:] + boxes[:, :2]) / 2.0

        h_orig, w_orig, _ = img.shape
        h = random.uniform(0.6 * h_orig, h_orig)
        w = random.uniform(0.6 * w_orig, w_orig)
        y = random.uniform(0, h_orig - h)
        x = random.uniform(0, w_orig - w)
        h, w, x, y = int(h), int(w), int(x), int(y)

        center = center - torch.FloatTensor([[x, y]]).expand_as(center)  # [n, 2]
        mask_x = (center[:, 0] >= 0) & (center[:, 0] < w)  # [n,]
        mask_y = (center[:, 1] >= 0) & (center[:, 1] < h)  # [n,]
        mask = (mask_x & mask_y).view(-1, 1)  # [n, 1], mask for the boxes within the image after crop.

        boxes_out = boxes[mask.expand_as(boxes)].view(-1, 4)  # [m, 4]
        if len(boxes_out) == 0:
            return img, boxes, labels
        shift = torch.FloatTensor([[x, y, x, y]]).expand_as(boxes_out)  # [m, 4]

        boxes_out = boxes_out - shift
        boxes_out[:, 0] = boxes_out[:, 0].clamp_(min=0, max=w)
        boxes_out[:, 2] = boxes_out[:, 2].clamp_(min=0, max=w)
        boxes_out[:, 1] = boxes_out[:, 1].clamp_(min=0, max=h)
        boxes_out[:, 3] = boxes_out[:, 3].clamp_(min=0, max=h)

        labels_out = labels[mask.view(-1)]
        img_out = img[y:y + h, x:x + w, :]

        return img_out, boxes_out, labels_out


def test():
    from torch.utils.data import DataLoader

    base_dir = 'data'

    import os
    
    train_names = [f.rsplit('.jpg', 1)[0] for f in os.listdir(base_dir + '/train') if f.endswith('.jpg')]

    dataset = VOCDataset(is_train=True, file_names=train_names, base_dir=base_dir)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

    data_iter = iter(data_loader)
    for i in range(100):
        img, target = next(data_iter)
        print(img.size(), target.size())


if __name__ == '__main__':
    test()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import config

class Loss(nn.Module):
    def __init__(self, fs=config.S, nb=config.B, nc=config.C, lambda_coord=5.0, lambda_noobj=0.5):
        super(Loss, self).__init__()

        self.FS = fs
        self.NB = nb
        self.NC = nc
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj

    def compute_iou(self, box1, box2):

        def box_area(box):
            # box = 4xn
            return (box[2] - box[0]) * (box[3] - box[1])

        area1 = box_area(box1.T)
        area2 = box_area(box2.T)

        inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
        return inter / (area1[:, None] + area2 - inter)

    def forward(self, prediction, target_tensor):

        S, B, C = self.FS, self.NB, self.NC
        N = 5 * B + C  # 5=len([x, y, w, h, conf]

        batch_size = prediction.size(0)
        coord_mask = target_tensor[:, :, :, 4] > 0
        noobj_mask = target_tensor[:, :, :, 4] == 0
        coord_mask = coord_mask.unsqueeze(-1).expand_as(target_tensor)
        noobj_mask = noobj_mask.unsqueeze(-1).expand_as(target_tensor)

        coord_pred = prediction[coord_mask].view(-1, N)
        bbox_pred = coord_pred[:, :5 * B].contiguous().view(-1, 5)
        class_pred = coord_pred[:, 5 * B:]

        coord_target = target_tensor[coord_mask].view(-1, N)
        bbox_target = coord_target[:, :5 * B].contiguous().view(-1, 5)
        class_target = coord_target[:, 5 * B:]

        # Compute loss for the cells with no object bbox.
        noobj_pred = prediction[noobj_mask].view(-1, N)
        noobj_target = target_tensor[noobj_mask].view(-1, N)
        noobj_conf_mask = torch.zeros(noobj_pred.size(), dtype=torch.bool)
        for b in range(B):
            noobj_conf_mask[:, 4 + b * 5] = 1
        noobj_pred_conf = noobj_pred[noobj_conf_mask]
        noobj_target_conf = noobj_target[noobj_conf_mask]
        loss_noobj = F.mse_loss(noobj_pred_conf, noobj_target_conf, reduction='sum')

        # Compute loss for the cells with objects.
        coord_response_mask = torch.zeros(bbox_target.size(), dtype=torch.bool)
        coord_not_response_mask = torch.ones(bbox_target.size(), dtype=torch.bool)
        bbox_target_iou = torch.zeros(bbox_target.size())

        # Choose the predicted bbox having the highest IoU for each target bbox.
        for i in range(0, bbox_target.size(0), B):
            pred = bbox_pred[i:i + B]
            pred_xyxy = torch.zeros(pred.size())

            pred_xyxy[:, :2] = pred[:, :2] / float(S) - 0.5 * pred[:, 2:4]
            pred_xyxy[:, 2:4] = pred[:, :2] / float(S) + 0.5 * pred[:, 2:4]

            target = bbox_target[i]
            target = bbox_target[i].view(-1, 5)
            target_xyxy = torch.zeros(target.size())

            target_xyxy[:, :2] = target[:, :2] / float(S) - 0.5 * target[:, 2:4]
            target_xyxy[:, 2:4] = target[:, :2] / float(S) + 0.5 * target[:, 2:4]

            iou = self.compute_iou(pred_xyxy[:, :4], target_xyxy[:, :4])  # [B, 1]
            max_iou, max_index = iou.max(0)
            max_index = max_index.data

            coord_response_mask[i + max_index] = 1
            coord_not_response_mask[i + max_index] = 0

            bbox_target_iou[i + max_index, 4] = max_iou.data
        
        # Move tensors to the same device as prediction
        device = prediction.device
        bbox_target_iou = bbox_target_iou.to(device)
        coord_response_mask = coord_response_mask.to(device)
        coord_not_response_mask = coord_not_response_mask.to(device)

        bbox_pred_response = bbox_pred[coord_response_mask].view(-1, 5)
        bbox_target_response = bbox_target[coord_response_mask].view(-1, 5)
        target_iou = bbox_target_iou[coord_response_mask].view(-1, 5)

        loss_xy = F.mse_loss(bbox_pred_response[:, :2], bbox_target_response[:, :2], reduction='sum')

        loss_wh = F.mse_loss(torch.sqrt(bbox_pred_response[:, 2:4]),
                             torch.sqrt(bbox_target_response[:, 2:4]), reduction='sum')

        loss_obj = F.mse_loss(bbox_pred_response[:, 4], target_iou[:, 4], reduction='sum')

        loss_class = F.mse_loss(class_pred, class_target, reduction='sum')

        # Total loss
        loss = self.lambda_coord * (loss_xy + loss_wh) + loss_obj + self.lambda_noobj * loss_noobj + loss_class
        loss = loss / float(batch_size)

        return loss

In [None]:
import torch
import torch.nn as nn
import config

def pad(k, p):
    if p is None:
        p = k // 2
    return p


class Conv(nn.Module):
    def __init__(self, c1, c2, k, s=1, p=None, d=1, g=1, act=True):
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, pad(k, p), dilation=d, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2, momentum=0.03, eps=1e-3) # TODO: check that it is in paper
        self.act = nn.LeakyReLU(0.01, inplace=True) if act else nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()

    @staticmethod
    def forward(x):
        return torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    @staticmethod
    def forward(x):
        return x.view(x.size(0), -1)



class Backbone(nn.Module):
    def __init__(self, num_classes=config.C, init_weight=True):
        super().__init__()

        self.features = nn.Sequential(
            Conv(3, 64, 7, 2),
            nn.MaxPool2d(2, 2),
            
            Conv(64, 192, 3),
            nn.MaxPool2d(2, 2),
            
            Conv(192, 128, 1),
            Conv(128, 256, 3),
            Conv(256, 256, 1),
            Conv(256, 512, 3),
            nn.MaxPool2d(2, 2),
            
            Conv(512, 256, 1),
            Conv(256, 512, 3),
            Conv(512, 256, 1),
            Conv(256, 512, 3),
            Conv(512, 256, 1),
            Conv(256, 512, 3),
            Conv(512, 256, 1),
            Conv(256, 512, 3),
            Conv(512, 512, 1),
            Conv(512, 1024, 3),
            nn.MaxPool2d(2, 2),
            
            Conv(1024, 512, 1),
            Conv(512, 1024, 3),
            Conv(1024, 512, 1),
            Conv(512, 1024, 3)
        )

        layers = [
            *self.features,
            GlobalAvgPool2d(),
            nn.Linear(1024, num_classes)
        ]

        self.classifier = nn.Sequential(*layers)

        if init_weight:
            self._initialize_weights()

    def forward(self, x):
        return self.classifier(x)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

class Head(nn.Module):
    def __init__(self, feature_size, num_boxes, num_classes):
        super().__init__()

        self.conv = nn.Sequential(
            Conv(1024, 1024, 3),
            Conv(1024, 1024, 3, 2),
            Conv(1024, 1024, 3),
            Conv(1024, 1024, 3)
        )

        self.detect = nn.Sequential(
            Flatten(),
            nn.Linear(7 * 7 * 1024, 4096),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Linear(4096, feature_size * feature_size * (5 * num_boxes + num_classes)),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.detect(x)
        return x


class YOLOv1(nn.Module):
    def __init__(self, fs=config.S, nb=config.B, nc=config.C, pretrained_backbone=False):
        super(YOLOv1, self).__init__()

        self.FS = fs
        self.NB = nb
        self.NC = nc
        if pretrained_backbone:
            self.features = Backbone().features
            darknet = Backbone()
            darknet = nn.DataParallel(darknet)
            src_state_dict = torch.load('model_best.pth.tar')['state_dict']
            dst_state_dict = darknet.state_dict()

            for k in dst_state_dict.keys():
                print('Loading weight of', k)
                dst_state_dict[k] = src_state_dict[k]
            darknet.load_state_dict(dst_state_dict)
            self.features = darknet.module.features
        else:
            self.features = Backbone().features
        self.head = Head(fs, nb, nc)

    def forward(self, x):
        x = self.features(x)
        x = self.head(x)

        x = x.view(-1, self.FS, self.FS, 5 * self.NB + self.NC)
        return x



In [None]:
import torch
from torch.utils.data import DataLoader

from loaders.data_loader import VOCDataset
from model.model import YOLOv1
from model.loss import Loss

import os
import math
import tqdm
import numpy as np
from collections import defaultdict
import glob

# %% cell 5 code
# Learning rate scheduling
def update_lr(optimizer, epoch, burning_base, burning_exp=4.0, init_lr=0.001, base_lr=0.01):
    if epoch == 0:
        lr = init_lr + (base_lr - init_lr) * math.pow(burning_base, burning_exp)
    elif epoch == 1:
        lr = base_lr
    elif epoch == 75:
        lr = 0.001
    elif epoch == 105:
        lr = 0.0001
    else:
        return

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


# %% cell 6 code
# Set up parameters
base_dir = './data'
log_dir = './weights'
init_lr = 0.001
base_lr = 0.01
momentum = 0.9
weight_decay = 5.0e-4
num_epochs = 135
batch_size = 64
seed = 42

# Create weights directory
os.makedirs(log_dir, exist_ok=True)
np.random.seed(seed)
torch.manual_seed(seed)

# Check if GPU devices are available
print(f'CUDA DEVICE COUNT: {torch.cuda.device_count()}')

# Check for MPS (Apple Silicon) support first, then CUDA, then fall back to CPU
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

# %% cell 7 code
# Load YOLO model
net = YOLOv1(pretrained_backbone=False).to(device)
if torch.cuda.device_count() > 1:
    net = torch.nn.DataParallel(net)

accumulate = max(round(64 / batch_size), 1)

params = defaultdict()
params['weight_decay'] = weight_decay
params['weight_decay'] *= batch_size * accumulate / 64

pg0, pg1, pg2 = [], [], []
for k, v in net.named_modules():
    if hasattr(v, 'bias') and isinstance(v.bias, torch.nn.Parameter):
        pg2.append(v.bias)
    if isinstance(v, torch.nn.BatchNorm2d):
        pg0.append(v.weight)
    elif hasattr(v, 'weight') and isinstance(v.weight, torch.nn.Parameter):
        pg1.append(v.weight)

optimizer = torch.optim.SGD(pg0, lr=init_lr, momentum=momentum, nesterov=True)
optimizer.add_param_group({'params': pg1, 'weight_decay': params['weight_decay']})
optimizer.add_param_group({'params': pg2})

# Setup loss
criterion = Loss()

# %% cell 8 code
# Prepare datasets
train_names = [f.rsplit('.jpg', 1)[0] for f in os.listdir(base_dir + '/train') if f.endswith('.jpg')]
val_names = [f.rsplit('.jpg', 1)[0] for f in os.listdir(base_dir + '/valid') if f.endswith('.jpg')]

# Create datasets with specific file lists
train_dataset = VOCDataset(
    is_train=True,
    base_dir=base_dir,
    file_names=train_names,
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=min(batch_size, len(train_dataset)), 
    shuffle=True, 
    num_workers=4
)

val_dataset = VOCDataset(
    is_train=False, 
    base_dir=base_dir,
    file_names=val_names,
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=min(batch_size // 2, max(1, len(val_dataset))), 
    shuffle=False, 
    num_workers=4
)

print('Number of training images: ', len(train_dataset))
print('Number of validation images: ', len(val_dataset))

# %% cell 9 code
# Training loop
best_val_loss = np.inf

for epoch in range(num_epochs):
    print('\n')
    print(f'Starting epoch {epoch} / {num_epochs}')

    # Training
    net.train()
    total_loss = 0.0
    total_batch = 0
    print(('\n' + '%10s' * 3) % ('epoch', 'loss', 'gpu'))
    progress_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (images, targets) in progress_bar:
        # Update learning rate
        update_lr(optimizer, epoch, float(i) / float(len(train_loader) - 1), 
                 init_lr=init_lr, base_lr=base_lr)
        lr = get_lr(optimizer)

        # Load data as a batch
        batch_size_this_iter = images.size(0)
        images, targets = images.to(device), targets.to(device)

        # Forward to compute loss
        predictions = net(images)
        loss = criterion(predictions, targets)
        loss_this_iter = loss.item()
        total_loss += loss_this_iter * batch_size_this_iter
        total_batch += batch_size_this_iter

        # Backward to update model weight
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)
        s = ('%10s' + '%10.4g' + '%10s') % ('%g/%g' % (epoch + 1, num_epochs), total_loss / (i + 1), mem)
        progress_bar.set_description(s)

    # Validation
    net.eval()
    val_loss = 0.0
    total_batch = 0

    progress_bar = tqdm.tqdm(enumerate(val_loader), total=len(val_loader))
    for i, (images, targets) in progress_bar:
        # Load data as a batch
        batch_size_this_iter = images.size(0)
        images, targets = images.to(device), targets.to(device)

        # Forward to compute validation loss
        with torch.no_grad():
            predictions = net(images)
        loss = criterion(predictions, targets)
        loss_this_iter = loss.item()
        val_loss += loss_this_iter * batch_size_this_iter
        total_batch += batch_size_this_iter
    val_loss /= float(total_batch)

    # Save results
    save = {'state_dict': net.state_dict()}
    torch.save(save, os.path.join(log_dir, 'final.pth'))
    if best_val_loss > val_loss:
        best_val_loss = val_loss
        save = {'state_dict': net.state_dict()}
        torch.save(save, os.path.join(log_dir, 'best.pth'))

    # Print
    print(f'Epoch [{epoch + 1}/{num_epochs}], Val Loss: {val_loss:.4f}, Best Val Loss: {best_val_loss:.4f}')