# Aleket Faster R-CNN training notebook

In [None]:
%pip install pillow
%pip install numpy<2.0
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
%pip install matplotlib
%pip install pycocotools
%pip install gdown

from IPython.display import clear_output

clear_output(wait=False)

print("ALL DEPENDENCIES INSTALLED")

In [1]:
# IMPORTS

# Standard Library
import os
import random

# Third-Party Libraries
import numpy as np

# Torch
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import LinearLR

# Torchvision
import torchvision.models.detection as tv_detection
from torchvision.models.detection import FasterRCNN
import torchvision.transforms.v2 as v2

# Utils
from utils import AleketDataset, TrainingLogger, StatsTracker, download_dataset, split_dataset
from training_and_evaluation import evaluate, train_one_epoch, CocoEvaluator

In [None]:
# Helper Functions

# Device Selection
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Model Builder
def get_model(num_classes) -> FasterRCNN:
    """Loads or creates a Faster R-CNN model.
    Args:
        num_classes: The number of classes in the dataset.

    Returns:
        The Faster R-CNN model on the specified device.
    """
    model = tv_detection.fasterrcnn_resnet50_fpn(
        weights="DEFAULT"
    )
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = (
        tv_detection.faster_rcnn.FastRCNNPredictor(
            in_features, num_classes
        )
    )
    return model.to(DEVICE)

# Save training state
def save_checkpoint(model: FasterRCNN,
                    optimizer: SGD,
                    lr_scheduler: LinearLR,
                    epoch_trained: int,
                    checkpoint_path: str) -> None:
    save_state = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "lr_scheduler_state_dict": lr_scheduler.state_dict(),
        "epoch_trained": epoch_trained, 
    }
    torch.save(save_state, checkpoint_path)

# Load training state
def load_checkpoint(
    checkpoint_path: str) -> tuple[FasterRCNN, SGD, LinearLR]:
    
        save_state = torch.load(checkpoint_path, weights_only=False)
        
        model = get_model(3)
        model.load_state_dict(save_state["model_state_dict"])
        optimizer = SGD(model.parameters())
        optimizer.load_state_dict(save_state["optimizer_state_dict"])
        
        epoch_trained = save_state["epoch_trained"]
               
        lr_scheduler = LinearLR(optimizer, last_epoch=epoch_trained)
        lr_scheduler.load_state_dict(save_state["lr_scheduler_state_dict"])
        
        
        return model, optimizer, lr_scheduler, epoch_trained


In [None]:
# Random Seed for Reproducibility
SEED = 1
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Path variables
RESUME = False
RESULT_PATH = "result"
DATASET_ROOT = download_dataset("dataset_patched")

os.makedirs(os.path.join(RESULT_PATH, "run"), exist_ok=True)
LAST_CHECKPOINT = os.path.join(RESULT_PATH, "run", "last.pth")
BEST_CHECKPOINT = os.path.join(RESULT_PATH, "run", "best.pth")   
STATS_GRAPH = os.path.join(RESULT_PATH, "stats_graph")   
STATS_LOG = os.path.join(RESULT_PATH, "stats.csv")   

# Dataset split
TEST_FRACTION = 0.2
DATASET_FRACTION = 0.01
DATALOADER_WORKERS = 16

# Training Hyperparameters
IMG_SIZE = 1024
BATCH_SIZE = 16
EPOCHS = 10
LR = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0001
WARMUP_EPOCHS = 50

# Data Augmentation Transforms

DEFAULT_TRANSFORMS = v2.Compose(
    [v2.Resize(size=IMG_SIZE), v2.ToDtype(torch.float32, scale=True), ]
)
   
TRAINING_TRANSFORMS = v2.Compose(
    [
        v2.RandomHorizontalFlip(p=0.5),
        v2.RandomVerticalFlip(p=0.5),
        v2.RandomPerspective(distortion_scale=0.2, p=0.2),
        v2.RandomRotation(degrees=(-10, 10), expand=True),
        DEFAULT_TRANSFORMS,
    ]
)

def main():
    """Main training and evaluation loop."""
    
    dataset = AleketDataset(DATASET_ROOT, DEFAULT_TRANSFORMS)
    # Calculate the number of samples to use based on the dataset_fraction
    num_samples = int(len(dataset) * DATASET_FRACTION)
    indices = torch.randperm(len(dataset)).tolist()[:num_samples]
    test_size = int(num_samples * TEST_FRACTION)
    
    train_indicies, val_indicies = indices[:-test_size], indices[-test_size:]
    
    train_dataloader, val_dataloader = split_dataset(
            dataset, train_indicies, val_indicies, BATCH_SIZE, DATALOADER_WORKERS
    ) 
           
    coco_eval = CocoEvaluator(val_dataloader.dataset)       
           
    model = get_model(3)
    optimizer = SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    lr_scheduler = LinearLR(
        optimizer, start_factor=1, end_factor=0.1, total_iters=WARMUP_EPOCHS
    )
    
    epoch_trained = 0
    stats_tracker = StatsTracker()
    logger = TrainingLogger(STATS_LOG)
    
    if RESUME:
        print(f"Resuming from  {LAST_CHECKPOINT}...")
        model, optimizer, lr_scheduler, epoch_trained = load_checkpoint(LAST_CHECKPOINT)
    
    while epoch_trained < EPOCHS:
        
        epoch = epoch_trained+1
        logger.log_epoch_start(epoch, EPOCHS, lr_scheduler.get_last_lr()[0])
        
        dataset.transforms = TRAINING_TRANSFORMS
        losses = train_one_epoch(
            model, optimizer, train_dataloader, DEVICE
        )
        
        dataset.transforms = DEFAULT_TRANSFORMS
        eval_stats = evaluate(
            model, val_dataloader, coco_eval, DEVICE
        )
        
        stats_tracker.update_train_loss(losses)
        is_best = stats_tracker.update_val_metrics(eval_stats)

        logger.log_epoch_end(epoch, losses, eval_stats)
        stats_tracker.plot_stats(STATS_GRAPH)
        
        lr_scheduler.step()
        
        epoch_trained = epoch
        
        save_checkpoint(model,optimizer, lr_scheduler, epoch_trained, LAST_CHECKPOINT)
        if is_best:
            save_checkpoint(model,optimizer, lr_scheduler, epoch_trained, BEST_CHECKPOINT)

if __name__ == "__main__":
    main()