# 1/ Env Setup
Load necessary libraries to run this notebook. <br>
All libraries are cited in ```requirements.txt```. <br>
Documentation: https://docs.pytorch.org/vision/main/models/generated/torchvision.models.detection.retinanet_resnet50_fpn_v2.html

## 1.1/ Import dependencies
Load libraries:

In [1]:
import sys  
import os

current_dir = os.getcwd() # path to the current working directory (notebook location)
project_root = os.path.abspath(os.path.join(current_dir, "..")) # path to project root

if project_root not in sys.path: # add project root to sys.path
    sys.path.insert(0, project_root)
print(f"Project root added to sys.path: {project_root}")

Project root added to sys.path: /Users/litani/Documents/myCode/steel-defects


In [None]:
from pathlib import Path
import torch 
from torchvision.models.detection import retinanet_resnet50_fpn_v2
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
import numpy as np

## 1.2/ Set reproducibility
Device and seed:

In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)

# 2/ Configuration Management
Define:
- image path
- model hyperparameters
- hardware


In [4]:
class Config:
    # Paths
    DATA_ROOT = Path(project_root) / "data" / "raw"
    TRAIN_IMG = DATA_ROOT / "train_images"
    TRAIN_ANN = DATA_ROOT / "train_annotations"
    VAL_IMG = DATA_ROOT / "valid_images"
    VAL_ANN = DATA_ROOT / "valid_annotations"

    # Model parameters
    NUM_CLASSES = 7 # 6 defects + 1 background
    BACKBONE_PRETRAINED = True 
    
    # Training hyperparameters
    BATCH_SIZE = 5  # no mention of batch size in the paper (go for > 10 when you are sure that training works)
    NUM_EPOCHS = 3 # 24 epochs based on paper. Reduced for quicker testing
    LEARNING_RATE = 0.0025 # 0.0025 based on paper
    MOMENTUM = 0.9 # 0.9 based on paper
    WEIGHT_DECAY = 0.0005 # double check this value <<<<<<<

    # Hardware
    DEVICE = device
    NUM_WORKERS = 8
    PIN_MEMORY = True if torch.cuda.is_available() else False

config = Config()

In [5]:
config

<__main__.Config at 0x12de812b0>

# 3/ Dataset Class
- Load images and annotations into PyTorch format. 
- This is necessary since RetineNet excepts a dictionary format. The latter requires XML parsing.

In [6]:
from src.utils.dataset import SteelDefectDataset, collate_func

# 4/ Data Augmentation
- We have 1800 images, resorting to image augmentation is mandatory to avoid overfitting. 
- Geometric transformation, simplist form, will be applied as a quick fix:
    - Horizental/Vertical flips
    - Rotate by 90
    - Others: brightness, contrast, adding random noise
- **NB:** OpenCV stores images as [Height in pixels, Width in pixels, RGB] while PyTorch expects [channel, height, width]

In [7]:
from src.utils.transforms_pipeline import get_train_transforms, get_val_transforms

# 5/ Model Initilization
- Apply transfer learning where pretrained RetineNet is loaded the changes are applied based on the dataset

In [8]:
def create_model(num_classes, pretrained = True):
    # Load pretrained RetineNet w/ ResNet50 backbone,
    model = retinanet_resnet50_fpn_v2(weights = "DEFAULT" if pretrained else None)  # DEFAULT loads ImageNet pretrained weights for transfer learning
    
    # Replace head so that model learns defect-specific patterns
    num_anchors = model.head.classification_head.num_anchors # default is 9 anchors per location >> 3 scales x 3 aspect ratios
    model.head.classification_head = RetinaNetClassificationHead(
        in_channels = 256,          # Input: 256 features from FPN   
        num_anchors = num_anchors,  # Process: 9 anchors per location
        num_classes = num_classes   # Output: 7 classes scores per anchor
    )
    return model

model = create_model(config.NUM_CLASSES).to(device) # Create model instance and move to device CPU/GPU, config.NUM_CLASSES = 7 includes background

# 6/ Data Loaders
- collate_func is a fucntion that works on the collation process of RetinaNet since images have variable bbox counts.

In [9]:
# Create train dataset with augmentations and val dataset without augmentations, only format conversion
train_dataset = SteelDefectDataset(
    config.TRAIN_IMG,
    config.TRAIN_ANN,
    transforms = get_train_transforms()
)

val_dataset = SteelDefectDataset(
    config.VAL_IMG,
    config.VAL_ANN,
    transforms = get_val_transforms()
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = config.BATCH_SIZE,
    shuffle = True, # shuffle training data for better generalization
    num_workers = config.NUM_WORKERS,
    pin_memory = config.PIN_MEMORY, # only useful if using GPU
    collate_fn = collate_func 
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size = config.BATCH_SIZE,
    shuffle = False, # validation data should be consistent
    num_workers = config.NUM_WORKERS,
    pin_memory = config.PIN_MEMORY,
    collate_fn = collate_func
)

  self._set_keys()


# 7/ Training Loop and Evaluation Metrics
- Quick and dirty: use SGD as an optimizer for an initial model training, won't be launching/tracking experiments in the beginning
- this is standard supervised learning using RetinaNet loss
- For evaluation, assess model perf without retraining, aim for .5 (50% overlap) with validation images

In [10]:
from src.utils.trainEval_pipeline import train_one_epoch, evaluate

In [None]:
# Optimizer
optimizer = torch.optim.SGD(
    model.parameters(),
    lr = config.LEARNING_RATE,
    momentum = config.MOMENTUM,
    weight_decay = config.WEIGHT_DECAY # >>> double check this value <<<
)

# Scheduler
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size = 5, 
    gamma = 0.8   # reduce to 80% every 5 epochs
)

# Checkpoint directory
Path("models").mkdir(exist_ok = True)

# Training Loop
for epoch in range(config.NUM_EPOCHS):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{config.NUM_EPOCHS}")
    print(f"{'='*60}")
    
    # Training
    train_loss = train_one_epoch(model, train_loader, optimizer, device)
    scheduler.step()
    
    # Validation
    results = evaluate(model, val_loader, device)
    
    # Display metrics
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val mAP@0.5: {results['map_50']:.4f}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save checkpoint
    save_interval = max(1, config.NUM_EPOCHS // 3)
    if (epoch + 1) % save_interval == 0 or (epoch + 1) == config.NUM_EPOCHS:
        checkpoint_path = f"models/retinanet_epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"✓ Checkpoint saved: {checkpoint_path}")


Epoch 1/3


Train Loss: 1.3562
Val mAP@0.5: 0.0000
Learning Rate: 0.002500
✓ Checkpoint saved: models/retinanet_epoch_1.pth

Epoch 2/3




Train Loss: 1.0409
Val mAP@0.5: 0.2362
Learning Rate: 0.002500
✓ Checkpoint saved: models/retinanet_epoch_2.pth

Epoch 3/3
Train Loss: 0.8673
Val mAP@0.5: 0.3986
Learning Rate: 0.002500
✓ Checkpoint saved: models/retinanet_epoch_3.pth
