# Aleket Faster R-CNN training notebook

In [None]:
%pip install pillow
%pip install numpy<2.0
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
%pip install matplotlib
%pip install pycocotools
%pip install gdown

from IPython.display import clear_output

clear_output(wait=False)

print("ALL DEPENDENCIES INSTALLED")

In [2]:
# IMPORTS

# Standard Library
import os
import random
import shutil

# Third-Party Libraries
import gdown
import numpy as np

# Torch
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

# Torchvision
import torchvision.models.detection as tv_detection
import torchvision.transforms.v2 as v2

# Utils
from utils import AleketDataset, TrainingLogger, StatsTracker
from training_and_evaluation import evaluate, train_one_epoch

In [None]:
# PATH VARIABLES
def download_dataset(save_dir: str):
    """Downloads and extracts the dataset if it doesn't exist locally.
    
    Args:
        save_dir: The directory to save the dataset.

    Returns:
        The path to the saved dataset directory.
    """
    patched_dataset_gdrive_id = ""  #  FIXME: Replace with your actual Google Drive ID
    if not os.path.exists(save_dir):
        gdown.download(id=patched_dataset_gdrive_id, output="_temp_.zip")
        shutil.unpack_archive("_temp_.zip", save_dir)
        os.remove("_temp_.zip")
    print(f"Dataset loaded from {save_dir}")
    return save_dir


# Dataset and Model Paths
DATASET_ROOT = download_dataset("dataset_patched")
MODEL_DIR = "result"
LAST_MODEL_PATH = os.path.join(MODEL_DIR, "last_model.pth")
BEST_MODEL_PATH = os.path.join(MODEL_DIR, "best_model.pth")

In [None]:
# TRAIN SETTINGS

# Device Selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Random Seed for Reproducibility
SEED = 1

# Training Hyperparameters
LOAD_BEST = False
BATCH_SIZE = 15
EPOCHS = 500
TEST_FRACTION = 0.2
DATASET_FRACTION = 1
DATALOADER_WORKERS = 20
LR = 0.003

# Data Augmentation Transforms
TRANSFORMS = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomPerspective(distortion_scale=0.2, p=0.2),
    v2.RandomRotation(degrees=(-10, 10), expand=True),
])


def get_model(num_classes) -> tv_detection.FasterRCNN:
    """Loads or creates a Faster R-CNN model.
    Args:
        num_classes: The number of classes in the dataset.

    Returns:
        The Faster R-CNN model on the specified device.
    """
    if LOAD_BEST and os.path.exists(BEST_MODEL_PATH):
        model = torch.load(BEST_MODEL_PATH, weights_only=False)
    else:
        model = tv_detection.fasterrcnn_resnet50_fpn_v2(
            weights="DEFAULT"
        )
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = (
            tv_detection.faster_rcnn.FastRCNNPredictor(
                in_features, num_classes
            )
        )
    return model.to(DEVICE)

In [None]:
# main

def set_seed():
    """Sets the random seed for reproducibility."""
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)


def save_model(model, is_best):
    """Saves the model checkpoint.
    Args:
        model: The model to save.
        is_best: Whether this is the best model so far.
    """
    os.makedirs(MODEL_DIR, exist_ok=True)
    torch.save(model, LAST_MODEL_PATH)
    if is_best:
        torch.save(model, BEST_MODEL_PATH)


def divide_dataset(
    dataset: AleketDataset,
    dataset_fraction: float,
    test_fraction: float,
    batch_size: int,
    num_workers: int,
) -> tuple[DataLoader, DataLoader]:
    """Divides the dataset into training and validation sets and creates DataLoaders.
    Args:
        dataset: The AleketDataset to divide.
        dataset_fraction: The fraction of the dataset to use.
        test_fraction: The fraction of the used dataset to allocate for validation.
        batch_size: The batch size for the DataLoaders.
        num_workers: The number of worker processes for data loading.
    Returns:
        A tuple containing the training DataLoader and the validation DataLoader.
    """

    def collate_fn(batch):
        """Collates data samples into batches for the dataloader."""
        return tuple(zip(*batch))

    # Calculate the number of samples to use based on the dataset_fraction
    num_samples = int(len(dataset) * dataset_fraction)

    # Randomly shuffle indices and select the desired number of samples
    indices = torch.randperm(len(dataset)).tolist()[:num_samples]

    # Calculate the number of samples for the validation set
    test_size = int(num_samples * test_fraction)

    # Create training and validation subsets
    train_dataset = Subset(dataset, indices[:-test_size])
    val_dataset = Subset(dataset, indices[-test_size:])

    # Create DataLoaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        num_workers=num_workers,
    )

    return train_dataloader, val_dataloader


def main():
    """Main training and evaluation loop."""
    set_seed()
    model = get_model(3)

    dataset = AleketDataset(DATASET_ROOT, transforms=TRANSFORMS)
    train_dataloader, val_dataloader = divide_dataset(
        dataset, DATASET_FRACTION, TEST_FRACTION, BATCH_SIZE, DATALOADER_WORKERS
    )

    params = [
        p for p in model.parameters() if p.requires_grad
    ]  # Optimize only trainable parameters
    optimizer = optim.SGD(params, lr=LR, momentum=0.9, weight_decay=0.0005)

    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=1, end_factor=0.1, total_iters=15
    )

    if(os.path.exists("training.log")):
        os.remove("training.log")
        
    logger = TrainingLogger("resnet50_v2_backbone training", "training.log", batch_print=False)
    stats = StatsTracker()
    conf_thresh = 0
    for epoch in range(EPOCHS):

        logger.log_epoch_start(epoch, EPOCHS)
        
        dataset.train = True
        losses = train_one_epoch(
            model, optimizer, train_dataloader, DEVICE, epoch, logger
        )

        dataset.train = False
        eval_stats = evaluate(
            model, val_dataloader, DEVICE, logger, conf_thresh
        )
        stats.update_train_loss(losses)
        is_best = stats.update_val_metrics(eval_stats)

        logger.log_epoch_end(epoch, losses['loss'],eval_stats)
        save_model(model, is_best)

        lr_scheduler.step()
        stats.plot_stats("training_stats")

    stats.plot_stats("training_stats")

if __name__ == "__main__":
    main()