### 1. Import Model Architecture

**Load dependencies**: Import the FairFaceMultiTaskModel, dataloaders, and encoders from notebook 2.2.

In [None]:
from IPython.utils.io import capture_output

with capture_output() as cap:
    %run "2.2-Multi-TaskModelArchitecture.ipynb"

### 2. Core Library Imports

**Import libraries**: PyTorch, NumPy, Pandas, scikit-learn for model training, loss computation, and class weight calculation.

In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Literal, Optional

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode

### 3. Helper Functions: Target Extraction

**Purpose**: Normalize batch format handling - supports both dict and nested dict target structures.

**Why needed**: DataLoader may return targets as `{"age": t, "gender": t, "race": t}` OR `{"y": {"age": t, ...}}`.

In [None]:
# Supports either:
#     targets = {"age": t, "gender": t, "race": t}
# or
#     targets = {"y": {"age": t, "gender": t, "race": t}}
def _extract_targets(targets: dict) -> dict:
    if "y" in targets and isinstance(targets["y"], dict):
        t = targets["y"]
    else:
        t = targets

    return {
        "age": t["age"].long().view(-1),
        "gender": t["gender"].long().view(-1),
        "race": t["race"].long().view(-1),
    }

### 3.1 Focal Loss Implementation

**Focal Loss Formula**: 

$$FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)$$

**Components**:
- **alpha_t (class weights)**: Handle statistical imbalance (rare vs frequent classes)
- **focal term (1-p_t)^gamma**: Down-weight easy examples, focus on hard examples
- **gamma=2.0**: Standard value from original paper (Lin et al., ICCV 2017)

**Implementation notes**:
- Applies label smoothing first via `F.cross_entropy`
- Then applies focal modulation
- Finally applies per-class weights if provided

In [None]:
def focal_loss_ce(
    logits: torch.Tensor,
    targets: torch.Tensor,
    *,
    gamma: float = 2.0,
    class_weights: Optional[torch.Tensor] = None,
    label_smoothing: float = 0.0,
) -> torch.Tensor:
    """
    Focal Loss (Lin et al., ICCV 2017) with optional per-class weighting.
    FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)

    Notes:
      - label_smoothing is applied via CE, then focal weighting is applied on top.
      - class_weights is a tensor of shape (C,) on same device as logits.
    """
    targets = targets.long().view(-1)

    # per-sample CE (optionally smoothed)
    ce = F.cross_entropy(logits, targets, reduction="none", label_smoothing=label_smoothing)  # (B,)

    # p_t for true class
    probs = F.softmax(logits, dim=1)
    pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1).clamp_min(1e-8)  # (B,)

    focal = (1.0 - pt) ** gamma

    if class_weights is not None:
        cw = class_weights.to(device=logits.device, dtype=logits.dtype)
        alpha_t = cw.gather(0, targets)  # (B,)
        loss = alpha_t * focal * ce
    else:
        loss = focal * ce

    return loss.mean()

### 3.2 Multi-Task Focal Loss

**Strategy**: Apply focal loss independently to each task (age/gender/race), then sum.

**Class weights**: Optional per-task weights dict with keys: age, gender, race

**Returns**: Total loss (for backward) + per-task loss breakdown (for logging)

In [None]:
# Multi-task focal loss for FairFace:
# preds: {"age": (B, A), "gender": (B, G), "race": (B, R)} logits
# targets: {"age": (B,), "gender": (B,), "race": (B,)} (or wrapped under targets["y"])
def multitask_loss_focal(
    preds: dict,
    targets: dict,
    *,
    gamma: float = 2.0,
    class_weights: Optional[dict] = None,
    label_smoothing: float = 0.0,
    task_loss_weights: Optional[dict] = None,
) -> tuple[torch.Tensor, dict]:
    t = _extract_targets(targets)

    cw_age = class_weights.get("age") if class_weights is not None else None
    cw_gender = class_weights.get("gender") if class_weights is not None else None
    cw_race = class_weights.get("race") if class_weights is not None else None

    loss_age = focal_loss_ce(preds["age"], t["age"], gamma=gamma, class_weights=cw_age, label_smoothing=label_smoothing)
    loss_gender = focal_loss_ce(preds["gender"], t["gender"], gamma=gamma, class_weights=cw_gender, label_smoothing=label_smoothing)
    loss_race = focal_loss_ce(preds["race"], t["race"], gamma=gamma, class_weights=cw_race, label_smoothing=label_smoothing)

    w_age = (task_loss_weights or {}).get("age", 1.0)
    w_gender = (task_loss_weights or {}).get("gender", 1.0)
    w_race = (task_loss_weights or {}).get("race", 1.0)

    total = w_age * loss_age + w_gender * loss_gender + w_race * loss_race

    loss_parts = {
        "age": loss_age.detach(),
        "gender": loss_gender.detach(),
        "race": loss_race.detach(),
        "total": total.detach(),
    }
    return total, loss_parts

### 3.3 Multi-Task Accuracy Computation

**Computes**: Per-task classification accuracy + mean accuracy across all tasks.

**Used for**: Training/validation monitoring without affecting gradients.

In [None]:
@torch.inference_mode()
def multitask_accuracies(preds: dict, targets: dict) -> dict:
    t = _extract_targets(targets)
    age_acc = (preds["age"].argmax(dim=1) == t["age"]).float().mean()
    gender_acc = (preds["gender"].argmax(dim=1) == t["gender"]).float().mean()
    race_acc = (preds["race"].argmax(dim=1) == t["race"]).float().mean()
    mean_acc = (age_acc + gender_acc + race_acc) / 3.0
    return {"age": age_acc, "gender": gender_acc, "race": race_acc, "mean": mean_acc}

### 4. Training Loop: Batch Unpacking

**Handles two batch formats**:
1. Tuple: `(imgs, targets, meta)` from custom collate functions
2. Dict: `{"img_t": imgs, "y": targets, "meta": meta}` from FairFaceDataset

**Returns**: Images and targets on device, metadata stays on CPU (not needed for training).

In [None]:
from contextlib import nullcontext


# Supports:
#     A) batch = (imgs, targets, meta) from custom collate_fn
#     B) batch = {"img_t": imgs, "y": targets, ...}
# Returns:
#     imgs:    (B,3,H,W) on device
#     targets: dict of tensors on device (age/gender/race)
#     meta:    whatever (kept on CPU)
def _unpack_batch(batch, device):
    if isinstance(batch, (tuple, list)) and len(batch) == 3:
        imgs, targets, meta = batch
    elif isinstance(batch, dict):
        imgs = batch["img_t"]
        targets = batch.get("y", batch)
        meta = batch.get("meta", None)
    else:
        raise TypeError(f"Unsupported batch type: {type(batch)}")

    imgs = imgs.to(device, non_blocking=True)
    t = _extract_targets(targets)
    t = {k: v.to(device, non_blocking=True) for k, v in t.items()}
    return imgs, t, meta

### 4.1 Training Loop: One Epoch

**Key features**:
- **Mixed precision (AMP)**: Uses `torch.amp.autocast` for faster training on modern GPUs
- **Focal loss**: Applies gamma focusing + optional class weights
- **Label smoothing**: Regularization technique (default 0.05)
- **Efficient gradient accumulation**: `set_to_none=True` for optimizer reset

**Returns**: Epoch-averaged losses and accuracies (per-task + total).

In [None]:
def train_one_epoch_fairface(
    model,
    optimizer,
    train_dl,
    *,
    gamma: float = 2.0,
    class_weights: dict | None = None,
    label_smoothing: float = 0.0,
    task_loss_weights: dict | None = None,
    grad_clip: float | None = 1.0,
    device: torch.device,
    scaler=None,
    amp: bool = True,
):
    model.train()

    use_amp = bool(amp and device.type == "cuda")
    autocast_cm = torch.amp.autocast(device_type="cuda", dtype=torch.float16) if use_amp else nullcontext()

    total = 0
    running_loss = {"age": 0.0, "gender": 0.0, "race": 0.0, "total": 0.0}
    running_acc = {"age": 0.0, "gender": 0.0, "race": 0.0, "mean": 0.0}

    for batch in train_dl:
        optimizer.zero_grad(set_to_none=True)

        imgs, targets, _ = _unpack_batch(batch, device)

        with autocast_cm:
            preds = model(imgs)
            total_loss, loss_parts = multitask_loss_focal(
                preds,
                targets,
                gamma=gamma,
                class_weights=class_weights,
                label_smoothing=label_smoothing,
                task_loss_weights=task_loss_weights,
            )

        if use_amp and scaler is not None:
            scaler.scale(total_loss).backward()
            if grad_clip is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            total_loss.backward()
            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        bs = imgs.size(0)
        total += bs

        accs = multitask_accuracies(preds, targets)

        running_loss["total"] += float(total_loss.detach().item()) * bs
        running_loss["age"] += float(loss_parts["age"].item()) * bs
        running_loss["gender"] += float(loss_parts["gender"].item()) * bs
        running_loss["race"] += float(loss_parts["race"].item()) * bs

        running_acc["age"] += float(accs["age"].item()) * bs
        running_acc["gender"] += float(accs["gender"].item()) * bs
        running_acc["race"] += float(accs["race"].item()) * bs
        running_acc["mean"] += float(accs["mean"].item()) * bs

    train_losses = {k: v / total for k, v in running_loss.items()}
    train_accs = {k: v / total for k, v in running_acc.items()}
    return train_losses, train_accs

### 4.2 Evaluation Loop: One Epoch

**Differences from training**:
- No gradient computation
- Model in eval mode (disables dropout, batchnorm updates)
- Typically uses `label_smoothing=0.0` for true performance measurement

**Used for**: Validation during training, final test evaluation.

In [None]:
@torch.inference_mode()
def eval_one_epoch_fairface(
    model,
    eval_dl,
    *,
    gamma: float = 2.0,
    class_weights: dict | None = None,
    label_smoothing: float = 0.0,
    task_loss_weights: dict | None = None,
    device: torch.device,
    amp: bool = True,
):
    model.eval()

    use_amp = bool(amp and device.type == "cuda")
    autocast_cm = torch.amp.autocast(device_type="cuda", dtype=torch.float16) if use_amp else nullcontext()

    total = 0
    running_loss = {"age": 0.0, "gender": 0.0, "race": 0.0, "total": 0.0}
    running_acc = {"age": 0.0, "gender": 0.0, "race": 0.0, "mean": 0.0}

    for batch in eval_dl:
        imgs, targets, _ = _unpack_batch(batch, device)

        with autocast_cm:
            preds = model(imgs)
            total_loss, loss_parts = multitask_loss_focal(
                preds,
                targets,
                gamma=gamma,
                class_weights=class_weights,
                label_smoothing=label_smoothing,
                task_loss_weights=task_loss_weights,
            )
            accs = multitask_accuracies(preds, targets)

        bs = imgs.size(0)
        total += bs

        running_loss["total"] += float(total_loss.item()) * bs
        running_loss["age"] += float(loss_parts["age"].item()) * bs
        running_loss["gender"] += float(loss_parts["gender"].item()) * bs
        running_loss["race"] += float(loss_parts["race"].item()) * bs

        running_acc["age"] += float(accs["age"].item()) * bs
        running_acc["gender"] += float(accs["gender"].item()) * bs
        running_acc["race"] += float(accs["race"].item()) * bs
        running_acc["mean"] += float(accs["mean"].item()) * bs

    valid_losses = {k: v / total for k, v in running_loss.items()}
    valid_accs = {k: v / total for k, v in running_acc.items()}
    return valid_losses, valid_accs

### 4.3 Full Training Loop (fit_fairface)

**Orchestrates**:
1. Train one epoch and compute train metrics
2. Evaluate on validation set and compute val metrics
3. Step learning rate scheduler (if provided)
4. Log metrics and learning rate
5. Store history for plotting

**Scheduler handling**: Detects ReduceLROnPlateau automatically and passes validation loss.

In [None]:
def fmt_metrics(d, factor=1.0, precision=4):
    return ", ".join(f"{k}:{factor * float(v):3.{precision}f}" for k, v in d.items())


def fit_fairface(
    model,
    optimizer,
    *,
    train_dl,
    valid_dl,
    epochs: int,
    sched=None,
    gamma: float = 2.0,
    class_weights: dict | None = None,
    label_smoothing: float = 0.0,
    task_loss_weights: dict | None = None,
    grad_clip: float | None = 1.0,
    device: torch.device,
    amp: bool = True,
):
    scaler = torch.amp.GradScaler(enabled=(amp and device.type == "cuda"))

    history = {
        "train_loss": [],
        "train_acc": [],
        "valid_loss": [],
        "valid_acc": [],
        "lr": [],
    }

    for ep in range(1, epochs + 1):
        train_losses, train_accs = train_one_epoch_fairface(
            model,
            optimizer,
            train_dl,
            gamma=gamma,
            class_weights=class_weights,
            label_smoothing=label_smoothing,
            task_loss_weights=task_loss_weights,
            grad_clip=grad_clip,
            device=device,
            scaler=scaler,
            amp=amp,
        )

        valid_losses, valid_accs = eval_one_epoch_fairface(
            model,
            valid_dl,
            gamma=gamma,
            class_weights=class_weights,
            label_smoothing=0.0,  # keep eval clean unless you explicitly want smoothing
            task_loss_weights=task_loss_weights,
            device=device,
            amp=amp,
        )

        if sched is not None:
            if "plateau" in sched.__class__.__name__.lower():
                sched.step(valid_losses["total"])
            else:
                sched.step()

        curr_lr = float(optimizer.param_groups[0]["lr"])

        history["train_loss"].append(train_losses)
        history["train_acc"].append(train_accs)
        history["valid_loss"].append(valid_losses)
        history["valid_acc"].append(valid_accs)
        history["lr"].append(curr_lr)

        print(
            f"[Epoch {ep:02d}/{epochs:02d}]:\n"
            f"Train loss: {fmt_metrics(train_losses)} | "
            f"Train acc: {fmt_metrics(train_accs, factor=100, precision=2)}\n"
            f"Valid loss: {fmt_metrics(valid_losses)} | "
            f"Valid acc: {fmt_metrics(valid_accs, factor=100, precision=2)} | "
            f"lr: {curr_lr:.8f}"
        )

    return history

### 5. Model Setup

**Architecture**: ResNet-34 backbone + 3 task-specific heads (age/gender/race)

**Optimizer**: AdamW with lr=1e-4 (good default for fine-tuning pretrained models)

**Scheduler**: ReduceLROnPlateau - reduces LR when validation loss plateaus

**Loss config**:
- `gamma=2.0`: Focal loss focusing parameter
- `label_smoothing=0.05`: Prevents overconfident predictions

In [None]:
# --- device (single source of truth) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- model (create once, then move once) ---
model = FairFaceMultiTaskModel(
    pretrained=True, 
    freeze_backbone=False,
    dropout_p=0.3  # INCREASED from 0.2
).to(device)

# --- loss config ---
gamma = 2.0  # INCREASED from 1.5 (standard focal loss)
label_smoothing = 0.1  # INCREASED from 0.05
task_loss_weights = {"age": 1.2, "gender": 1.0, "race": 1.5}  # INCREASED race from 1.2
grad_clip = 1.0

# --- optimizer ---
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=1e-4,
    weight_decay=1e-4  # ADDED (L2 regularization)
)

# --- scheduler ---
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=2,
    threshold=1e-3,
    min_lr=1e-6,
)

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

### 5.1 Class Weight Computation Strategy

**Weight type:** Inverse-frequency weights using sklearn's `class_weight="balanced"` (higher weight for rarer classes).

**Why from DataLoader (not DataFrame):**
- Ensures weights match the encoded integer labels actually seen by the model
- Handles any custom encoding/preprocessing automatically
- Validates against the true batch distribution

**Algorithm:**
1. Infer number of classes from the model heads (no hardcoding)
2. Collect all labels from the training DataLoader
3. Compute balanced weights per task using sklearn's balanced mode
4. Assign weight = 1.0 to any class not present in the sample
5. Move all weight tensors to GPU once to avoid transfer overhead

**Formula (per class $i$, per task):**

$$
w_i = \frac{N_{\text{total}}}{C \cdot N_i}
$$

where $N_{\text{total}}$ is the total number of samples for the task, $C$ is the number of classes, and $N_i$ is the sample count of class $i$.

In [None]:
from sklearn.utils.class_weight import compute_class_weight


@torch.inference_mode()
def _infer_num_classes_from_model(model) -> dict:
    """Infer number of classes from model head architecture."""
    return {
        "age": model.age_head.out_features,
        "gender": model.gender_head.out_features,
        "race": model.race_head.out_features,
    }


def compute_class_weights_from_loader(train_dl, *, model, device: torch.device) -> dict:
    """
    Compute balanced class weights from training DataLoader.
    
    Args:
        train_dl: Training DataLoader
        model: Model instance (to infer num_classes)
        device: Target device for weight tensors
    
    Returns:
        Dict mapping task -> torch.Tensor of weights (shape: [num_classes])
    """
    num_classes = _infer_num_classes_from_model(model)
    
    # Collect all labels from training data
    ys = {"age": [], "gender": [], "race": []}
    
    print("Collecting labels from training data...")
    for batch in train_dl:
        # Handle both dict and tuple batch formats
        if isinstance(batch, dict):
            targets = batch.get("y", batch)
        elif isinstance(batch, (tuple, list)) and len(batch) >= 2:
            _, targets = batch[:2]
        else:
            continue
        
        t = _extract_targets(targets)
        for k in ys.keys():
            ys[k].append(t[k].cpu().numpy())
    
    # Compute weights per task
    out = {}
    for task, chunks in ys.items():
        y = np.concatenate(chunks, axis=0) if len(chunks) else np.array([], dtype=np.int64)
        C = num_classes[task]
        
        # Initialize all weights to 1.0 (neutral)
        w = np.ones((C,), dtype=np.float32)
        
        if y.size > 0:
            present_classes = np.unique(y)
            
            # Compute balanced weights only for present classes
            w_present = compute_class_weight(
                class_weight="balanced",
                classes=present_classes,
                y=y
            ).astype(np.float32)
            
            w[present_classes] = w_present
            
            print(f"{task:6s}: {C} classes | present: {len(present_classes)} | "
                  f"weights range [{w.min():.3f}, {w.max():.3f}]")
        
        out[task] = torch.tensor(w, dtype=torch.float32, device=device)
    
    return out


# Compute class weights and move to device
print("=" * 60)
print("COMPUTING CLASS WEIGHTS FROM TRAINING DATA")
print("=" * 60)

class_weights = compute_class_weights_from_loader(train_loader, model=model, device=device)

print("\nâœ“ Class weights computed and moved to device!")
print(f"  Device: {device}")
print(f"  Tasks: {list(class_weights.keys())}")
print("\nWeights will be used in focal loss to handle class imbalance.")

In [None]:
# --- train (REDUCED epochs to stop overfitting) ---
epochs = 10  # REDUCED from 15
history = fit_fairface(
    model,
    optimizer,
    sched=sched,
    train_dl=train_loader,
    valid_dl=valid_loader,
    epochs=epochs,
    gamma=gamma,
    class_weights=class_weights,
    label_smoothing=label_smoothing,
    task_loss_weights=task_loss_weights,
    grad_clip=grad_clip,
    device=device,
)

### 7. Training Visualization

**Plots**:
- **Row 1**: Per-task loss curves (age, gender, race)
- **Row 2**: Per-task accuracy curves

**What to look for**:
- Train/val loss converging means good learning
- Val loss increasing while train decreasing means overfitting
- Consistent gap between train/val may need more regularization

In [None]:
# Expects:
# history["train_loss"], history["valid_loss"] : list[dict]
# history["train_acc"],  history["valid_acc"]  : list[dict]
# Each dict has per-task keys like: "age", "gender", "race"
def plot_fairface_target_metrics(history, keys=("age", "gender", "race")):
    epochs = range(1, len(history["train_loss"]) + 1)

    # keep only keys that actually exist
    existing = []
    for k in keys:
        if k in history["train_loss"][0] and k in history["valid_loss"][0]:
            existing.append(k)

    if not existing:
        raise KeyError(f"None of these keys found in history dicts: {keys}")

    n = len(existing)
    fig, axes = plt.subplots(2, n, figsize=(6 * n, 8), constrained_layout=True)

    # if n==1, axes is 1D in each row; normalize indexing
    if n == 1:
        axes = [axes[0:1], axes[1:2]]  # make it 2 x 1-like

    # --- Row 1: Loss ---
    for i, k in enumerate(existing):
        ax = axes[0][i]
        tr = [d[k] for d in history["train_loss"]]
        va = [d[k] for d in history["valid_loss"]]
        ax.plot(epochs, tr, label="Train loss")
        ax.plot(epochs, va, label="Valid loss")
        ax.set_title(f"{k} loss")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.grid(True)
        ax.legend()

    # --- Row 2: Accuracy ---
    for i, k in enumerate(existing):
        ax = axes[1][i]
        tr = [d[k] * 100 for d in history["train_acc"]]   # as %
        va = [d[k] * 100 for d in history["valid_acc"]]   # as %
        ax.plot(epochs, tr, label="Train acc")
        ax.plot(epochs, va, label="Valid acc")
        ax.set_title(f"{k} accuracy")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Accuracy (%)")
        ax.grid(True)
        ax.legend()

    plt.show()

# usage
plot_fairface_target_metrics(history)

### 8. Test Set Evaluation

**Final performance measurement**:
- Runs on held-out test set (never seen during training/validation)
- Uses `label_smoothing=0.0` for true accuracy (no regularization)
- Reports per-task accuracies + mean accuracy

**This is the number you report**: Test accuracy represents real-world generalization.

In [None]:
model.eval()
test_losses, test_accs = eval_one_epoch_fairface(
    model=model.to(device),
    eval_dl=test_loader,
    gamma=gamma,
    class_weights=class_weights,
    task_loss_weights=task_loss_weights,
    device=device,
    label_smoothing=0.0,
)

print(f"Test total loss : {test_losses['total']:.4f}")
print(f"Test mean acc   : {test_accs['mean']*100:.2f}%")

print(f"Age acc         : {test_accs['age']*100:.2f}%")
print(f"Gender acc      : {test_accs['gender']*100:.2f}%")
print(f"Race acc        : {test_accs['race']*100:.2f}%")

print("Losses:", {k: round(v, 4) for k, v in test_losses.items()})
print("Accs  :", {k: round(v*100, 2) for k, v in test_accs.items()})
