# Aleket Faster R-CNN training notebook

In [None]:
%pip install pillow
%pip install numpy<2.0
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
%pip install matplotlib
%pip install pycocotools
%pip install gdown

from IPython.display import clear_output

clear_output(wait=False)

print("ALL DEPENDENCIES INSTALLED")

In [3]:
# IMPORTS

# Standard Library
import os
import random
import shutil
from typing import Optional

# Third-Party Libraries
import gdown
import numpy as np

# Torch
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

# Torchvision
import torchvision.models.detection as tv_detection
import torchvision.transforms.v2 as v2

# Utils
from utils import AleketDataset, TrainingLogger, StatsTracker
from training_and_evaluation import evaluate, train_one_epoch

In [None]:
# Trainer

# Device Selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Dataset split
def split_dataset(
    dataset: AleketDataset,
    train_indicies: list[int],
    val_indicies: list[int],
    batch_size: int,
    num_workers: int,
) -> tuple[DataLoader, DataLoader]:
    """Divides the dataset into training and validation sets and creates DataLoaders.
    Args:
        dataset: The AleketDataset to divide.
        train_indicies: Dataset indicies to train
        test_fraction: The fraction of the used dataset to allocate for validation.
        batch_size: The batch size for the DataLoaders.
        num_workers: The number of worker processes for data loading.
    Returns:
        A tuple containing the training DataLoader and the validation DataLoader.
    """

    def collate_fn(batch):
        """Collates data samples into batches for the dataloader."""
        return tuple(zip(*batch))

    # Create training and validation subsets
    train_dataset = Subset(dataset, train_indicies)
    val_dataset = Subset(dataset, val_indicies)

    # Create DataLoaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )

    return train_dataloader, val_dataloader

def get_model(num_classes) -> tv_detection.FasterRCNN:
    """Loads or creates a Faster R-CNN model.
    Args:
        num_classes: The number of classes in the dataset.

    Returns:
        The Faster R-CNN model on the specified device.
    """
    model = tv_detection.fasterrcnn_resnet50_fpn(
        weights="DEFAULT"
    )
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = (
        tv_detection.faster_rcnn.FastRCNNPredictor(
            in_features, num_classes
        )
    )
    return model.to(DEVICE)

class Trainer: 
    """
    A class to manage the training and evaluation of a FasterRCNN model.

    Args:
        dataset (AleketDataset): The dataset to use for training and evaluation.
        train_indicies (list[int]): List of indices for the training set.
        val_inidices (list[int]): List of indices for the validation set.
        batch_size (int): The batch size to use for training and evaluation.
        num_workers (int): The number of worker processes to use for data loading.
        model (FasterRCNN): The FasterRCNN model to train.
        optimizer (optim.optimizer.Optimizer): The optimizer to use for training.
        lr_scheduler (optim.lr_scheduler.LRScheduler): The learning rate scheduler to use for training.
        result_path (str): The path to save training results and checkpoints.
        total_epochs (int): The total number of epochs to train for.
        last_epoch (int, optional): The last epoch completed. Defaults to 0.
    """
    def __init__(self,
                 dataset: AleketDataset, 
                 train_indicies: list[int],
                 val_inidices: list[int],
                 batch_size: int,
                 num_workers: int,
                 model: tv_detection.FasterRCNN,
                 optimizer: optim.Optimizer,
                 lr_scheduler: optim.lr_scheduler.LRScheduler,
                 result_path: str,
                 total_epochs: int,
                 stats_tracker: StatsTracker,
                 last_epoch: int = 0
                 ) -> None:
        
        
        self.dataset = dataset
        self.train_indicies = train_indicies
        self.val_indicies = val_inidices
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        self.train_dataloader, self.val_dataloader = split_dataset(
            dataset, train_indicies, val_inidices, batch_size, num_workers
            )

        self.result_path = result_path
        os.makedirs(os.path.join(result_path, "run"), exist_ok=True)
        self.last = os.path.join(result_path, "run", "last.pth")
        self.best = os.path.join(result_path, "run", "best.pth")   
        self.stats_graph = os.path.join(result_path, "stats_graph")
        log_file = os.path.join(result_path, "stats.csv")
        self.logger = TrainingLogger(log_file)
        
        self.stats_tracker = stats_tracker
        
        self.model = model
        self.optimizer = optimizer 
        self.lr_scheduler = lr_scheduler 
        
        self.last_epoch = last_epoch
        self.total_epochs = total_epochs
    
    def train(self):
        """
        Trains the FasterRCNN model for the specified number of epochs.
        """
        while self.last_epoch < self.total_epochs:
            epoch = self.last_epoch+1
            self.logger.log_epoch_start(epoch, self.total_epochs, self.lr_scheduler.get_last_lr()[0])
            
            self.dataset.train = True
            losses = train_one_epoch(
                self.model, self.optimizer, self.train_dataloader, DEVICE
            )

            self.dataset.train = False
            eval_stats = evaluate(
                self.model, self.val_dataloader, DEVICE
            )
            
            self.stats_tracker.update_train_loss(losses)
            is_best = self.stats_tracker.update_val_metrics(eval_stats)

            self.logger.log_epoch_end(epoch, losses,eval_stats)
            
            self.save_state(is_best)
            
            self.lr_scheduler.step()
            self.stats_tracker.plot_stats("training_stats")
            
            self.last_epoch = epoch  # Update last_epoch after each epoch

        
    def save_state(self, is_best: bool):
        """
        Saves the current training state to a file.

        Args:
            is_best (bool): Whether the current state is the best so far.
        """
        save_state = {
            "dataset": self.dataset,
            "train_indicies": self.train_indicies,
            "val_indicies": self.val_indicies,
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "lr_scheduler_state_dict": self.lr_scheduler.state_dict(),
            "stats_tracker": self.stats_tracker,
            "last_epoch": self.last_epoch, 
            "total_epochs": self.total_epochs,
            "result_path": self.result_path,
            "last": self.last, 
            "best": self.best,
        }
        torch.save(save_state, self.last)
        if is_best:
            torch.save(save_state, self.best)
        
    
    def load_state(file: str):
        """
        Loads the training state from a saved file.
        Args:
          dataset: The AleketDataset object.
          file: The path to the saved state file.
        """
        
        save_state = torch.load(file, weights_only=False)

        train_indicies = save_state["train_indicies"]
        val_indicies = save_state["val_indicies"]
        batch_size = save_state["batch_size"]
        num_workers = save_state["num_workers"]
        dataset = save_state["dataset"] 
        
        model = get_model(3)
        model.load_state_dict(save_state["model_state_dict"])
        params = [p for p in model.parameters() if p.requires_grad]
        
        optimizer = optim.SGD(params)
        optimizer.load_state_dict(save_state["optimizer_state_dict"])
        
        lr_scheduler = optim.lr_scheduler.LinearLR(optimizer)
        lr_scheduler.load_state_dict(save_state["lr_scheduler_state_dict"])
        
        stats_tracker = save_state["stats_tracker"]  
        last_epoch = save_state["last_epoch"]
        total_epochs = save_state["total_epochs"]
        result_path = save_state["result_path"]
        
        print(f"Loaded state from {file}")
        return Trainer(dataset,
                       train_indicies,
                       val_indicies,
                       batch_size,
                       num_workers,
                       model,
                       optimizer,
                       lr_scheduler,
                       result_path,
                       total_epochs,
                       stats_tracker,
                       last_epoch)
        

In [None]:
# Random Seed for Reproducibility
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Downloads and extracts the dataset if it doesn't exist locally.
def download_dataset(save_dir: str):
    """Downloads and extracts the dataset if it doesn't exist locally.

    Args:
        save_dir: The directory to save the dataset.

    Returns:
        The path to the saved dataset directory.
    """
    patched_dataset_gdrive_id = ""  #  FIXME: Replace with actual Google Drive ID fpr
    if not os.path.exists(save_dir):
        gdown.download(id=patched_dataset_gdrive_id, output="_temp_.zip")
        shutil.unpack_archive("_temp_.zip", save_dir)
        os.remove("_temp_.zip")
    print(f"Dataset loaded from {save_dir}")
    return save_dir


# Path variables
RESUME = False
RESULT_PATH = "result"
DATASET_ROOT = download_dataset("dataset_patched")

# Dataset split
TEST_FRACTION = 0.2
DATASET_FRACTION = 1
DATALOADER_WORKERS = 16

# Training Hyperparameters
BATCH_SIZE = 16
EPOCHS = 500
LR = 0.005
MOMENTUM = 0.997
WEIGHT_DECAY = 0.0001
WARMUP_EPOCHS = 15
# Data Augmentation Transforms
TRANSFORMS = v2.Compose(
    [
        v2.RandomHorizontalFlip(p=0.5),
        v2.RandomVerticalFlip(p=0.5),
        v2.RandomPerspective(distortion_scale=0.2, p=0.2),
        v2.RandomRotation(degrees=(-10, 10), expand=True),
    ]
)

trainer = None

def main():
    """Main training and evaluation loop."""



    if not RESUME:
        dataset = AleketDataset(DATASET_ROOT, transforms=TRANSFORMS)
        # Calculate the number of samples to use based on the dataset_fraction
        num_samples = int(len(dataset) * DATASET_FRACTION)
        indices = torch.randperm(len(dataset)).tolist()[:num_samples]
        test_size = int(num_samples * TEST_FRACTION)
        train_indicies, val_indicies = indices[:-test_size], indices[-test_size:]
        
        model = get_model(3)
        params = [
            p for p in model.parameters() if p.requires_grad
        ]  # Optimize only trainable parameters
        optimizer = optim.SGD(params, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=1, end_factor=0.1, total_iters=WARMUP_EPOCHS
        )

        stats_tracker = StatsTracker()

        trainer = Trainer(
            dataset,
            train_indicies,
            val_indicies,
            BATCH_SIZE,
            DATALOADER_WORKERS,
            model,
            optimizer,
            lr_scheduler,
            RESULT_PATH,
            EPOCHS,
            stats_tracker,
        )
    else: 
        trainer = Trainer.load_state("result/run/best.pth")
    trainer.train()


if __name__ == "__main__":
    main()