# Training Pipeline Notebook

This notebook bundles configuration, dataset utilities, model definitions, training loop, and evaluation into a single self‑contained workflow.

## 1. Configuration & Hyperparameters

In [None]:
import os
import random
import json
import time
import glob
import numpy as np
import pandas as pd                              
import torch
import torch.nn as nn
from torch.utils.data import DataLoader 
from torch.amp import autocast
from torch.optim import AdamW, lr_scheduler
from torchvision import transforms
import timm
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.metrics import average_precision_score, precision_recall_fscore_support, f1_score
from data_utils import DataPartition
import warnings # Suppress warnings that currently do not affect execution
warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release")
warnings.filterwarnings("ignore", message="Cannot set number of intraop threads after parallel work has started or after set_num_threads call")

# Hyperparameters
DEBUG_MODE = True # Uses sample of 200 
USE_GPU = True
MODEL_NAME = "swinv2_base_window12to24_192to384"
IMG_WIDTH = 384
N_EPOCHS = 2          
BATCH_SIZE = 2
LEARNING_RATE = 1e-5
PATIENCE = 8
DROPOUT_RATE = 0.5
SCHEDULER_T0 = 6
SCHEDULER_T_MULT = 1
MIN_LR = 1e-6
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
TEST_RATIO = 0.1
RANDOM_SEED = 42
THRESHOLD_MODE = 'per_label'  # choices: 'per_label', 'global'
GLOBAL_THRESHOLD = 0.5

# Reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Device setup
if USE_GPU and torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    pin_memory = True
    amp_dtype = torch.bfloat16
else:
    device = torch.device("cpu")
    pin_memory = False
    amp_dtype = torch.float32
print(f"Using device: {device}")

Using device: cpu


## 2. Dataset Utilities

In [None]:
# JSON Parsers and Helpers
def get_confidence_score(response_str):
    try:
        return int(response_str.split()[0])
    except:
        return None

def parse_classifications(record, threshold):
    labels_positive = {}
    for project in record.get("projects", {}).values():
        for label in project.get("labels", []):
            for classification in label.get("annotations", {}).get("classifications", []):
                name = classification.get("name")
                positive = False
                for answer in classification.get("checklist_answers", []):
                    score = get_confidence_score(answer.get("name", ""))
                    if score is not None and score <= threshold:
                        positive = True
                        break
                val = classification.get("value", "")
                if not positive and val and val[0].isdigit() and int(val[0]) <= threshold:
                    positive = True
                if name:
                    labels_positive[name] = int(labels_positive.get(name, 0) or positive)
    return labels_positive

def get_base_filename(fn):
    for suf in ["_left.jpg", "_right.jpg"]:
        if fn.endswith(suf):
            return fn[:-len(suf)]
    return os.path.splitext(fn)[0]

def create_ndjson_image_path_mapping(base_dir):
    pattern = os.path.join(base_dir, "*", "*", "split_jpg", "*.jpg")
    return {os.path.basename(p): p for p in glob.glob(pattern, recursive=True)}

# CSV Parsers and Helpers
def load_csv_to_df(filepath, img_dir):
    df = pd.read_csv(filepath) 
    df["group_id"] = [os.path.splitext(filename)[0] for filename in df["Filenames"]] # Add 'group_id' by removing the file extension.
    df["image_path"] = [os.path.join(img_dir, filename) for filename in df["Filenames"]] # Add 'image_path' by joining the img_dir with the filename.
    return df   

def group_stratified_split(df, label_columns, group_col, split_ratio, seed):
    unique_groups_array = df[group_col].unique()
    aggregated_labels = [] 
    for group in unique_groups_array:
        group_df = df[df[group_col] == group] # Extract the subset of rows for this group.
        agg_labels = group_df[label_columns].max()  # Use max() across rows for each label column to simulate a logical OR combining the labels per group
        aggregated_labels.append(agg_labels)
    aggregated_labels_array = np.array(aggregated_labels)
    # Initialize the multilabel stratified shuffle split with the desired test size and random seed.
    splitter = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=split_ratio, random_state=seed)
    # Use the splitter to get indices for train and test groups based on the aggregated labels.
    for first_split_idx, second_split_idx in splitter.split(unique_groups_array.reshape(-1, 1), aggregated_labels_array):
        first_groups = unique_groups_array[first_split_idx]
        second_groups = unique_groups_array[second_split_idx]
    # Create the final DataFrame splits by selecting rows that belong to each group split.
    df_split_1 = df[df[group_col].isin(first_groups)].reset_index(drop=True)
    df_split_2 = df[df[group_col].isin(second_groups)].reset_index(drop=True)
    return df_split_1, df_split_2

In [None]:
# MUST move DataPartition into a separate file (data_utils.py)
# in order to run the notebook version with num_workers > 0 for GPU training

# class DataPartition(Dataset):
#     def __init__(self, df, label_columns, transform=None):
#         self.label_columns = label_columns
#         self.transform = transform
#         self.img_paths = df["image_path"].tolist() # List of image paths
#         self.labels = df[label_columns].to_numpy(dtype=np.float32) # 2-D Array of of shape (N_samples, N_labels)

#     def __len__(self):
#         return len(self.img_paths)

#     def __getitem__(self, idx):
#         img_path  = self.img_paths[idx]
#         img = Image.open(img_path).convert("RGB") # Retrieve image
#         if self.transform:                        # Apply transformations to image
#             img = self.transform(img)
#         label_vector = torch.from_numpy(self.labels[idx]) # Retrieve label vector for the given sample
#         return img, label_vector
        
# Data Augmentation (Transforms)
train_transforms = transforms.Compose([
    transforms.Resize((IMG_WIDTH, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2,0.2,0.2,0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_WIDTH, IMG_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

## 3. Model & Classifier

In [None]:
class AttentionPool2d(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.query = nn.Parameter(torch.randn(1, in_channels)) # Learnable query vector of shape [1, C]
        self.to_key = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) # 1×1 convs for keys & values
        self.to_value = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
        self.scale = in_channels ** -0.5 # Scaling factor for dot‑product attention

    def forward(self, x):
        B, C, H, W = x.shape # x: [B, C, H, W]
        # Produce raw keys & values: [B, C, H, W] → [B, C, H*W] then reshape to [B, C, N] and permute to [B, N, C]
        key = self.to_key(x).reshape(B, C, -1).permute(0, 2, 1) # Keys and values: [B, N, C]s
        value = self.to_value(x).reshape(B, C, -1).permute(0, 2, 1)
        query = self.query.expand(B, -1).unsqueeze(1) # Expand single query to one per batch: [1, C] → [B, 1, C]
        attn = torch.softmax(torch.matmul(query, key.transpose(-1, -2)) * self.scale, dim=-1) # Compute scaled dot‑product attention: [B, 1, N]
        output = torch.matmul(attn, value) # Weighted sum of values: [B, 1, C]
        return output.squeeze(1) # Squeeze to [B, C]

class SwinTransformerMultiLabel(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        # 1) Model backbone without pooling/head
        self.model = timm.create_model(MODEL_NAME, pretrained=pretrained, num_classes=0, global_pool="")
        # 2) Attention pool + dropout + classifier head
        in_features = self.model.num_features
        self.attn_pool = AttentionPool2d(in_features)
        self.dropout = nn.Dropout(DROPOUT_RATE)
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, imgs):
        feat_map = self.model.forward_features(imgs) # Backbone -> feature map [B, H', W', C], downsampled
        feat_map = feat_map.permute(0, 3, 1, 2) # Permute to [B, C, H', W']
        pooled = self.attn_pool(feat_map) # Attention pooling -> [B, C]
        # Head
        features = self.dropout(pooled)
        logits = self.fc(features)  # Pass through a linear fc layer to get one score per class for each example in the batch: [B, num_classes]
        return logits

class Classifier:
    def __init__(self, model, transform, device, labels, thresholds):
        self.model = model.to(device).eval()
        self.transform = transform
        self.device = device
        self.labels = labels 
        self.thresholds = thresholds

    def predict(self, img):
        tensor = self.transform(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            logits = self.model(tensor)
            probabilities  = torch.sigmoid(logits).cpu().numpy()[0]  # [K]
        return (probabilities >= self.thresholds).astype(int)

    def save(self, filename):
        torch.save(self.model.state_dict(), filename + ".pt") # Save model weights
        thresholds_list = self.thresholds.tolist() # Gather thresholds into a JSON‑safe list
        config = { # Build and write the JSON metadata config
            "thresholds": thresholds_list,
            "labels": self.labels
        }
        with open(filename + ".json", "w") as file:
            json.dump(config, file, indent=2)
        print(f"Model Weights saved as {filename}.pt | Classifier Metadata saved as {filename}.json")

## 4. Training Monitor & Trainer

In [None]:
class TrainingMonitor:
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.val_mAPs = []
        self.start=time.time()
    def report_epoch(self, train_loss, val_loss, val_map):
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.val_mAPs.append(val_map)
    def finish(self):
        total_time = time.time()-self.start
        mins = int(total_time // 60)
        secs = int(total_time % 60)
        print(f"Total Training Time: {mins} min {secs} sec")
        return total_time

class Trainer:
    def __init__(self, model, optimizer, scheduler_cos, scheduler_plateau, criterion, train_loader, val_loader, device, monitor, patience, warmup_epochs, amp_dtype, accumulation_steps):
        self.model = model
        self.optimizer = optimizer
        self.scheduler_cos = scheduler_cos
        self.scheduler_plateau = scheduler_plateau
        self.criterion = criterion
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.monitor = monitor
        self.patience = patience
        self.warmup_epochs = warmup_epochs
        self.amp_dtype = amp_dtype
        self.accumulation_steps = accumulation_steps
        self.best_val_loss = float('inf')
        self.epochs_no_improve = 0
        self.base_lr=optimizer.param_groups[0]['lr'] # store the base LR for warm‑up calculations
        self.best_state = None

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        total_samples = 0
        self.optimizer.zero_grad()
        for batch_idx, (images, labels) in enumerate(self.train_loader):
            images = images.to(self.device)
            labels = labels.to(self.device)
            with autocast(device_type=self.device.type, dtype=amp_dtype): # GPU: forward + loss w/ BF16 Automatic Mixed Precision. Default: FP32 precision
                logits = self.model(images)
                loss = self.criterion(logits, labels)
                loss = loss / self.accumulation_steps
            loss.backward()# backward pass
            if (batch_idx + 1) % self.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
            batch_size = images.size(0)  
            running_loss += loss.item() * batch_size * self.accumulation_steps
            total_samples += batch_size
        if (batch_idx + 1) % self.accumulation_steps != 0: # flush gradients if the last batch didn’t trigger a step
            self.optimizer.step()
            self.optimizer.zero_grad()
        epoch_loss = running_loss / total_samples
        return epoch_loss
    
    def validate_epoch(self):
        self.model.eval()
        running_loss = 0.0
        total_samples = 0
        all_probs = []
        all_labels = []
        with torch.no_grad():
            for imgs, labels in self.val_loader:
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)
                logits = self.model(imgs)
                loss = self.criterion(logits, labels)
                batch_size = imgs.size(0)
                running_loss += loss.item() * batch_size
                total_samples += batch_size
                probabilities = torch.sigmoid(logits)
                all_probs.append(probabilities.cpu().numpy())
                all_labels.append(labels.cpu().numpy())
        val_loss = running_loss / total_samples
        all_probs = np.vstack(all_probs)
        all_labels = np.vstack(all_labels)
        per_label_AP = [average_precision_score(all_labels[:, i], all_probs[:, i]) for i in range(all_labels.shape[1])]
        val_mAP = float(np.mean(per_label_AP))
        return val_loss, per_label_AP, val_mAP

    def train(self, num_epochs):
        for epoch in range(1, num_epochs + 1):
            # warm‑up LR for first few epochs 
            if epoch < self.warmup_epochs:
                warmup_lr = self.base_lr * (epoch + 1) / self.warmup_epochs
                for pg in self.optimizer.param_groups:
                    pg['lr'] = warmup_lr
            start = time.time()
            train_loss = self.train_epoch()
            val_loss, val_per_label_AP, val_mAP = self.validate_epoch()
            total_time = time.time() - start
            mins = int(total_time // 60)
            secs = int(total_time % 60)
            # Scheduler steps 
            self.scheduler_cos.step()
            self.scheduler_plateau.step(val_loss) 
            # Print epoch summary
            print(f"\nEpoch {epoch}: Train Loss={train_loss:.4f} | Val Loss={val_loss:.4f} | Val mAP={val_mAP:.4f} ({mins} min {secs} sec)")
            # Early stopping & per‐class AP logging
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.epochs_no_improve = 0
                self.best_state = self.model.state_dict()
                torch.save(self.best_state, "best_model.pt")
                print(f"New best_model.pt saved at epoch {epoch} with val loss: {val_loss:.4f}")
                # Print a little table of per‐class AP
                print("   Validation per-class AP:")
                label_names = self.val_loader.dataset.label_columns
                for name, AP in zip(label_names, val_per_label_AP):
                    print(f"     {name:<15s} {AP:.4f}")
                print(f"   Validation mean AP: {val_mAP:.4f}")
            else:
                self.epochs_no_improve += 1
                if self.epochs_no_improve >= self.patience:
                    print("Early stopping triggered.")
                    break
            # Record in monitor
            self.monitor.report_epoch(train_loss, val_loss, val_mAP)
        if self.best_state is not None:
            self.model.load_state_dict(self.best_state) # Load best weights
        self.monitor.finish() 


## 5. Main Training & Evaluation

In [None]:
# Load labels to df
df = load_csv_to_df('miml_dataset/miml_labels_1.csv','miml_dataset/images')
if DEBUG_MODE:
    df = df.sample(n=200, random_state=RANDOM_SEED).reset_index(drop=True)
nonlabel_cols = {"external_id", "Filenames", "group_id", "image_path","Problematic", "Extra Notes", "Revisit"}
label_columns = [col for col in df.columns if col not in nonlabel_cols]
df[label_columns] = df[label_columns].fillna(0) # Fill NaN entries with 0
# Split train/val/test partitions
df_train_and_val, df_test = group_stratified_split(df, label_columns=label_columns, group_col="group_id", split_ratio=TEST_RATIO, seed=RANDOM_SEED)
relative_val_ratio = VAL_RATIO / (TRAIN_RATIO + VAL_RATIO)
df_train, df_val = group_stratified_split(df_train_and_val, label_columns=label_columns, group_col="group_id", split_ratio=relative_val_ratio, seed=RANDOM_SEED)
# Save partitions to .csv
df_train.to_csv("train_partition.csv", index=False)
df_val.to_csv("val_partition.csv", index=False)
df_test.to_csv("test_partition.csv", index=False)
print("Partitions saved to .csv files.")
# Load paritions from .csv
df_train = pd.read_csv("train_partition.csv")
df_val   = pd.read_csv("val_partition.csv")
df_test  = pd.read_csv("test_partition.csv")
# DataLoaders
train_dataset = DataPartition(df_train, label_columns, transform=train_transforms)
val_dataset   = DataPartition(df_val,   label_columns, transform=val_transforms)
test_dataset  = DataPartition(df_test,  label_columns, transform=val_transforms)
if USE_GPU and torch.cuda.is_available(): # Set num_workers for GPU
    optimal_num_workers = min(8, os.cpu_count() // 2)
else:
    optimal_num_workers = 0        
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=optimal_num_workers, pin_memory=pin_memory)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=optimal_num_workers, pin_memory=pin_memory)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=optimal_num_workers,pin_memory=pin_memory)
print(f"Train samples:      {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples:       {len(test_dataset)}")
print(f"Using num_workers: {optimal_num_workers}")
# Model, optimizer, scheduler
model = SwinTransformerMultiLabel(num_classes=len(label_columns),pretrained=True).to(device)
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05, amsgrad=False)
cos_scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=SCHEDULER_T0, T_mult=SCHEDULER_T_MULT, eta_min=MIN_LR)
plateau_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, threshold=1e-4, cooldown=1, min_lr=MIN_LR)
# Train 
monitor = TrainingMonitor()
trainer = Trainer(model=model, 
                    optimizer=optimizer,
                    scheduler_cos=cos_scheduler, 
                    scheduler_plateau=plateau_scheduler,
                    criterion=loss_fn,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    device=device, 
                    monitor=monitor,
                    patience=PATIENCE, 
                    warmup_epochs=3, 
                    amp_dtype=amp_dtype,
                    accumulation_steps=2
    )
trainer.train(N_EPOCHS)
# Save in Classifier wrapper w/ Prediction Threshold Settings
def find_optimal_thresholds(model, val_loader, device, num_classes, n_steps=101):
        model.eval()
        all_probs = []
        all_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                logits = model(images)
                probs = torch.sigmoid(logits).cpu().numpy()
                all_probs.append(probs)
                all_labels.append(labels.numpy())
        all_probs  = np.vstack(all_probs)
        all_labels = np.vstack(all_labels)

        thresholds = np.zeros(num_classes, dtype=float)
        taus = np.linspace(0, 1, n_steps)
        for k in range(num_classes):
            best_f1, best_tau = 0.0, 0.5
            for tau in taus:
                preds_k = (all_probs[:, k] >= tau).astype(int)
                f1 = f1_score(all_labels[:, k], preds_k, zero_division=0)
                if f1 > best_f1:
                    best_f1, best_tau = f1, tau
            thresholds[k] = best_tau
        return thresholds
if THRESHOLD_MODE == 'per_label': # Pick Thresholds on Validation
    thresholds = find_optimal_thresholds(model, val_loader, device, num_classes=len(label_columns), n_steps=101)
    print("\nOptimal per-class thresholds:", thresholds)
else: # Single Global Threshold
    thresholds = np.full(len(label_columns), GLOBAL_THRESHOLD, dtype=float)
print("Using thresholds:", thresholds)
classifier = Classifier(model, val_transforms, device, labels=label_columns, thresholds=thresholds)
classifier.save('best_classifier')
# Test set evaluation
print("\nTest Set performance:")
model.eval()
all_probs = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits = model(images)
        probabilities = torch.sigmoid(logits).cpu().numpy()
        all_probs.append(probabilities)
        all_labels.append(labels.numpy())
all_probs  = np.vstack(all_probs)
all_labels = np.vstack(all_labels)
binary_predictions = (all_probs >= thresholds).astype(int)
# Classification Report on Test Set
precisions, recalls, f1s, supports = precision_recall_fscore_support(all_labels, binary_predictions, zero_division=0)
for idx, label in enumerate(label_columns):
    precision = precisions[idx]
    recall = recalls[idx]
    f1 = f1s[idx]
    num_occurrences = supports[idx]
    print(f"{label:<15s} Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, num_occurences={num_occurrences}")


Partitions saved to .csv files.
Train samples:      162
Validation samples: 21
Test samples:       17
Using num_workers: 0


