<a href="https://colab.research.google.com/github/baiyunming/I2DL_FinalProject/blob/main/TrainYolo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Import necessary packages**

In [None]:
import os
import torch
from google.colab import drive
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm.notebook import tqdm

#**Move to directory containing model.py, loss.py, animal_dataset.py, utils.py**


In [None]:
drive.mount('/content/drive')
%cd drive/MyDrive/ObjectDetection/

Mounted at /content/drive
/content/drive/MyDrive/ObjectDetection


**import necessary classes**

In [None]:
from model import YOLO_Resnet
from loss import YoloLoss
from animal_dataset import AnimalDataset
from utils import non_max_suppression, mean_average_precision, get_list_boxes, calculate_map

#**Dataset and dataloader**

**In order to create an inatance of class AnimalDataset, path to the csv file generated by Generate_txt_csv.ipynb is needed.**

In [None]:
train_dataset = AnimalDataset("/content/drive/MyDrive/ObjectDetection/animal_dataset/train_augmented.csv")
train_batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

valid_dataset = AnimalDataset("/content/drive/MyDrive/ObjectDetection/animal_dataset/test_augmented.csv")
valid_batch_size = 16
valid_dataloader = DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=True)

# **Model**

**Option 1: generate new model**

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = YOLO_Resnet().to(device)

**Option 2: load pretrained model (need to modify file_path to the corresponding path of the pretrained model) and you can directly skip to "Visualization of result on validation(test) dataset" part** 




In [None]:
#load pretrained model
device = "cuda" if torch.cuda.is_available() else "cpu"
file_path = "./model1/best_model.pth"
checkpoint = torch.load(file_path)
model = YOLO_Resnet()
model.load_state_dict(checkpoint["state_dict"]) 
model = model.to(device)

#**Optimizer**

In [None]:
optimizer = optim.Adam(model.parameters(), lr=2e-5)

# **Tensorboard for visualization of training process**

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
from tensorflow import summary
import tensorflow as tf

In [None]:
train_log_dir = './run/train'
train_summary_writer = summary.create_file_writer(train_log_dir)
val_log_dir = './run/validate'
val_summary_writer = summary.create_file_writer(val_log_dir)

In [None]:
%tensorboard --logdir run

#**Define train and validate function**

In [None]:
def train(model, optimizer, image, label):
    model.train()
    batch_size = image.shape[0]

    criterion = YoloLoss()

    optimizer.zero_grad()

    input = image.to(device)
    target = label.to(device) 
    
    pred = model(input)
    loss = criterion(pred, target)/batch_size
    
    loss.backward()
       
    optimizer.step()

    return loss.item()

In [None]:
def validate(model, image, label):
    model.eval()
    batch_size = image.shape[0]

    criterion = YoloLoss()

    optimizer.zero_grad()

    with torch.no_grad():
      input = image.to(device)
      target = label.to(device) 
    
      pred = model(input)
      loss = criterion(pred, target)/batch_size

    return loss.item()

# **Start Training**

In [None]:
max_epoch = 50
max_map = -1
tmp_path = './checkpoint_model.pth'

for epoch in tqdm(range(max_epoch)):
    train_loss = 0 
    train_map = 0
    valid_loss = 0
    valid_map = 0


    # Iterate over the train_dataloader
    with tqdm(total=len(train_dataloader)) as pbar:
        for idx, [image, label, _] in enumerate(train_dataloader):
          curr_loss = train(model, optimizer, image, label)
          train_loss += curr_loss / len(train_dataloader)
          pbar.update(1)

    checkpoint = {"state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),}
    torch.save(checkpoint, tmp_path)

    train_map = calculate_map(model, train_dataloader, iou_threshold=0.4, confidence_threshold=0.5)
    
    with train_summary_writer.as_default():
        tf.summary.scalar('loss', train_loss, step=epoch+80)                
        tf.summary.scalar('map', train_map, step=epoch+80)     


    with tqdm(total=len(valid_dataloader)) as pbar:
        for idx, [image, label, _] in enumerate(valid_dataloader):
          curr_loss = validate(model, image, label)
          valid_loss += curr_loss / len(valid_dataloader)
          pbar.update(1)

    valid_map = calculate_map(model, valid_dataloader, iou_threshold=0.4, confidence_threshold=0.5)


    with val_summary_writer.as_default():
        tf.summary.scalar('loss', valid_loss, step=epoch+80)
        tf.summary.scalar('map', valid_map, step=epoch+80) 


    max_map = max(valid_map, max_map)
    if max_map == valid_map:
      filename = './best_model.pth'
      print("=> Saving checkpoint")
      torch.save(checkpoint, filename)


    print(train_loss, valid_loss)
    print(train_map, valid_map)


# **Visualization of result on validation(test) dataset**

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

**define function for plotting the bounding boxes**

In [None]:
def plot_result(image_path, boxes):
    """Plots predicted bounding boxes on the image"""
    im = Image.open(image_path)
    width, height = im.size

    # Create figure and axes
    fig, ax = plt.subplots(1)
    # Display the image
    ax.imshow(im)

    for box in bboxes:
        pred_class_num = box[0]
        score = box[1]
        score = round(score, 2)
        box = box[2:]
        
        if pred_class_num == 0:
          pred_class = "buffalo"
        elif pred_class_num == 1:
          pred_class = "elephant"
        elif pred_class_num == 2:
          pred_class = "rhino"
        else:
          pred_class = "zebra"
        
        text = pred_class + "(" + str(score) + ")"
        #print(text)

        assert len(box) == 4, "Got more values than in x, y, w, h, in a box!"
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        rect = patches.Rectangle(
            (upper_left_x * width, upper_left_y * height),
            box[2] * width,
            box[3] * height,
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
        # Add the patch to the Axes
        ax.text(upper_left_x*width, upper_left_y*height, text, bbox=dict(facecolor='red', alpha=0.5))
        ax.add_patch(rect)
    plt.show()

**plot result on all validation(test) images**

In [None]:
for x, y, path in valid_dataloader:
  with torch.no_grad():
    x = x.to(device)
    for idx in range(x.shape[0]):
      bboxes = get_list_boxes(model(x), S=5)
      bboxes = non_max_suppression(bboxes[idx], 0.5, 0.5)
      #print(bboxes)
      #print(path[idx])
      #print("show_result")
      plot_result(path[idx], bboxes)

# **Calculate MAP for different iou_thresholds**

In [None]:
iou_thresholds = [0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6]

In [None]:
result_map = []
for iou in iou_thresholds:
  result = calculate_map(model, valid_dataloader, iou_threshold=iou, confidence_threshold=0.5)
  print("iou_threshold:"+ str(iou) + " MAP: " + str(result))
  result_map.append(iou)