In [1]:
# BLOCK A: imports, paths, device

import os
from pathlib import Path
import json
import random

import numpy as np
import pandas as pd
import cv2
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from ultralytics import YOLO  # YOLOv11 via ultralytics

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Paths (adjust DATA_ROOT if needed)
DATA_ROOT = Path(r"D:/Work/Projects/Facepoint Recognizer/Face-Orientation-Detector/data")   # <-- change if your root is different
TRAIN_DIR = DATA_ROOT / "train"
TEST_DIR  = DATA_ROOT / "test"

COCO_TRAIN_JSON = TRAIN_DIR / "_annotations.coco.json"
COCO_TEST_JSON  = TEST_DIR  / "_annotations.coco.json"

# Processed output
PROC_ROOT = DATA_ROOT / "processed_yolo"
PROC_TRAIN_DIR = PROC_ROOT / "train"
PROC_VAL_DIR   = PROC_ROOT / "val"

PROC_TRAIN_DIR.mkdir(parents=True, exist_ok=True)
PROC_VAL_DIR.mkdir(parents=True, exist_ok=True)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# BLOCK B: load COCO annotations into a table (for train)

def load_coco_annotations(coco_path: Path):
    with open(coco_path, "r") as f:
        coco = json.load(f)

    images = {img["id"]: img for img in coco["images"]}
    anns = coco["annotations"]
    categories = {cat["id"]: cat for cat in coco["categories"]}

    # find the category id for "head" (has keypoints)
    head_cat_id = None
    for cid, cat in categories.items():
        if cat["name"] == "head":
            head_cat_id = cid
            head_category = cat
            break
    if head_cat_id is None:
        raise ValueError("Could not find 'head' category in COCO.")

    num_keypoints = len(head_category["keypoints"])
    print(f"Found 'head' category with {num_keypoints} keypoints.")

    rows = []
    for ann in anns:
        if ann["category_id"] != head_cat_id:
            continue

        img_info = images[ann["image_id"]]
        file_name = img_info["file_name"]
        w, h = img_info["width"], img_info["height"]

        # COCO keypoints: [x1,y1,v1, x2,y2,v2, ...]
        kpts_raw = ann["keypoints"]
        if len(kpts_raw) != num_keypoints * 3:
            # Skip broken annotation
            continue

        xs, ys = [], []
        for i in range(num_keypoints):
            x = kpts_raw[3*i]
            y = kpts_raw[3*i + 1]
            v = kpts_raw[3*i + 2]  # visibility
            if v > 0:
                xs.append(x)
                ys.append(y)
            else:
                xs.append(0.0)
                ys.append(0.0)

        bbox = ann["bbox"]  # [x, y, w, h] in original image

        rows.append({
            "image_id": img_info["id"],
            "file_name": file_name,
            "width": w,
            "height": h,
            "bbox": bbox,
            "keypoints_x": xs,
            "keypoints_y": ys
        })

    df = pd.DataFrame(rows)
    print("Loaded", len(df), "head annotations")
    return df, head_category

train_df, head_category = load_coco_annotations(COCO_TRAIN_JSON)
train_df.head()


Found 'head' category with 26 keypoints.
Loaded 515 head annotations


Unnamed: 0,image_id,file_name,width,height,bbox,keypoints_x,keypoints_y
0,0,image_f1f0e67c_jpg.rf.d3a8d0305ac8db19cb9e1eb0...,1280,1280,"[0, 0, 1212.475, 793.518]","[1093.003, 1080.399, 1163.084, 1120.685, 1135....","[298.578, 359.332, 496.351, 558.768, 574.951, ..."
1,1,image_9ddc2611_jpg.rf.c3c6a6e39aff4c63cec65c7d...,1280,1280,"[0, 0, 961.884, 733.776]","[852.885, 846.384, 890.255, 857.758, 866.966, ...","[318.805, 369.856, 478.678, 504.124, 519.825, ..."
2,2,image_b2315e84_jpg.rf.a7e085afb96272ebbd3dd489...,1280,1280,"[2, 0, 735, 783.333]","[674.483, 648.902, 686.697, 623.481, 624.131, ...","[346.61, 390.304, 544.489, 569.404, 604.225, 6..."
3,3,image_b058cc7d_jpg.rf.d73d0cbad2755b76d304d814...,1280,1280,"[0, 2, 1228.773, 1018.869]","[936.555, 953.546, 1097.31, 1039.868, 1056.728...","[375.658, 489.655, 619.883, 713.567, 739.78, 8..."
4,4,image_fb8775d0_jpg.rf.010835876f3589c0aded7a4b...,1280,1280,"[0, 0, 1089.357, 1280]","[836.936, 827.941, 977.961, 901.131, 902.892, ...","[429.582, 565.491, 798.858, 893.165, 958.754, ..."


In [4]:
# BLOCK C: train/val split on original train set

from sklearn.model_selection import train_test_split

train_df_split, val_df_split = train_test_split(
    train_df,
    test_size=0.2,
    random_state=SEED,
    shuffle=True,
    stratify=None  # dataset probably small, so no stratify
)

print("Train samples:", len(train_df_split))
print("Val samples  :", len(val_df_split))


Train samples: 412
Val samples  : 103


In [5]:
# BLOCK D: YOLOv11 Option A preprocessing (run once)

IMG_SIZE = 512

# Load YOLOv11 detection model (small, for speed)
yolo_model = YOLO("yolo11n.pt")  # make sure weights downloaded once

# class id for "person" in COCO is 0 for standard YOLO models
PERSON_CLASS_ID = 0

def letterbox_pad_to_square(img, keypoints_xy):
    """
    img: HxWx3 (uint8)
    keypoints_xy: list of (x, y) in original image coords
    Returns:
        img_512: 512x512x3
        remapped_kpts: list of (x', y') in 512x512 coords
    """
    h, w = img.shape[:2]
    scale = IMG_SIZE / max(h, w)
    new_w, new_h = int(w * scale), int(h * scale)

    # Resize
    resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

    # Pad to 512x512 (black)
    pad_x = (IMG_SIZE - new_w) // 2
    pad_y = (IMG_SIZE - new_h) // 2

    img_512 = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
    img_512[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized

    # Remap keypoints
    remapped = []
    for (x, y) in keypoints_xy:
        if x == 0 and y == 0:
            remapped.append((0.0, 0.0))
        else:
            x_res = x * scale + pad_x
            y_res = y * scale + pad_y
            remapped.append((x_res, y_res))

    return img_512, remapped


def apply_yolo_mask(img_bgr, bbox_from_coco=None):
    """
    Apply YOLOv11 detection, keep pixels inside person bbox, background black.
    If YOLO fails, fallback to bbox_from_coco.
    """
    h, w = img_bgr.shape[:2]

    # YOLO expects RGB
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    results = yolo_model(img_rgb, verbose=False)[0]

    x1, y1, x2, y2 = None, None, None, None

    if results.boxes is not None and len(results.boxes) > 0:
        boxes = results.boxes.xyxy.cpu().numpy()
        classes = results.boxes.cls.cpu().numpy()
        confs = results.boxes.conf.cpu().numpy()

        # filter person class
        person_indices = [i for i, c in enumerate(classes) if int(c) == PERSON_CLASS_ID]
        if len(person_indices) > 0:
            # take highest confidence person
            best_i = sorted(person_indices, key=lambda i: confs[i], reverse=True)[0]
            x1, y1, x2, y2 = boxes[best_i]
            x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])

    # Fallback: use COCO bbox (x,y,w,h)
    if x1 is None and bbox_from_coco is not None:
        bx, by, bw, bh = bbox_from_coco
        x1, y1, x2, y2 = int(bx), int(by), int(bx + bw), int(by + bh)

    # If still None, return original (no mask)
    if x1 is None:
        return img_bgr

    mask = np.zeros((h, w), dtype=np.uint8)
    mask[max(0, y1):min(h, y2), max(0, x1):min(w, x2)] = 1

    out = img_bgr.copy()
    out[mask == 0] = 0  # background to black
    return out


def preprocess_split(df_split, split_name, out_dir):
    """
    df_split: train_df_split or val_df_split subset
    split_name: "train" or "val"
    out_dir: directory to save images
    Returns: DataFrame with new info: file_name_512, keypoints_x_512, keypoints_y_512
    """
    rows_out = []
    out_dir.mkdir(parents=True, exist_ok=True)

    for idx, row in tqdm(df_split.iterrows(), total=len(df_split), desc=f"Preprocess {split_name}"):
        img_path = TRAIN_DIR / row["file_name"]  # original images are under train/
        if not img_path.exists():
            print("Missing image:", img_path)
            continue

        img_bgr = cv2.imread(str(img_path))
        if img_bgr is None:
            print("Failed to read:", img_path)
            continue

        # 1) YOLO mask
        masked = apply_yolo_mask(img_bgr, bbox_from_coco=row["bbox"])

        # 2) letterbox to 512x512, remap kpts
        keypoints_xy = list(zip(row["keypoints_x"], row["keypoints_y"]))
        img_512, kpts_512 = letterbox_pad_to_square(masked, keypoints_xy)

        # 3) save new image
        new_fname = f"{split_name}_{row['image_id']}.png"
        save_path = out_dir / new_fname
        cv2.imwrite(str(save_path), img_512)

        xs_512 = [kp[0] for kp in kpts_512]
        ys_512 = [kp[1] for kp in kpts_512]

        rows_out.append({
            "orig_file_name": row["file_name"],
            "file_name_512": new_fname,
            "width_512": IMG_SIZE,
            "height_512": IMG_SIZE,
            "keypoints_x_512": xs_512,
            "keypoints_y_512": ys_512
        })

    return pd.DataFrame(rows_out)


proc_train_df = preprocess_split(train_df_split, "train", PROC_TRAIN_DIR)
proc_val_df   = preprocess_split(val_df_split, "val", PROC_VAL_DIR)

print("Processed train:", len(proc_train_df), "images")
print("Processed val  :", len(proc_val_df), "images")


Preprocess train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 412/412 [00:19<00:00, 20.99it/s]
Preprocess val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 103/103 [00:04<00:00, 22.05it/s]

Processed train: 412 images
Processed val  : 103 images





In [6]:
# (optional) save preprocessed annotations
proc_train_df.to_pickle(PROC_ROOT / "proc_train_df.pkl")
proc_val_df.to_pickle(PROC_ROOT / "proc_val_df.pkl")


In [7]:
# BLOCK E: Albumentations transforms and Dataset class

NUM_KEYPOINTS = len(head_category["keypoints"])  # 26 in your COCO
IMG_SIZE = 512  # keep constant

mean = (0.5, 0.5, 0.5)
std  = (0.5, 0.5, 0.5)

train_transform = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.05,
            scale_limit=0.10,
            rotate_limit=20,
            border_mode=cv2.BORDER_CONSTANT,
            value=(0, 0, 0),
            p=0.8
        ),
        A.RandomBrightnessContrast(
            brightness_limit=0.2, contrast_limit=0.2, p=0.7
        ),
        A.GaussNoise(var_limit=(5.0, 20.0), p=0.4),
        A.MotionBlur(blur_limit=5, p=0.3),
        A.Normalize(mean=mean, std=std),
        ToTensorV2()
    ],
    keypoint_params=A.KeypointParams(format="xy", remove_invisible=False)
)

val_transform = A.Compose(
    [
        A.Normalize(mean=mean, std=std),
        ToTensorV2()
    ],
    keypoint_params=A.KeypointParams(format="xy", remove_invisible=False)
)


class FaceKptsDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.root_dir = Path(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.root_dir / row["file_name_512"]
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        xs = row["keypoints_x_512"]
        ys = row["keypoints_y_512"]
        keypoints = [(float(x), float(y)) for x, y in zip(xs, ys)]

        # Albumentations
        if self.transform is not None:
            transformed = self.transform(
                image=img,
                keypoints=keypoints
            )
            img_t = transformed["image"]
            kpts_t = transformed["keypoints"]
        else:
            img_t = img
            kpts_t = keypoints

        # convert to tensor shape (52,)
        kpts_arr = np.array(kpts_t, dtype=np.float32)  # (K, 2)
        # normalise [0, IMG_SIZE] -> [0,1]
        kpts_arr /= IMG_SIZE

        kpts_flat = torch.from_numpy(kpts_arr.reshape(-1))  # (2*K,)

        sample = {
            "image": img_t,          # (3,H,W)
            "keypoints": kpts_flat,  # (2*K,)
            "file_name": row["file_name_512"]
        }
        return sample


train_dataset = FaceKptsDataset(proc_train_df, PROC_TRAIN_DIR, transform=train_transform)
val_dataset   = FaceKptsDataset(proc_val_df,   PROC_VAL_DIR,   transform=val_transform)

print("Train dataset size:", len(train_dataset))
print("Val   dataset size:", len(val_dataset))


Train dataset size: 412
Val   dataset size: 103


  original_init(self, **validated_kwargs)
  A.ShiftScaleRotate(
  A.GaussNoise(var_limit=(5.0, 20.0), p=0.4),
  self._set_keys()


In [8]:
# BLOCK F: DataLoaders

BATCH_SIZE = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    drop_last=False
)

print("DataLoaders ready.")


DataLoaders ready.


In [9]:
# BLOCK G: model definition with cosine-aux loss

class ResNetKpts(nn.Module):
    def __init__(self, num_keypoints):
        super().__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        in_feats = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Sequential(
            nn.Linear(in_feats, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, num_keypoints * 2)
        )

    def forward(self, x):
        feat = self.backbone(x)
        out = self.head(feat)
        return out


model = ResNetKpts(NUM_KEYPOINTS).to(device)
print(model)

# Criterion components
reg_crit = nn.SmoothL1Loss()

cosine = nn.CosineSimilarity(dim=1, eps=1e-6)
COSINE_WEIGHT = 0.1  # small auxiliary term


def combined_loss(pred, target):
    """
    pred, target: (B, 2*K), coordinates normalized in [0,1].
    """
    reg_loss = reg_crit(pred, target)

    # cosine similarity: we want vectors to point in same direction
    pred_norm = pred - pred.mean(dim=1, keepdim=True)
    tgt_norm  = target - target.mean(dim=1, keepdim=True)

    cos_sim = cosine(pred_norm, tgt_norm)  # (B,)
    cos_loss = (1.0 - cos_sim).mean()

    return reg_loss + COSINE_WEIGHT * cos_loss, reg_loss.item(), cos_loss.item()


ResNetKpts(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [10]:
# BLOCK H: optimizer and scheduler

BASE_LR = 1e-3
WEIGHT_DECAY = 1e-4
MAX_EPOCHS = 120
PATIENCE = 15   # early stopping patience

optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=MAX_EPOCHS,
    eta_min=1e-5
)

print("Optimizer & scheduler ready.")


Optimizer & scheduler ready.


In [None]:
# BLOCK I: metrics and training loop with checkpointing & early stopping

def compute_l2_dist_px(pred, target):
    """
    pred, target: (B, 2*K) normalized in [0,1].
    returns mean L2 distance per keypoint in pixels (IMG_SIZE).
    """
    B = pred.shape[0]
    pred_xy = pred.view(B, -1, 2) * IMG_SIZE
    tgt_xy  = target.view(B, -1, 2) * IMG_SIZE

    dists = torch.linalg.norm(pred_xy - tgt_xy, dim=-1)  # (B, K)
    return dists.mean().item()


def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_reg_loss = 0.0
    total_cos_loss = 0.0
    total_dist = 0.0
    n_batches = 0

    pbar = tqdm(loader, desc="Train", leave=False)
    for batch in pbar:
        imgs = batch["image"].to(device)
        targets = batch["keypoints"].to(device)

        optimizer.zero_grad()
        preds = model(imgs)

        loss, reg_l, cos_l = combined_loss(preds, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_reg_loss += reg_l
        total_cos_loss += cos_l
        total_dist += compute_l2_dist_px(preds.detach(), targets.detach())
        n_batches += 1

        pbar.set_postfix(loss=total_loss/n_batches, dist=total_dist/n_batches)

    return {
        "loss": total_loss / n_batches,
        "reg_loss": total_reg_loss / n_batches,
        "cos_loss": total_cos_loss / n_batches,
        "dist_px": total_dist / n_batches
    }


@torch.no_grad()
def eval_one_epoch(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_reg_loss = 0.0
    total_cos_loss = 0.0
    total_dist = 0.0
    n_batches = 0

    pbar = tqdm(loader, desc="Val", leave=False)
    for batch in pbar:
        imgs = batch["image"].to(device)
        targets = batch["keypoints"].to(device)

        preds = model(imgs)
        loss, reg_l, cos_l = combined_loss(preds, targets)

        total_loss += loss.item()
        total_reg_loss += reg_l
        total_cos_loss += cos_l
        total_dist += compute_l2_dist_px(preds, targets)
        n_batches += 1

        pbar.set_postfix(loss=total_loss/n_batches, dist=total_dist/n_batches)

    return {
        "loss": total_loss / n_batches,
        "reg_loss": total_reg_loss / n_batches,
        "cos_loss": total_cos_loss / n_batches,
        "dist_px": total_dist / n_batches
    }


history = {
    "train_loss": [],
    "val_loss": [],
    "train_dist": [],
    "val_dist": [],
    "train_reg_loss": [],
    "val_reg_loss": [],
    "train_cos_loss": [],
    "val_cos_loss": []
}

CHECKPOINT_DIR = DATA_ROOT / "checkpoints_resnet_yolo"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

best_val_loss = float("inf")
epochs_no_improve = 0

for epoch in range(1, MAX_EPOCHS + 1):
    print(f"\nEpoch {epoch}/{MAX_EPOCHS} (lr={optimizer.param_groups[0]['lr']:.6f})")

    train_metrics = train_one_epoch(model, train_loader, optimizer, device)
    val_metrics   = eval_one_epoch(model, val_loader, device)

    scheduler.step()

    history["train_loss"].append(train_metrics["loss"])
    history["val_loss"].append(val_metrics["loss"])
    history["train_dist"].append(train_metrics["dist_px"])
    history["val_dist"].append(val_metrics["dist_px"])
    history["train_reg_loss"].append(train_metrics["reg_loss"])
    history["val_reg_loss"].append(val_metrics["reg_loss"])
    history["train_cos_loss"].append(train_metrics["cos_loss"])
    history["val_cos_loss"].append(val_metrics["cos_loss"])

    print(
        f"Train - loss: {train_metrics['loss']:.4f} | "
        f"reg: {train_metrics['reg_loss']:.4f} | "
        f"cos: {train_metrics['cos_loss']:.4f} | "
        f"L2 dist: {train_metrics['dist_px']:.2f} px"
    )
    print(
        f"Val   - loss: {val_metrics['loss']:.4f} | "
        f"reg: {val_metrics['reg_loss']:.4f} | "
        f"cos: {val_metrics['cos_loss']:.4f} | "
        f"L2 dist: {val_metrics['dist_px']:.2f} px"
    )

    # Save epoch checkpoint
    ckpt_path = CHECKPOINT_DIR / f"epoch_{epoch:03d}.pth"
    torch.save(
        {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "history": history,
        },
        ckpt_path
    )

    # Early stopping logic
    if val_metrics["loss"] < best_val_loss - 1e-4:
        best_val_loss = val_metrics["loss"]
        best_ckpt_path = CHECKPOINT_DIR / "best_model.pth"
        torch.save(model.state_dict(), best_ckpt_path)
        print(f"‚úÖ New best model (val loss {best_val_loss:.4f}) saved to {best_ckpt_path}")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve} epoch(s).")

    if epochs_no_improve >= PATIENCE:
        print(f"‚èπ Early stopping triggered (patience = {PATIENCE}).")
        break



Epoch 1/120 (lr=0.001000)


Train:   0%|          | 0/52 [00:00<?, ?it/s]

In [None]:
# BLOCK F ‚Äî UPDATED TRAINING LOOP FOR FAST DATASET (NO SEGMENTATION IN LOOP)

def compute_l2_dist(pred, target):
    """
    pred, target: (B, 52) ‚Üí flattened (x1,y1,...)
    returns mean L2 dist per keypoint across the batch.
    """
    B = pred.shape[0]
    pred_xy = pred.view(B, -1, 2)   # (B, 26, 2)
    tgt_xy  = target.view(B, -1, 2)

    dists = torch.linalg.norm(pred_xy - tgt_xy, dim=-1)  # (B, 26)
    return dists.mean().item()


def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_dist = 0.0
    n_batches = 0

    for imgs, targets in tqdm(loader, desc="Train", leave=False):
        imgs    = imgs.float().to(device)      # (B, 3, 512, 512)
        targets = targets.float().to(device)   # (B, 52)

        optimizer.zero_grad()

        preds = model(imgs)
        loss = crit(preds, targets)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_dist += compute_l2_dist(preds.detach().cpu(), targets.detach().cpu())
        n_batches += 1

    return total_loss / n_batches, total_dist / n_batches


@torch.no_grad()
def eval_one_epoch(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_dist = 0.0
    n_batches = 0

    for imgs, targets in tqdm(loader, desc="Val", leave=False):
        imgs    = imgs.float().to(device)
        targets = targets.float().to(device)

        preds = model(imgs)
        loss = crit(preds, targets)

        total_loss += loss.item()
        total_dist += compute_l2_dist(preds.detach().cpu(), targets.detach().cpu())
        n_batches += 1

    return total_loss / n_batches, total_dist / n_batches


# ---------- MAIN TRAINING LOOP ----------
num_epochs = 30
best_val_loss = float("inf")
history = {"train_loss": [], "val_loss": [], "train_dist": [], "val_dist": []}

CHECKPOINT_DIR = DATA_ROOT / "checkpoints"
CHECKPOINT_DIR.mkdir(exist_ok=True)

print("üöÄ Training Started...")

for epoch in range(1, num_epochs + 1):
    print(f"\nüîµ Epoch {epoch}/{num_epochs}")

    train_loss, train_dist = train_one_epoch(model, train_loader, optimizer, device)
    val_loss,   val_dist   = eval_one_epoch(model, val_loader, device)

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_dist"].append(train_dist)
    history["val_dist"].append(val_dist)

    print(f"Train ‚Üí loss: {train_loss:.4f}, L2 dist: {train_dist:.2f}px")
    print(f"Val   ‚Üí loss: {val_loss:.4f}, L2 dist: {val_dist:.2f}px")

    # save per-epoch checkpoint
    ckpt_path = CHECKPOINT_DIR / f"epoch_{epoch:03d}.pth"
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "history": history,
    }, ckpt_path)

    # best model tracking
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_ckpt = CHECKPOINT_DIR / "best_model.pth"
        torch.save(model.state_dict(), best_ckpt)
        print(f"‚úÖ Best model updated ‚Üí {best_ckpt}")

print("\nüéâ Training Complete!")


üöÄ Training Started...

üîµ Epoch 1/30


Train:   0%|          | 0/33 [00:00<?, ?it/s]

In [None]:
# Metrics curves

epochs = range(1, len(history["train_loss"]) + 1)

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, history["train_loss"], label="Train loss")
plt.plot(epochs, history["val_loss"], label="Val loss")
plt.xlabel("Epoch")
plt.ylabel("SmoothL1 loss")
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(epochs, history["train_dist"], label="Train L2 dist")
plt.plot(epochs, history["val_dist"], label="Val L2 dist")
plt.xlabel("Epoch")
plt.ylabel("Mean L2 distance (px)")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
# BLOCK G: visualize predictions vs ground truth

@torch.no_grad()
def visualize_predictions(model, dataset, num_samples=6):
    model.eval()
    idxs = np.random.choice(len(dataset), size=min(num_samples, len(dataset)), replace=False)

    num_cols = 3
    num_rows = int(np.ceil(len(idxs) / num_cols))

    plt.figure(figsize=(5 * num_cols, 5 * num_rows))

    for i, idx in enumerate(idxs):
        item = dataset[idx]
        img = item["image"].unsqueeze(0).to(device)
        gt_kpts = item["keypoints"].numpy().reshape(-1, 2)

        pred = model(img).cpu().numpy().reshape(-1, 2)

        # de-normalize image
        img_np = item["image"].numpy()
        img_np = np.transpose(img_np, (1, 2, 0))
        img_np = img_np * train_dataset.std + train_dataset.mean
        img_np = np.clip(img_np, 0, 1)

        plt.subplot(num_rows, num_cols, i + 1)
        plt.imshow(img_np)
        plt.scatter(gt_kpts[:, 0], gt_kpts[:, 1], s=60, c="lime", label="GT")
        plt.scatter(pred[:, 0], pred[:, 1], s=60, c="red", marker="x", label="Pred")
        plt.title(f"idx={idx}")
        plt.axis("off")

    # Only one legend for all subplots
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.figlegend(handles, labels, loc="upper right")
    plt.tight_layout()
    plt.show()


# Example usage:
visualize_predictions(model, val_dataset, num_samples=9)



Epoch 1/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:32<00:00,  1.01it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:04<00:00,  1.25it/s]


Train Loss: 0.071767 | Val Loss: 0.025882
‚úî Saved Best Model

Epoch 2/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:20<00:00,  1.57it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:03<00:00,  1.92it/s]


Train Loss: 0.005927 | Val Loss: 0.004911
‚úî Saved Best Model

Epoch 3/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:20<00:00,  1.61it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:03<00:00,  1.89it/s]


Train Loss: 0.003482 | Val Loss: 0.002139
‚úî Saved Best Model

Epoch 4/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:20<00:00,  1.64it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.01it/s]


Train Loss: 0.002995 | Val Loss: 0.002258

Epoch 5/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:20<00:00,  1.61it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.09it/s]


Train Loss: 0.002508 | Val Loss: 0.001810
‚úî Saved Best Model

Epoch 6/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:19<00:00,  1.67it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.13it/s]


Train Loss: 0.002799 | Val Loss: 0.003610

Epoch 7/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:19<00:00,  1.68it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.13it/s]


Train Loss: 0.004267 | Val Loss: 0.001447
‚úî Saved Best Model

Epoch 8/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:19<00:00,  1.74it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.12it/s]


Train Loss: 0.001663 | Val Loss: 0.001266
‚úî Saved Best Model

Epoch 9/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:19<00:00,  1.73it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.12it/s]


Train Loss: 0.002088 | Val Loss: 0.001693

Epoch 10/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:18<00:00,  1.76it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.45it/s]


Train Loss: 0.002363 | Val Loss: 0.001901

Epoch 11/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.83it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.59it/s]


Train Loss: 0.001508 | Val Loss: 0.000714
‚úî Saved Best Model

Epoch 12/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.82it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.56it/s]


Train Loss: 0.001022 | Val Loss: 0.000559
‚úî Saved Best Model

Epoch 13/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.82it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.62it/s]


Train Loss: 0.000840 | Val Loss: 0.000470
‚úî Saved Best Model

Epoch 14/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.84it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.60it/s]


Train Loss: 0.001033 | Val Loss: 0.000524

Epoch 15/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.83it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.62it/s]


Train Loss: 0.000912 | Val Loss: 0.000635

Epoch 16/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.83it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.58it/s]


Train Loss: 0.000967 | Val Loss: 0.000462
‚úî Saved Best Model

Epoch 17/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.80it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.55it/s]


Train Loss: 0.001071 | Val Loss: 0.000496

Epoch 18/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.81it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.58it/s]


Train Loss: 0.000877 | Val Loss: 0.000477

Epoch 19/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.82it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.53it/s]


Train Loss: 0.000831 | Val Loss: 0.000483

Epoch 20/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.82it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.53it/s]


Train Loss: 0.001067 | Val Loss: 0.000538

Epoch 21/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.82it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.60it/s]


Train Loss: 0.000827 | Val Loss: 0.000473

Epoch 22/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.81it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.58it/s]


Train Loss: 0.000816 | Val Loss: 0.000496

Epoch 23/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.81it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.65it/s]


Train Loss: 0.000788 | Val Loss: 0.000487

Epoch 24/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.86it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.65it/s]


Train Loss: 0.000812 | Val Loss: 0.000483

Epoch 25/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.85it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.65it/s]


Train Loss: 0.000778 | Val Loss: 0.000418
‚úî Saved Best Model

Epoch 26/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.83it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.57it/s]


Train Loss: 0.000666 | Val Loss: 0.000455

Epoch 27/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.83it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.63it/s]


Train Loss: 0.000724 | Val Loss: 0.000451

Epoch 28/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:12<00:00,  2.64it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.34it/s]


Train Loss: 0.000906 | Val Loss: 0.000418

Epoch 29/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:12<00:00,  2.57it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.26it/s]


Train Loss: 0.000850 | Val Loss: 0.000438

Epoch 30/30


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:11<00:00,  2.83it/s]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.61it/s]

Train Loss: 0.000601 | Val Loss: 0.000430



