In [None]:
import torch
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torch.utils.data import DataLoader
from model import YOLOv1
from dataset import COCODataset, print_sample
from utils import (
    convert_cellboxes,
    plot_image,
    save_checkpoint,
    load_checkpoint,
)
from train import Compose
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from loss import YoloLoss

# Parameters

In [None]:
hp = {
    # model config
    'S': 4,
    'B': 2,
    'dropout': 0.5,
    'image_size': 256,
    # training config
    'lr': 2e-5,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'batch_size': 8,
    'weight_decay': 0,
    'num_epochs': 10,
    'num_worker': 0,
    'Pin_memory': True,
    'load_model': True,
    'load_model_file': 'overfit.pth.tar',
    'max_training_samples': 100,
    # loss config
    'lambda_coord': 5,
    'lambda_noobj': 0.5,
    # validation config 
    'threshold':0.4,
}

# Util Functions

In [None]:
def non_max_suppression(bboxes, threshold):
    assert type(bboxes) == list
    bboxes = [box for box in bboxes if box[0] > threshold]
    return bboxes

In [None]:
def cellboxes_to_boxes(out, S=hp['S']):
    converted_pred = convert_cellboxes(out,S)
    converted_pred = converted_pred.reshape(out.shape[0], S * S, -1)
    return converted_pred.tolist()

In [None]:
def intersection_over_union(boxes_preds, boxes_labels):
    print(boxes_preds.shape)
    print(boxes_labels.shape)
    box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
    box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
    box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
    box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
    box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
    box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
    box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
    box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    # .clamp(0) is for the case when they do not intersect
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)

    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [None]:
def get_mean_iou(
        loader,
        model,
        threshold,
        split_size=hp['S'],
        sample_batch_size = 10):
    model.eval()
    ious = []
    
    for batch_idx, (x,y) in enumerate(loader):
        print(x.shape)
        print(y.shape)
        if batch_idx > 10:
            break
        x = x.to(hp['device'])
        y = y.to(hp['device'])
        
        with torch.no_grad():
            preds = model(x)
        print(preds.shape)    
            
        for idx,pred in enumerate(preds):
            print(pred.shape)
            pred = pred.view([hp['S']*hp['S']*2,5])
            label = y[idx].view([hp['S']*hp['S']*2,5])
            print(pred.shape)
            suppressed_pred = non_max_suppression(pred.tolist(),threshold)
            print(torch.Tensor(suppressed_pred).shape)
            iou = intersection_over_union(torch.Tensor(suppressed_pred),label)
            ious.append(iou.item())
            break
        break
        
    return sum(ious)/len(ious)

## Load Dataset

In [None]:
transform = Compose([transforms.Resize((hp['image_size'], hp['image_size'])), transforms.ToTensor()])

### Load Training Data

In [None]:
train_dataset_100 = COCODataset(transform=transform)
train_dataset_100.load_dataset()

In [None]:
train_loader = DataLoader(  dataset=train_dataset_100, batch_size=hp["batch_size"], num_workers=hp["num_worker"],
                            pin_memory=hp["Pin_memory"], shuffle=True, drop_last=False)
print(f"Train loader initialized with: batch_size={hp['batch_size']} on device: {hp['device']}")

#### Print Stats

In [None]:
print(f"Training samples: {len(train_dataset_100)}")
# Display the first batch samples with true boxes
for x,y in train_loader:
    for idx in range(hp['batch_size']):
        print(y[idx])
        real_boxes = cellboxes_to_boxes(y.flatten(start_dim=1),hp['S'])
        print(f"Image shape: {x[idx].shape}")
        print(f"Box shape: {len(real_boxes[idx])}")
        print(f"The first Boxes are: {real_boxes[idx]}")
        plot_image(x[idx].permute(1,2,0).to("cpu"), real_boxes[idx])
    break

### Load Validation Data

In [None]:
val_dataset_100 = COCODataset(transform=transform)
val_dataset_100.load_dataset("validation")

In [None]:
val_loader = DataLoader(dataset=val_dataset_100, batch_size=hp["batch_size"], num_workers=hp["num_worker"],
                          pin_memory=hp["Pin_memory"], shuffle=True, drop_last=False)

#### Print Stats

In [None]:
print(f"Validation samples: {len(val_dataset_100)}")

In [None]:
for x,y in val_loader:
    for idx in range(hp['batch_size']):
        real_boxes = cellboxes_to_boxes(y.flatten(start_dim=1),hp['S'])
        print(f"Image shape: {x[idx].shape}")
        print(f"Box shape: {len(real_boxes[idx])}")
        print(f"The first Boxes are: {real_boxes[idx]}")
        plot_image(x[idx].permute(1,2,0).to("cpu"), real_boxes[idx])
    break

# Train Model

In [None]:
model = YOLOv1().to(hp["device"])

In [None]:
def test():
    x = torch.randn((8, 3, 256, 256))
    print(model(x).shape)
test()
print(f"Shape should be: [{hp['batch_size']}, {hp['S']*hp['S']*hp['B']*5}]")

In [None]:
optimizer = optim.Adam(model.parameters(), lr=hp["lr"], weight_decay=hp["weight_decay"])

In [None]:
loss_fn = YoloLoss()

In [None]:
def intermediate_print(out):
        for idx in range(8):
            bboxes = cellboxes_to_boxes(out, S=hp["S"])
            plot_image(x[idx].permute(1, 2, 0).to("cpu"), bboxes[idx])

In [None]:
def train_fn(train_loader,val_loader, model, optimizer, loss_fn):
    loop = tqdm(train_loader, leave=True)
    ridx = torch.randint(0, len(train_loader), (1,)).item()
    #------------------- Training -------------------#
    mean_train_loss = []
    mean_train_box_loss = []
    mean_train_obj_loss = []
    mean_train_noobj_loss = []
    for batch_idx, (x, y) in enumerate(loop):
        x, y = x.to(hp["device"]), y.to(hp["device"])
        out = model(x)
        loss,box_loss,obj_loss,noobj_loss = loss_fn(out, y)
        mean_train_loss.append(loss.item())
        mean_train_box_loss.append(box_loss.item())
        mean_train_obj_loss.append(obj_loss.item())
        mean_train_noobj_loss.append(noobj_loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #if batch_idx == ridx:
            #intermediate_print(batch_idx,ridx,out)
        # update progress bar
        loop.set_postfix(loss=loss.item())

    print(f"Mean train loss was {sum(mean_train_loss)/len(mean_train_loss)}")
    print(f"Mean train Box loss was {sum(mean_train_box_loss)/len(mean_train_box_loss)}")
    print(f"Mean train Obj loss was {sum(mean_train_obj_loss)/len(mean_train_obj_loss)}")
    print(f"Mean train Noobj loss was {sum(mean_train_noobj_loss)/len(mean_train_noobj_loss)}")

    #------------------- Validation -------------------#
    model.eval()
    mean_val_loss = []
    mean_val_box_loss = []
    mean_val_obj_loss = []
    mean_val_noobj_loss = []
    with torch.no_grad():
        for (x, y) in val_loader:
            x, y = x.to(hp["device"]), y.to(hp["device"])
            out = model(x)
            loss,box_loss,obj_loss,noobj_loss = loss_fn(out, y)
            mean_val_loss.append(loss.item())
            mean_val_box_loss.append(box_loss.item())
            mean_val_obj_loss.append(obj_loss.item())
            mean_val_noobj_loss.append(noobj_loss.item())
    model.train()
    print(f"Mean validation loss was {sum(mean_val_loss)/len(mean_val_loss)}")
    print(f"Mean validation Box loss was {sum(mean_val_box_loss)/len(mean_val_box_loss)}")
    print(f"Mean validation Obj loss was {sum(mean_val_obj_loss)/len(mean_val_obj_loss)}")
    print(f"Mean validation Noobj loss was {sum(mean_val_noobj_loss)/len(mean_val_noobj_loss)}")
    return sum(mean_val_loss)/len(mean_val_loss)

In [None]:
best_loss = 1000
for epoch in range(hp["num_epochs"]):
    """
    #------------------- Training IOU -------------------#
    train_iou = get_mean_iou(
            train_loader, model, threshold=0.4, split_size=hp["S"]
    )
    print(f"Train mAP: {train_iou}")
    #------------------- Validation IOU -------------------#
    val_iou = get_mean_iou(
            train_loader, model, threshold=0.4, split_size=hp["S"]
    )
    print(f"Val mAP: {val_iou}")
    #------------------- Checkpointing -------------------#
    if val_iou >= best_iou:
        best_iou = val_iou
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint, filename=hp["load_model_file"])
    """
    #------------------- Training -------------------#
    val_loss = train_fn(train_loader,val_loader, model, optimizer, loss_fn)
    if val_loss >= best_loss:
        best_loss = val_loss
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        print("=> Saving checkpoint")
        torch.save(checkpoint, hp["load_model_file"])

# Validation

In [None]:
def validation_print(out):
    bboxes = cellboxes_to_boxes(out, S=hp["S"])
    for idx in range(8):
        best_boxes = non_max_suppression(bboxes[idx],hp['threshold'])
        plot_image(x[idx].permute(1, 2, 0).to("cpu"), best_boxes)

In [None]:
model.eval()
with torch.no_grad():
    for (x, y) in val_loader:
        out = model(x)
        validation_print(out)
        break
model.train()