# Fine Tuning Tiny Yolo

The Tiny Yolo Network is fine tuned to only detect people.

## Prepare Workspace

### Google Drive Flag

In [8]:
GOOGLE_DRIVE = True

### Mount Google Drive

In [None]:
if GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')

### Create the directories needed and place uploaded files inside them

In [None]:
if GOOGLE_DRIVE:
    !pip install torchinfo
    !pip install torchvision pillow

    !mkdir /content/data

    !cp /content/drive/MyDrive/eml_challenge/data/person_indices.json /content/data
    !cp -r /content/drive/MyDrive/eml_challenge/utils /content
    !cp /content/drive/MyDrive/eml_challenge/tinyyolov2.py /content

### Append directory paths to system path

In [11]:
if GOOGLE_DRIVE:
    import sys
    sys.path.append('/content')
    sys.path.append('/content/data')
    sys.path.append('/content/utils')
    sys.path.append('/content/drive/MyDrive/eml_challenge/weights')

## Executing Workspace

### Importing essential libraries

In [12]:
import torch
import torchinfo
import torch.nn as nn
from torchvision.transforms import v2
import numpy as np
import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau # Import ReduceLROnPlateau to reduce learning rate of optimizer after Plateau

from tinyyolov2 import TinyYoloV2
from utils.loss import YoloLoss
from utils.dataloaderv2 import VOCDataset
from utils.ap import precision_recall_levels, ap, display_roc
from utils.yolo import nms, filter_boxes
from utils.viz import display_result

### Define data augmentation pipeline

In [13]:
pipeline = v2.Compose([
    v2.RandomPhotometricDistort(p=0.5),
    v2.RandomHorizontalFlip(p=0.5),
])

### Define data and test loaders

In [None]:
train_dataset = VOCDataset(root="/content/data", year="2012", image_set='train', transform=pipeline, only_person=True) # Contains 2142 pictures
test_dataset = VOCDataset(root="/content/data", year="2012", image_set='val', transform=None, only_person=True)        # Contains 2232 pictures

data_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

### Define Early Stopping class

In [15]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0,
                 path='/content/drive/MyDrive/eml_challenge/weights/checkpoint.pt',
                 best_model_path='/content/drive/MyDrive/eml_challenge/weights/voc_fine_tuned.pt'):
        """
        Args:
            patience (int): How long to wait after last improvement.
            verbose (bool): If True, prints a message for each validation metric improvement.
            delta (float): Minimum change in the monitored metric to qualify as an improvement.
            path (str): Path to save the best model checkpoint.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path
        self.best_model_path = best_model_path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.avg_precision_min = 0 # Track the minimum average precision

    def __call__(self, avg_precision, model):
        score = avg_precision  # Positiv because we maximize AP

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(avg_precision, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                torch.save(model.state_dict(), self.best_model_path)
        else:
            self.best_score = score
            self.save_checkpoint(avg_precision, model)
            self.counter = 0

    def save_checkpoint(self, avg_precision, model):
        """Save model when average precision increases."""
        if self.verbose:
            print(f"Average Precision increased ({self.avg_precision_min:.6f} --> {avg_precision:.6f}). Saving model...")
        torch.save(model.state_dict(), self.path)
        self.avg_precision_min = avg_precision


### Define train, validate and test functions

In [16]:
def train(net: nn.Module, data_loader: torch.utils.data.DataLoader, optimizer, criterion, device):
    """
    Description:
    This function trains the network for one epoch.

    Args:
    net: the network to train
    data_loader: the data loader for the training set
    optimizer: the optimizer to use for training
    criterion: the loss function to use for training
    device: the device to use for training
    """

    net.train()
    # Move weights to device
    net.to(device)

    for idx, (input, target) in tqdm.tqdm(enumerate(data_loader), total=len(data_loader)):
        # Move Inputs and targets to Device
        input  = input.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        #Yolo head is implemented in the loss for training, therefore yolo=False
        output = net(input, yolo=False)
        loss, _ = criterion(output, target)
        loss.backward()
        optimizer.step()

def validate(net: nn.Module, data_loader: torch.utils.data.DataLoader, device, num_validation_samples):
    """
    Description:
    This function uses the first "num_validation_samples" images from the test data set to validate the network.
    Keep in mind that this function only works for batch_size=1 and shuffle=False.

    Args:
    net: the network to test
    data_loader: the data loader for the test set
    device: the device to use for training
    num_validation_samples: the number of samples to use for validation
    """

    eval_precision = []
    eval_recall = []

    net.eval()
    # Move weights to device
    net.to(device)

    with torch.no_grad():
        for idx, (input, target) in tqdm.tqdm(enumerate(test_loader), total=num_validation_samples):
            input  = input.to(device)
            target = target.to(device)
            output = net(input, yolo=True)
            #The right threshold values can be adjusted for the target application
            output = filter_boxes(output, CONFIDENCE_THRESHOLD)
            output = nms(output, NMS_THRESHOLD)
            # Calculate precision and recall for each sample
            precision, recall = precision_recall_levels(target[0], output[0])
            eval_precision.append(precision)
            eval_recall.append(recall)
            if idx == num_validation_samples:
                break

    # Calculate average precision with collected samples
    average_precision = ap(eval_precision, eval_recall)
    # Plot ROC
    display_roc(eval_precision, eval_recall)

    return average_precision


def test(net: nn.Module, data_loader: torch.utils.data.DataLoader, device, num_validation_samples, best_model_path):
    """
    Description:
    This function skips the images passed during the validation and uses the rest of the
    images from the test data set to test the network. This is done to reduce overfitting
    and improve generalization. Keep in mind that this function only works for batch_size=1
    and shuffle=False.

    Args:
    net: the network to test
    data_loader: the data loader for the test set
    device: the device to use for training
    num_validation_samples: the number of passed images to the validate function
    """

    test_precision = []
    test_recall = []

    # Load weights and move them to device
    sd = torch.load(best_model_path, weights_only=True)
    net.load_state_dict(sd)
    net.to(device)
    net.eval()

    with torch.no_grad():
        for idx, (input, target) in tqdm.tqdm(enumerate(data_loader), total=len(data_loader)):
            if idx >= num_validation_samples:
                input  = input.to(device)
                target = target.to(device)
                output = net(input, yolo=True)
                #The right threshold values can be adjusted for the target application
                output = filter_boxes(output, CONFIDENCE_THRESHOLD)
                output = nms(output, NMS_THRESHOLD)
                # Calculate precision and recall for each sample
                precision, recall = precision_recall_levels(target[0], output[0])
                test_precision.append(precision)
                test_recall.append(recall)

    # Calculate average precision with collected samples
    average_precision = ap(test_precision, test_recall)
    # Plot ROC
    display_roc(test_precision, test_recall)

    return average_precision


### Define fine Tuning function

In [17]:
def fine_tune(net: nn.Module, sd,
              data_loader: torch.utils.data.DataLoader, test_loader: torch.utils.data.DataLoader,
              num_eval_samples: int=0):

    if torch.cuda.is_available():
      torch_device = torch.device("cuda")
      print("Using GPU")
    else:
      torch_device = torch.device("cpu")
      print("Using CPU")

    eval_AP = []

    #We load all parameters from the pretrained dict except for the last layer
    net.load_state_dict({k: v for k, v in sd.items() if not '9' in k}, strict=False)

    #We only train the last layer (conv9)
    for key, param in net.named_parameters():
        if any(x in key for x in ['1', '2', '3', '4', '5', '6', '7']):
            param.requires_grad = False

    # Definition of the loss
    criterion = YoloLoss(anchors=net.anchors)

    # Definition of the optimizer
    learning_rate = 0.001
    optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=learning_rate)

    # Define the ReduceLROnPlateau scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=7, verbose=True, path=WEIGHTS_PATH+"checkpoint.pt", best_model_path=WEIGHTS_PATH+"voc_fine_tuned.pt")

    for epoch in range(NUM_EPOCHS):
        print(f"Epoch: {epoch}")

        # Train the network
        if epoch != 0:
            train(net, data_loader, optimizer, criterion, torch_device)

        # Validate the network
        average_precision = validate(net, test_loader, torch_device, num_eval_samples=num_eval_samples)
        eval_AP.append(average_precision)
        print('average precision', eval_AP)
        # Adjust learning rate in case of a Plateau of AP
        scheduler.step(average_precision)
        print(f"learning rate: {scheduler.get_last_lr()}")
        # Stop training in case there is no further improvement of AP
        early_stopping(average_precision, net)
        if early_stopping.early_stop:
            print("Early stopping triggered. Stopping training.")
            break

    if not early_stopping.early_stop:
        torch.save(net.state_dict(), WEIGHTS_PATH + "voc_fine_tuned.pt")
        print("No early stopping triggered. Training completed.")

    # Test the network
    final_average_precision = test(net, test_loader, torch_device, num_eval_samples=num_eval_samples, best_model_path=WEIGHTS_PATH+"voc_fine_tuned.pt")
    print('final average precision: ', final_average_precision)

    torch.cuda.empty_cache()

## Testing Workspace

In [None]:
# We define a tinyyolo network with only two possible classes
net = TinyYoloV2(num_classes=1)

if GOOGLE_DRIVE:
    WEIGHTS_PATH = "/content/drive/MyDrive/eml_challenge/weights/"
else:
    WEIGHTS_PATH = "./"

sd = torch.load(WEIGHTS_PATH + "voc_pretrained.pt", weights_only=True)

# Number of Epochs
NUM_EPOCHS = 50
NUM_VALIDATION_SAMPLES = 350
# Thresholds
CONFIDENCE_THRESHOLD = 0.0
NMS_THRESHOLD = 0.5

# Fine-tune the network
fine_tune(net, sd, data_loader, test_loader, num_eval_samples=NUM_VALIDATION_SAMPLES)