# Training Pipeline Notebook

This notebook bundles configuration, dataset utilities, model definitions, training loop, and evaluation into a single self‑contained workflow.

## 1. Configuration & Hyperparameters

In [None]:
import os
import sys
import random
import json
import time
import glob
import numpy as np
import pandas as pd            
from pandas import json_normalize                  
import torch
import torch.nn as nn
from torch.utils.data import DataLoader 
from torch.amp import autocast
from torch.optim import AdamW, lr_scheduler
from torchvision import transforms
import timm
from PIL import Image
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.metrics import average_precision_score, precision_recall_fscore_support, f1_score
from datapartition import DataPartition
from torch.nn.modules.transformer import _get_activation_fn
from torch import Tensor
import warnings # Suppress warnings that currently do not affect execution
warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release")
warnings.filterwarnings("ignore", message="Cannot set number of intraop threads after parallel work has started or after set_num_threads call")

# Hyperparameters
DEBUG_MODE = True # Uses sample of 200 
USE_GPU = False
MODEL_NAME = "swinv2_base_window12to24_192to384"
IMG_WIDTH = 384
N_EPOCHS = 2          
BATCH_SIZE = 2
LEARNING_RATE = 1e-5
PATIENCE = 8
DROPOUT_RATE = 0.5
SCHEDULER_T0 = 6
SCHEDULER_T_MULT = 1
MIN_LR = 1e-6
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
TEST_RATIO = 0.1
RANDOM_SEED = 42
THRESHOLD_MODE = 'per_label'  # choices: 'per_label', 'global'
GLOBAL_THRESHOLD = 0.5

# Reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Device setup
if USE_GPU and torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    optimal_num_workers = min(8, os.cpu_count() // 2)
    pin_memory = True
    amp_dtype = torch.bfloat16
else:
    device = torch.device("cpu")
    optimal_num_workers = 0 
    pin_memory = False
    amp_dtype = torch.float32
print(f"Using device: {device}")

## 2. Dataset Utilities

### For NDJSON

In [None]:
NDJSON_PATH = 'demo_360.ndjson'
IMG_DIR = "miml_dataset/images"  

max_count = 0
deepest_order = None
with open(NDJSON_PATH, "r") as file:
    for line in file:
        record = json.loads(line)
        found_columns = []
        for project in record.get("projects", {}).values():
            for label in project.get("labels", []):
                for cls in label.get("annotations", {}).get("classifications", []):
                    found_columns.append(cls["name"])
        if len(found_columns) > max_count:
            max_count = len(found_columns)
            deepest_order = found_columns
print("Deepest order of classifications/labels:", deepest_order)

In [None]:
# Stream the NDJSON and build one dict per image
records = []
distance_columns = set()
with open(NDJSON_PATH, "r") as file:
    for line in file:
        record = json.loads(line)
        ext_id = record["data_row"]["external_id"]
        # initialize the row with external_id + all label cols set to NaN
        row = {"external_id": ext_id} 
        for label in deepest_order:
            row[label] = np.nan
        # fill in each classification
        for project in record.get("projects", {}).values():
            for label in project.get("labels", []):
                for cls in label["annotations"].get("classifications", []):
                    name = cls.get("name")
                    if name not in deepest_order:
                        continue
                    # Case: Free‑text fields (Extra Notes)
                    if cls.get("text_answer"):
                        row[name] = cls["text_answer"].get("content")
                    # Case: Checklist fields (confidence score and distance)
                    elif cls.get("checklist_answers"):
                        answers = [a.get("name") for a in cls["checklist_answers"]]
                        # extract the numeric (confidence) answer
                        conf = next((a for a in answers if a and a[0].isdigit()), None)
                        # extract the non‑numeric (distance) answer
                        dist = next((a for a in answers if not (a and a[0].isdigit())), None)

                        row[name] = conf
                        # create the distance column on the fly
                        distance_col = str(name)+"_distance"
                        distance_columns.add(distance_col)
                        row[distance_col] = dist
                    # Case: Radio fields (Lily, Problematic, Revisit)
                    elif cls.get("radio_answer"):
                        row[name] = cls["radio_answer"].get("name")
                    # Fallback: raw value string
                    else:
                        row[name] = cls.get("value")
        records.append(row)
df_all = pd.DataFrame(records) # Build the DataFrame (pandas will union in any distance cols)
nonlabel_columns = {"external_id", "Problematic", "Extra Notes", "Revisit"}
nonlabel_columns.update(distance_columns)
label_columns = [col for col in df_all.columns if col not in nonlabel_columns]
df_all.head()

In [None]:
# Re‑order columns by adding labels w/ its corresponding *_distance column(if it exists) in the order as it appears in LabelBox(deepest_order)
ordered_columns = ["external_id"]
for label in deepest_order:
    ordered_columns.append(label)
    distance_col = str(label)+"_distance"
    if distance_col in df_all.columns: # Also add corresponding *_distance column if it exists
        ordered_columns.append(distance_col)
df_all = df_all[ordered_columns]
df_all.head()

# ordered = (
#     ["external_id"] +
#     label_columns +
#     distance_cols 
# )
# df_all = df_all[ordered]
# df_all.head()

In [None]:
def create_img_path_mapping(root_img_dir):
    """
    Recursively scans the root image directory for JPEG images in any "split_jpg" folder
    and creates a mapping from each image's basename to its full local file path.
    
    Returns:
        dict: Mapping { img_filename : img_path }.
    """
    glob_pattern = os.path.join(root_img_dir, "*", "*", "split_jpg", "*.jpg")
    img_paths = glob.glob(glob_pattern, recursive=True)
    mapping = {}
    for img_path in img_paths:
        img_filename = os.path.basename(img_path)
        mapping[img_filename] = img_path
    return mapping

def get_base_filename(filename):
    '''
    Strip off suffixes like "_left.jpg", "_right.jpg" or just drop the extension
    '''
    for suffix in ["_left.jpg", "_right.jpg"]:
        if filename.endswith(suffix):
            return filename[:-len(suffix)] # Remove specified suffix
    return os.path.splitext(filename)[0] # Fallback: Removes extensions

path_map = create_img_path_mapping(IMG_DIR) # Build the mapping from filename to full path  
group_id_series= df_all["external_id"].apply(get_base_filename) # Create group_id column
img_path_series = df_all["external_id"].map(path_map) # Create image_path column
df_all.insert(0, "group_id", group_id_series) # Add group_id and image_path
df_all.insert(1, "image_path", img_path_series) # Add image_path
df_all = df_all.drop(columns=nonlabel_columns) # Drop 
df_all.head()

In [None]:
df_binary = df_all.copy().fillna(0) # Make a binary copy with NaN turned to 0s
df_binary.head()

In [None]:
POSITIVE_THRESHOLD = 3

def to_binary(entry): 
    num = int(str(entry).split()[0])  # parse out a leading integer
    return int((num > 0) and int(num <= POSITIVE_THRESHOLD))

for col in label_columns:
    df_binary[col] = df_binary[col].apply(to_binary)
df_binary.head()

### For CSV

In [None]:
# CSV Parsers and Helpers
def load_csv_to_df(filepath, img_dir):
    df = pd.read_csv(filepath) 
    df["group_id"] = [os.path.splitext(filename)[0] for filename in df["Filenames"]] # Add 'group_id' by removing the file extension.
    df["image_path"] = [os.path.normpath(os.path.join(img_dir, filename)) for filename in df["Filenames"]] # Add 'image_path' by joining the img_dir with the filename.
    return df   

In [None]:
# MUST move DataPartition into a separate file (data_utils.py)
# in order to run the notebook version with num_workers > 0 for GPU training
'''
class DataPartition(Dataset):
    def __init__(self, df, label_columns, transform=None):
        self.label_columns = label_columns
        self.transform = transform
        self.img_paths = df["image_path"].tolist() # List of image paths
        self.labels = df[label_columns].to_numpy(dtype=np.float32) # 2-D Array of of shape (N_samples, N_labels)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path  = self.img_paths[idx]
        img = Image.open(img_path).convert("RGB") # Retrieve image
        if self.transform:                        # Apply transformations to image
            img = self.transform(img)
        label_vector = torch.from_numpy(self.labels[idx]) # Retrieve label vector for the given sample
        return img, label_vector
'''
        
# Data Augmentation (Transforms)
train_transforms = transforms.Compose([
    transforms.Resize((IMG_WIDTH, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance to flip horizontally
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.2),
    transforms.RandomApply([transforms.RandomAffine(       # Randomly apply
                            degrees=10,                    # small rotation: rotate within [-10, 10] degrees
                            translate=(0.05, 0.05),        # small translation: shift up to 5% of the image dimensions
                            scale=(0.95, 1.05))], p=0.5),  # slightly zoom in or out
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_WIDTH, IMG_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

def group_stratified_split(df, label_columns, group_col, split_ratio, seed):
    unique_groups_array = df[group_col].unique()
    aggregated_labels_list = [] 
    for group in unique_groups_array:
        group_df = df[df[group_col] == group] # Extract the subset of rows for this group.
        series_max = group_df[label_columns].max() # Use max() across rows for each label column to simulate a logical OR combining the labels per group to a panda Series
        agg_labels = series_max.values # convert to numpy 1D array
        aggregated_labels_list.append(agg_labels)
    aggregated_labels_array = np.vstack(aggregated_labels_list) # stack into shape (n_groups, n_labels)
    # Initialize the multilabel stratified shuffle split with the desired test size and random seed.
    splitter = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=split_ratio, random_state=seed)
    # Use the splitter to get indices for train and test groups based on the aggregated labels.
    for first_split_idx, second_split_idx in splitter.split(unique_groups_array.reshape(-1, 1), aggregated_labels_array):
        first_groups = unique_groups_array[first_split_idx]
        second_groups = unique_groups_array[second_split_idx]
    # Create the final DataFrame splits by selecting rows that belong to each group split.
    df_split_1 = df[df[group_col].isin(first_groups)].reset_index(drop=True)
    df_split_2 = df[df[group_col].isin(second_groups)].reset_index(drop=True)
    return df_split_1, df_split_2

## 3. Model Architecture

In [None]:
# 3. Model -- ML-Decoder head (official style)
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.transformer import _get_activation_fn

def add_ml_decoder_head(model, num_classes: int = -1, num_of_groups: int = -1,
                        decoder_embedding: int = 768, initial_num_features: int = None):
    if num_classes == -1:
        num_classes = model.num_classes
    nf = model.num_features if initial_num_features is None else initial_num_features

    # remove existing pooling + head
    if hasattr(model, 'global_pool'):
        model.global_pool = nn.Identity()
    if hasattr(model, 'fc'):
        del model.fc
        model.fc = MLDecoder(num_classes, num_of_groups, decoder_embedding, nf)
    elif hasattr(model, 'head'):
        del model.head
        model.head = MLDecoder(num_classes, num_of_groups, decoder_embedding, nf)
    else:
        raise RuntimeError("Model not suited for ML-Decoder")
    return model

class TransformerDecoderLayerOptimal(nn.Module):
    """
    Official ML-Decoder decoder layer: only cross-attention + FFN,
    aliased so that PyTorch's TransformerDecoder will accept it.
    """
    def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1,
                 activation="relu", layer_norm_eps=1e-5):
        super().__init__()
        # pre-norm + dropout on queries
        self.norm1    = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        # cross-attention
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # alias so TransformerDecoder finds it
        self.self_attn = self.multihead_attn
        self.self_attn.batch_first = False
        # post-attn norm
        self.norm2    = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout2 = nn.Dropout(dropout)
        # feed-forward
        self.linear1   = nn.Linear(d_model, dim_feedforward)
        self.activation= _get_activation_fn(activation)
        self.dropout_fc= nn.Dropout(dropout)
        self.linear2   = nn.Linear(dim_feedforward, d_model)
        self.norm3    = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt: Tensor, memory: Tensor,
                tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None,
                **kwargs) -> Tensor:
        # 1) pre-norm + residual
        t = self.norm1(tgt + self.dropout1(tgt))
        # 2) cross-attention
        attn_out, _ = self.multihead_attn(
            t, memory, memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask
        )
        t = self.norm2(t + self.dropout2(attn_out))
        # 3) feed-forward + residual
        ff = self.linear2(self.dropout_fc(self.activation(self.linear1(t))))
        t = self.norm3(t + self.dropout3(ff))
        return t

class GroupFC(nn.Module):
    """Grouped FC: h:[B,K,D], duplicate_pooling:[K,D,g] → out:[B,K,g]"""
    def forward(self, h, duplicate_pooling, out_extrap: Tensor) -> Tensor:
        B, K, _ = h.shape
        for i in range(K):
            out_extrap[:, i] = h[:, i] @ duplicate_pooling[i]
        return out_extrap

class MLDecoder(nn.Module):
    """
    ML-Decoder head (official):
      1) embed_proj: C→D
      2) fixed query embeddings
      3) 1-layer TransformerDecoder (cross-attn+FFN)
      4) GroupFC → flatten → bias
    """
    def __init__(self, num_classes, num_of_groups=-1,
                 decoder_embedding=768, initial_num_features=1024):
        super().__init__()
        # determine number of queries K
        K = num_classes if num_of_groups<0 else num_of_groups
        K = min(K, num_classes)
        D = decoder_embedding

        # 1) projection from backbone features → D
        self.feature_dim = initial_num_features
        self.embed_proj = nn.Linear(initial_num_features, D)

        # 2) fixed (non-learnable) queries
        self.query_embed = nn.Embedding(K, D)
        self.query_embed.weight.requires_grad_(False)

        # 3) one-layer TransformerDecoder
        layer = TransformerDecoderLayerOptimal(d_model=D, nhead=8, dim_feedforward=2048, dropout=0.1)
        self.decoder = nn.TransformerDecoder(layer, num_layers=1)

        # 4) grouped-FC params
        self.duplicate_factor = (num_classes + K - 1)//K
        self.duplicate_pooling = nn.Parameter(torch.randn(K, D, self.duplicate_factor))
        self.bias = nn.Parameter(torch.zeros(num_classes))
        nn.init.xavier_normal_(self.duplicate_pooling)
        self.group_fc = GroupFC()

    def forward(self, x: Tensor) -> Tensor:
        # handle SwinV2 NHWC or NCHW → flatten to [B, S, C]
        if x.ndim == 4:
            # detect NHWC by last dim
            if x.size(-1) == self.feature_dim:
                x = x.permute(0, 3, 1, 2)       # NHWC → NCHW
            x = x.flatten(2).transpose(1, 2)   # [B, C, H, W] → [B, S, C]
        elif x.ndim == 3:
            # already [B, S, C]
            pass
        else:
            raise ValueError(f"Unexpected input shape: {x.shape}")

        # 1) project → [B, S, D]
        feat = self.embed_proj(x).relu()
        B = feat.size(0)

        # 2) prepare queries → [K, B, D]
        q = self.query_embed.weight
        tgt = q.unsqueeze(1).expand(-1, B, -1)

        # 3) cross-decode
        mem = feat.transpose(0, 1)           # [S, B, D]
        h = self.decoder(tgt, mem)           # [K, B, D]
        h = h.transpose(0, 1)                # [B, K, D]

        # 4) grouped FC + bias → logits
        out = feat.new_zeros(B, h.size(1), self.duplicate_factor)
        self.group_fc(h, self.duplicate_pooling, out)
        logits = out.flatten(1)[:, :self.bias.numel()] + self.bias
        return logits


## 4. Training Monitor & Trainer

In [None]:
class TrainingMonitor:
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.val_mAPs = []
        self.start=time.time()
    def report_epoch(self, train_loss, val_loss, val_map):
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.val_mAPs.append(val_map)
    def finish(self):
        total_time = time.time()-self.start
        mins = int(total_time // 60)
        secs = int(total_time % 60)
        print(f"Total Training Time: {mins} min {secs} sec")
        return total_time

class Trainer:
    def __init__(self, model, optimizer, scheduler_cos, scheduler_plateau, criterion, train_loader, val_loader, device, monitor, patience, warmup_epochs, amp_dtype, accumulation_steps):
        self.model = model
        self.optimizer = optimizer
        self.scheduler_cos = scheduler_cos
        self.scheduler_plateau = scheduler_plateau
        self.criterion = criterion
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.monitor = monitor
        self.patience = patience
        self.warmup_epochs = warmup_epochs
        self.amp_dtype = amp_dtype
        self.accumulation_steps = accumulation_steps
        self.best_val_loss = float('inf')
        self.epochs_no_improve = 0
        self.base_lr=optimizer.param_groups[0]['lr'] # store the base LR for warm‑up calculations
        self.best_state = None

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        total_samples = 0
        self.optimizer.zero_grad()
        for batch_idx, (images, labels) in enumerate(self.train_loader):
            images = images.to(self.device)
            labels = labels.to(self.device)
            with autocast(device_type=self.device.type, dtype=self.amp_dtype): # GPU: forward + loss w/ BF16 Automatic Mixed Precision. Default: FP32 precision
                logits = self.model(images)
                loss = self.criterion(logits, labels)
                loss = loss / self.accumulation_steps
            loss.backward()# backward pass
            if (batch_idx + 1) % self.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
            batch_size = images.size(0)  
            running_loss += loss.item() * batch_size * self.accumulation_steps
            total_samples += batch_size
        if (batch_idx + 1) % self.accumulation_steps != 0: # flush gradients if the last batch didn’t trigger a step
            self.optimizer.step()
            self.optimizer.zero_grad()
        epoch_loss = running_loss / total_samples
        return epoch_loss
    
    def validate_epoch(self):
        self.model.eval()
        running_loss = 0.0
        total_samples = 0
        all_probs = []
        all_labels = []
        with torch.no_grad():
            for imgs, labels in self.val_loader:
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)
                logits = self.model(imgs)
                loss = self.criterion(logits, labels)
                batch_size = imgs.size(0)
                running_loss += loss.item() * batch_size
                total_samples += batch_size
                probabilities = torch.sigmoid(logits)
                all_probs.append(probabilities.cpu().numpy())
                all_labels.append(labels.cpu().numpy())
        val_loss = running_loss / total_samples
        all_probs = np.vstack(all_probs)
        all_labels = np.vstack(all_labels)
        per_label_AP = [average_precision_score(all_labels[:, i], all_probs[:, i]) for i in range(all_labels.shape[1])]
        val_mAP = float(np.mean(per_label_AP))
        return val_loss, per_label_AP, val_mAP

    def train(self, num_epochs):
        for epoch in range(1, num_epochs + 1):
            # warm‑up LR for first few epochs 
            if epoch < self.warmup_epochs:
                warmup_lr = self.base_lr * (epoch + 1) / self.warmup_epochs
                for pg in self.optimizer.param_groups:
                    pg['lr'] = warmup_lr
            start = time.time()
            train_loss = self.train_epoch()
            val_loss, val_per_label_AP, val_mAP = self.validate_epoch()
            total_time = time.time() - start
            mins = int(total_time // 60)
            secs = int(total_time % 60)
            # Scheduler steps 
            self.scheduler_cos.step()
            self.scheduler_plateau.step(val_loss) 
            # Print epoch summary
            print(f"\nEpoch {epoch}: Train Loss={train_loss:.4f} | Val Loss={val_loss:.4f} | Val mAP={val_mAP:.4f} ({mins} min {secs} sec)")
            # Early stopping & per‐class AP logging
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.epochs_no_improve = 0
                self.best_state = self.model.state_dict()
                torch.save(self.best_state, "best_model.pt")
                print(f"New best_model.pt saved at epoch {epoch} with val loss: {val_loss:.4f}")
                # Print a little table of per‐class AP
                print("   Validation per-class AP:")
                label_names = self.val_loader.dataset.label_columns
                for name, AP in zip(label_names, val_per_label_AP):
                    print(f"     {name:<15s} {AP:.4f}")
                print(f"   Validation mean AP: {val_mAP:.4f}")
            else:
                self.epochs_no_improve += 1
                if self.epochs_no_improve >= self.patience:
                    print("Early stopping triggered.")
                    break
            # Record in monitor
            self.monitor.report_epoch(train_loss, val_loss, val_mAP)
        if self.best_state is not None:
            self.model.load_state_dict(self.best_state) # Load best weights
        self.monitor.finish() 

class TeeOutput:
    def __init__(self, filename, mode='w'):
        self.terminal = sys.stdout
        self.log = open(filename, mode)

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

def log_hyperparameters():
    print("\n========== HYPERPARAMETERS ==========")
    print(f"DEBUG_MODE:        {DEBUG_MODE}")
    print(f"USE_GPU:           {USE_GPU}")
    print(f"MODEL_NAME:        {MODEL_NAME}")
    print(f"IMG_WIDTH:         {IMG_WIDTH}")
    print(f"N_EPOCHS:          {N_EPOCHS}")
    print(f"BATCH_SIZE:        {BATCH_SIZE}")
    print(f"LEARNING_RATE:     {LEARNING_RATE}")
    print(f"PATIENCE:          {PATIENCE}")
    print(f"DROPOUT_RATE:      {DROPOUT_RATE}")
    print(f"SCHEDULER_T0:      {SCHEDULER_T0}")
    print(f"SCHEDULER_T_MULT:  {SCHEDULER_T_MULT}")
    print(f"MIN_LR:            {MIN_LR}")
    print(f"TRAIN_RATIO:       {TRAIN_RATIO}")
    print(f"VAL_RATIO:         {VAL_RATIO}")
    print(f"TEST_RATIO:        {TEST_RATIO}")
    print(f"RANDOM_SEED:       {RANDOM_SEED}")
    print(f"THRESHOLD_MODE:    {THRESHOLD_MODE}")
    print(f"GLOBAL_THRESHOLD:  {GLOBAL_THRESHOLD}")
    print("=====================================\n")



## 5. Classifier Wrapper & Prediction Thresholding

In [None]:
class Classifier:
    def __init__(self, model, transform, device, labels, thresholds):
        self.model = model.to(device).eval()
        self.transform = transform
        self.device = device
        self.labels = labels 
        self.thresholds = thresholds

    def predict_probability(self, img):
        tensor = self.transform(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            logits = self.model(tensor)
            return torch.sigmoid(logits).cpu().numpy()[0]

    def predict_binary(self, img):
        probabilities = self.predict_probability(img)
        return (probabilities >= self.thresholds).astype(int)

    def save(self, base_filename):
        torch.save(self.model.state_dict(), base_filename + ".pt") # Save model weights
        thresholds_list = self.thresholds.tolist() # Gather thresholds into a JSON‑safe list
        config = { # Build and write the JSON metadata config
            "thresholds": thresholds_list,
            "labels": self.labels
        }
        with open(base_filename + ".json", "w") as file:
            json.dump(config, file, indent=2)
        print(f"Model Weights saved as {base_filename}.pt | Classifier Metadata saved as {base_filename}.json")

    @staticmethod
    def load(config_filename, transform, device):
        config = json.load(open(config_filename, "r"))
        labels = config["labels"]
        thresholds = np.array(config["thresholds"], dtype=float)
        weights_file = config_filename.replace(".json", ".pt") # Extract weights file name
        # rebuild and load model
        backbone = timm.create_model(
            MODEL_NAME,
            pretrained=False,
            num_classes=0,
            global_pool=""
        )
        model = add_ml_decoder_head(
            backbone,
            num_classes=len(labels),
            num_of_groups=16,       # UPDATE TO MATCH whatever used at training 
            decoder_embedding=768    # UPDATE TO MATCH whatever used at training 
        ).to(device)
        state = torch.load(weights_file, map_location=device)
        model.load_state_dict(state)
        model.to(device).eval()
        return Classifier(model=model, transform=transform, device=device, labels=labels,thresholds=thresholds)

def find_optimal_thresholds(model, val_loader, device, num_classes, n_steps=101):
    model.eval()
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            logits = model(images)
            probabilities = torch.sigmoid(logits).cpu().numpy()
            all_probs.append(probabilities)
            all_labels.append(labels.numpy())
    all_probs  = np.vstack(all_probs)
    all_labels = np.vstack(all_labels)

    thresholds = np.zeros(num_classes, dtype=float)
    taus = np.linspace(0, 1, n_steps)
    for k in range(num_classes):
        best_f1, best_tau = 0.0, 0.5
        for tau in taus:
            preds_k = (all_probs[:, k] >= tau).astype(int)
            f1 = f1_score(all_labels[:, k], preds_k, zero_division=0)
            if f1 > best_f1:
                best_f1, best_tau = f1, tau
        thresholds[k] = best_tau
    return thresholds

## 6. Main Training & Evaluation

In [None]:
# Load labels to df
IMG_DIR = 'miml_dataset/images'
LABELS_PATH = 'miml_dataset/miml_labels_1.csv'
df = load_csv_to_df(LABELS_PATH,IMG_DIR)
if DEBUG_MODE:
    df = df.sample(n=200, random_state=RANDOM_SEED).reset_index(drop=True)
nonlabel_cols = {"external_id", "Filenames", "group_id", "image_path", "Problematic", "Extra Notes", "Revisit"}
label_columns = [col for col in df.columns if col not in nonlabel_cols and not col.endswith("_distance")]
df[label_columns] = df[label_columns].fillna(0) # Fill NaN entries with 0

# Split train/val/test partitions & save to .csv
df_train_and_val, df_test = group_stratified_split(df, label_columns=label_columns, group_col="group_id", split_ratio=TEST_RATIO, seed=RANDOM_SEED)
relative_val_ratio = VAL_RATIO / (TRAIN_RATIO + VAL_RATIO)
df_train, df_val = group_stratified_split(df_train_and_val, label_columns=label_columns, group_col="group_id", split_ratio=relative_val_ratio, seed=RANDOM_SEED)
df_train.to_csv("train_partition.csv", index=False)
df_val.to_csv("val_partition.csv", index=False)
df_test.to_csv("test_partition.csv", index=False)
print("Partitions saved to .csv files.")
# Load partitions from .csv
df_train = pd.read_csv("train_partition.csv")
df_val   = pd.read_csv("val_partition.csv")
df_test  = pd.read_csv("test_partition.csv")
print("Partitions loaded from .csv files.")
# DataLoaders
train_dataset = DataPartition(df_train, label_columns, transform=train_transforms)
val_dataset   = DataPartition(df_val,   label_columns, transform=val_transforms)
test_dataset  = DataPartition(df_test,  label_columns, transform=val_transforms)    
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=optimal_num_workers, pin_memory=pin_memory)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=optimal_num_workers, pin_memory=pin_memory)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=optimal_num_workers, pin_memory=pin_memory)
print(f"Train samples:      {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples:       {len(test_dataset)}")
print(f"Using num_workers: {optimal_num_workers}")
# Model, optimizer, scheduler
backbone = timm.create_model(MODEL_NAME, pretrained=True, num_classes=0, global_pool="") # instantiate a pure backbone from timm
model = add_ml_decoder_head( # strip & attach ML-Decoder head for your number of labels
    backbone,
    num_classes = len(label_columns),
    num_of_groups= 16,          # e.g. 16 queries → groups of ~labels/16. Try 18 for 18 classes.
    decoder_embedding=768,      # same as ViT-base / DETR
).to(device)
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05, amsgrad=False)
cos_scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=SCHEDULER_T0, T_mult=SCHEDULER_T_MULT, eta_min=MIN_LR)
plateau_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, threshold=1e-4, cooldown=1, min_lr=MIN_LR)
# Train 
monitor = TrainingMonitor()
trainer = Trainer(model=model, 
                    optimizer=optimizer,
                    scheduler_cos=cos_scheduler, 
                    scheduler_plateau=plateau_scheduler,
                    criterion=loss_fn,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    device=device, 
                    monitor=monitor,
                    patience=PATIENCE, 
                    warmup_epochs=3, 
                    amp_dtype=amp_dtype,
                    accumulation_steps=2
    )
sys.stdout = TeeOutput("training_log.txt")
log_hyperparameters()
trainer.train(N_EPOCHS)
# Pick Thresholds for Predictions 
if THRESHOLD_MODE == 'per_label':
    thresholds = find_optimal_thresholds(model, val_loader, device, num_classes=len(label_columns), n_steps=101) # on Validation
    print("\nOptimal per-class thresholds:", thresholds)
else: # Single Global Threshold
    thresholds = np.full(len(label_columns), GLOBAL_THRESHOLD, dtype=float)
print("Using thresholds:", thresholds)
# Save in Classifier wrapper
classifier = Classifier(model, val_transforms, device, labels=label_columns, thresholds=thresholds)
classifier.save('best_classifier')

# Test set evaluation
print("\nTest Set performance:")
model.eval()
all_probs = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits = model(images)
        probabilities = torch.sigmoid(logits).cpu().numpy()
        all_probs.append(probabilities)
        all_labels.append(labels.numpy())
all_probs  = np.vstack(all_probs)
all_labels = np.vstack(all_labels)
binary_predictions = (all_probs >= thresholds).astype(int)
# Classification Report on Test Set
per_class_AP = [average_precision_score(all_labels[:, i], all_probs[:, i]) for i in range(all_labels.shape[1])]
mean_AP = float(np.mean(per_class_AP))
precisions, recalls, f1s, supports = precision_recall_fscore_support(all_labels, binary_predictions, zero_division=0)
for idx, label in enumerate(label_columns):
    precision = precisions[idx]
    recall = recalls[idx]
    f1 = f1s[idx]
    num_occurrences = supports[idx]
    print(f"{label:<15s} AP={per_class_AP[idx]:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, num_occurences={num_occurrences}")
print(f"Test Set mean AP: {mean_AP:.4f}")
sys.stdout.log.close()
sys.stdout = sys.stdout.terminal

## Single Prediction Test

In [None]:
classifier = Classifier.load(config_filename="best_classifier.json", transform=val_transforms, device=device)
image_path = "miml_dataset\\images\\1291.jpg"
img = Image.open(image_path).convert("RGB")
probabilities = classifier.predict_probability(img)
predictions   = classifier.predict_binary(img)
labels = classifier.labels
thresholds = classifier.thresholds
# Lookup ground-truth vector in df_test
row = df_test[df_test["image_path"] == image_path].iloc[0]
truth_vector = row[label_columns].astype(int).to_numpy()
print("Labels:        ", labels)
print("Thresholds:    ","[{}]".format(", ".join(f"{t:.3f}" for t in thresholds)))
print("Probabilities: ","[{}]".format(", ".join(f"{p:.3f}" for p in probabilities)))
print("Predictions:   ", predictions.tolist())
print("GroundTruth:   ", truth_vector.tolist())
