# Introduction.
What is KD, importance, why we need to learn from it, why is expect of the members to get.
Requirements needed etc.

## Dataset

In this tutorial, we will use <a href="https://cocodataset.org/#keypoints-2017"> COCO 2017 keypoint dataset</a> and use <a href="https://docs.voxel51.com/#where-to-begin"> fiftyone </a> for access and visualization of the the dataset.

In [55]:
import fiftyone as fo
import fiftyone.zoo as foz

In [56]:
print(fo)

<module 'fiftyone' from '/home/clinton-mwangi/Desktop/projects/2d-human-pose-estimation-using-kd/kd_hpe_env/lib/python3.12/site-packages/fiftyone/__init__.py'>


In [None]:
foz.list_zoo_datasets()

In [57]:
coco2017_train = foz.load_zoo_dataset("coco-2017", split='train',max_samples=30000,  label_types=["person_keypoints"])

Downloading split 'train' to '/home/clinton-mwangi/fiftyone/coco-2017/train' if necessary


2026-01-29 11:41:16,871 - INFO - Downloading split 'train' to '/home/clinton-mwangi/fiftyone/coco-2017/train' if necessary


Found annotations at '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_train2017.json'


2026-01-29 11:41:16,876 - INFO - Found annotations at '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_train2017.json'


Sufficient images already downloaded


2026-01-29 11:41:39,660 - INFO - Sufficient images already downloaded


Existing download of split 'train' is sufficient


2026-01-29 11:41:42,465 - INFO - Existing download of split 'train' is sufficient


Loading existing dataset 'coco-2017-train-30000'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


2026-01-29 11:41:42,470 - INFO - Loading existing dataset 'coco-2017-train-30000'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


In [5]:
coco2017_val = foz.load_zoo_dataset("coco-2017", split='validation')

Downloading split 'validation' to '/home/clinton-mwangi/fiftyone/coco-2017/validation' if necessary
Found annotations at '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_val2017.json'
Found 1000 (< 5000) downloaded images; must download full image zip
Downloading images to '/home/clinton-mwangi/fiftyone/coco-2017/tmp-download/val2017.zip'
 100% |██████|    6.1Gb/6.1Gb [49.1m elapsed, 0s remaining, 2.3Mb/s]       
Extracting images to '/home/clinton-mwangi/fiftyone/coco-2017/validation/data'
Writing annotations to '/home/clinton-mwangi/fiftyone/coco-2017/validation/labels.json'
Dataset info written to '/home/clinton-mwangi/fiftyone/coco-2017/info.json'
Loading 'coco-2017' split 'validation'
 100% |███████████████| 5000/5000 [23.5s elapsed, 0s remaining, 98.4 samples/s]       
Dataset 'coco-2017-validation' created


In [10]:
coco2017_val

Name:        coco-2017-validation-1000
Media type:  image
Num samples: 1000
Persistent:  False
Tags:        []
Sample fields:
    id:               fiftyone.core.fields.ObjectIdField
    filepath:         fiftyone.core.fields.StringField
    tags:             fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)
    metadata:         fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.ImageMetadata)
    created_at:       fiftyone.core.fields.DateTimeField
    last_modified_at: fiftyone.core.fields.DateTimeField
    ground_truth:     fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Detections)

In [None]:
session = fo.launch_app(coco2017_train, auto=False) #False here make use to not have the app in the cell

In [None]:
session

In [None]:
session.url

## Distillation pipeline

### Teacher model

In this tutprial we will use <a href="https://github.com/HRNet/HRNet-Human-Pose-Estimation?tab=readme-ov-file"> HR-net pose model </a>

In [1]:
import torch
from hrnetpose_model import get_pose_net
from configs import load_configs

# 1. Load the YAML
cfg = load_configs('hrnet_w48_model_configs.yaml')

# 2. Initialize Model
teacher_model = get_pose_net(cfg, is_train=False)

# 3. Load Weights
checkpoint = torch.load('hrnet_pose_models/pose_hrnet_w48_384x288.pth', map_location='cpu')
teacher_model.load_state_dict(checkpoint)
teacher_model.eval()

print("Model successfully initialized from YAML!")


Model successfully initialized from YAML!


#### Student model

In [2]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from collections import namedtuple

# Define the output structure
KeypointOutput = namedtuple('KeypointOutput', ['heatmaps'])

class SqueezeNetHPE(nn.Module):
    def __init__(self, num_keypoints=17):
        super().__init__()

        # Load SqueezeNet backbone (stride 16)
        squeezenet = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1)
        self.backbone = squeezenet.features

        # Simple decoder
        self.relu = nn.ReLU(inplace=True)
        
        # SqueezeNet 1.1 features output 512 channels
        self.conv_heatmap = nn.Conv2d(
            in_channels=512,
            out_channels=num_keypoints,
            kernel_size=3,
            padding=1
        )

    def forward(self, x):
        # 1. Feature extraction
        # Input: [B, 3, 384, 288] -> Output: [B, 512, 23, 17]
        x = self.backbone(x)

        # 2. Decoder
        x = self.relu(x)
        
        # 3. Explicit Upsampling 
        # Instead of scale_factor=4, we provide the exact target dimensions.
        # This fixes the "92x68" issue by forcing the output to 96x72.
        x = F.interpolate(
            x,
            size=(96, 72), 
            mode="bilinear",
            align_corners=False
        )

        # 4. Final heatmap convolution
        heatmaps = self.conv_heatmap(x)
        
        # Return as namedtuple
        return KeypointOutput(heatmaps=heatmaps)

# --- Initialization and Debug ---

# 1. Initialize the model
num_keypoints = 17
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = SqueezeNetHPE(num_keypoints=num_keypoints).to(device)


#### Data processing

#### Helper functions

In [3]:

import os
import cv2
import logging
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO  # type: ignore[import-untyped]
from typing import Any, Union, List, Tuple, Optional, Callable, cast, Dict, Iterator

def find_images_with_coco_keypoints(
    ann_file: str, img_root: str, verbose: bool = True
) -> Tuple[List[str], List[str]]:
    """
    Returns a list of image file paths that satisfy:
    - At least one person with all 17 keypoints visible, OR
    - More than one person, with each having at least 12 visible keypoints.

    Args:
        ann_file (str): Path to COCO annotation file.
        img_root (str): Path to image folder.
        verbose (bool): If True, prints info about matches.

    Returns:
        List[str]: List of image file paths.
    """
    coco = COCO(ann_file)
    selected_with_17_images = []
    selected_with_12_more_images = []
    for img_id in coco.imgs.keys():
        ann_ids = coco.getAnnIds(imgIds=img_id, catIds=[1], iscrowd=None)
        anns = coco.loadAnns(ann_ids)
        keypoints_per_person = []
        for ann in anns:
            keypoints = ann["keypoints"]
            visibility_flags = keypoints[2::3]
            num_visible = sum([v == 2 for v in visibility_flags])
            keypoints_per_person.append(num_visible)
        # Condition 1: at least one person with all 17 keypoints visible
        cond1 = any(k == 17 for k in keypoints_per_person)
        # Condition 2: more than one person, AND each has at least 12 keypoints visible
        cond2 = len(keypoints_per_person) > 1 and all(
            k >= 12 for k in keypoints_per_person
        )
        if cond1:
            imginfo = coco.loadImgs(img_id)[0]
            file_name = imginfo["file_name"]
            img_path = os.path.join(img_root, file_name)
            selected_with_17_images.append(img_path)
        if cond2:
            imginfo = coco.loadImgs(img_id)[0]
            file_name = imginfo["file_name"]
            img_path = os.path.join(img_root, file_name)
            selected_with_12_more_images.append(img_path)
    if verbose:
        print(
            f"\nTotal images found: \
            {len(selected_with_17_images)} and {len(selected_with_12_more_images)}"
        )
    return selected_with_17_images, selected_with_12_more_images


def extract_list_keypoints_and_visibility(
    keypoints: Union[List[float], List[List[float]]]
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """
    Extract keypoint coordinates and visibility flags from a
    single list or a list of keypoints lists.
    """

    def _process_single_keypoints(kps: List[float]) -> Tuple[np.ndarray, np.ndarray]:
        if len(kps) % 3 != 0:
            raise ValueError(f"Keypoints length must be multiple of 3, got {len(kps)}")
        coords = np.array(
            [[kps[i], kps[i + 1]] for i in range(0, len(kps), 3)], dtype=np.float32
        )
        visibility = np.array(
            [kps[i + 2] for i in range(0, len(kps), 3)], dtype=np.float32
        )
        return coords, visibility

    coords_list, visibility_list = [], []

    if not keypoints:
        raise ValueError("Input keypoints list is empty.")

    if isinstance(keypoints[0], (float, int)):
        kps = cast(List[float], keypoints)
        coords, visibility = _process_single_keypoints(kps)
        coords_list.append(coords)
        visibility_list.append(visibility)

    elif isinstance(keypoints[0], (list, tuple)):
        keypoints_list = cast(List[List[float]], keypoints)
        for kp in keypoints_list:
            coords, visibility = _process_single_keypoints(kp)
            coords_list.append(coords)
            visibility_list.append(visibility)
    else:
        raise TypeError("Input must be a list of floats or a list of lists of floats.")

    return coords_list, visibility_list


def extract_keypoints_and_visibility(
    keypoints: List[float],
) -> Tuple[List[List[float]], List[int]]:
    """
    Extracts (x, y) coordinates and visibility flags from a flat keypoints list.

    Args:
        keypoints (List[float]): A flat list structured as [x1, y1, v1, x2, y2, v2, ...].

    Returns:
        Tuple[List[List[float]], List[int]]:
            - coords: list of [x, y] pairs
            - visibility: list of visibility flags (0, 1, or 2)

    Raises:
        ValueError: If the keypoints list length is not a multiple of 3.
    """
    try:
        if len(keypoints) % 3 != 0:
            raise ValueError(
                "Keypoints list length must be a multiple of 3 (x, y, v per keypoint)."
            )

        coords: List[List[float]] = []
        visibility: List[int] = []

        for i in range(0, len(keypoints), 3):
            x = float(keypoints[i])
            y = float(keypoints[i + 1])
            v = int(keypoints[i + 2])
            coords.append([x, y])
            visibility.append(v)

        logger.debug(f"Extracted {len(coords)} keypoints successfully.")
        return coords, visibility

    except Exception as e:
        logger.error(f"Error extracting keypoints and visibility: {e}")
        raise


def process_image_and_keypoints(
    image: np.ndarray,
    keypoints: Union[List[List[float]], np.ndarray],
    bbox: Union[List[int], Tuple[int, int, int, int]],
    target_size: Tuple[int, int],
    angle: float = 0,
    flip: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Crop, resize, rotate, and optionally flip an image and its corresponding keypoints
    for pose estimation tasks.

    Steps:
        1. Crop the image using the bounding box and clamp to image bounds.
        2. Shift keypoints relative to the crop coordinates.
        3. Mask keypoints outside the crop.
        4. Resize the image and scale keypoints to the target size.
        5. Rotate image and keypoints if angle != 0.
        6. Flip image horizontally if flip=True.
        7. Restore [0, 0] for originally invisible or out-of-bounds keypoints.

    Args:
        image (np.ndarray): Input image of shape (H, W, C).
        keypoints (list or np.ndarray): Keypoints of shape (N, 2) or single keypoint (2,).
        bbox (list or tuple): Bounding box [x, y, w, h] for cropping.
        target_size (tuple): Desired output size (width, height) after resizing.
        angle (float, optional): Rotation angle in degrees. Default is 0.
        flip (bool, optional): Whether to horizontally flip the image. Default is False.

    Returns:
        Tuple[np.ndarray, np.ndarray]:
            - resized_image: Processed image of shape (target_height, target_width, C)
            - keypoints: Transformed keypoints of shape (N, 2)

    Raises:
        ValueError: If keypoints array is not in expected shape.
    """
    try:
        x1, y1, w, h = map(lambda x: int(round(x)), bbox)
        x2, y2 = x1 + w, y1 + h
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(image.shape[1], x2)
        y2 = min(image.shape[0], y2)

        # Step 1: Cropping image
        cropped_image = image[y1:y2, x1:x2]

        # Step 2: Shifting keypoints
        kps: np.ndarray = np.asarray(keypoints, dtype=np.float32)
        if kps.ndim == 1 and kps.size == 2:
            kps = kps.reshape(1, 2)
        elif kps.ndim != 2 or kps.shape[1] != 2:
            raise ValueError(f"Keypoints must be of shape [N,2], got shape {kps.shape}")

        zero_mask = np.all(kps == 0, axis=1)
        kps -= np.array([x1, y1])

        # Step 3: Masking keypoints outside crop
        crop_h, crop_w = cropped_image.shape[:2]
        invalid_mask = (
            (kps[:, 0] < 0)
            | (kps[:, 1] < 0)
            | (kps[:, 0] >= crop_w)
            | (kps[:, 1] >= crop_h)
        )

        # Step 4: Resizing the image and scale keypoints
        resized_image = cv2.resize(cropped_image, target_size)
        scale_x = target_size[0] / crop_w
        scale_y = target_size[1] / crop_h
        kps *= np.array([scale_x, scale_y])

        # Step 5: Rotating image and keypoints
        if angle != 0:
            center = (target_size[0] // 2, target_size[1] // 2)
            rot_mat = cv2.getRotationMatrix2D(center, angle, 1.0)
            resized_image = cv2.warpAffine(resized_image, rot_mat, target_size)
            kps = (rot_mat[:, :2] @ kps.T + rot_mat[:, 2:]).T

        # Step 6: Flipping horizontally when required
        if flip:
            resized_image = cv2.flip(resized_image, 1)
            kps[:, 0] = target_size[0] - kps[:, 0]

        # Step 7: Restoring [0,0] for invisible or out-of-bounds keypoints
        kps[zero_mask | invalid_mask] = [0, 0]

        logger.debug(
            f"Processed image shape: {resized_image.shape}, keypoints shape: {kps.shape}"
        )
        return resized_image, kps

    except Exception as e:
        logger.error(f"Error in process_image_and_keypoints: {e}")
        raise e


def process_image(
    image: np.ndarray,
    bbox: list[float] | tuple[float, float, float, float],
    target_size: tuple[int, int],
    angle: float = 0,
    flip: bool = False,
) -> tuple[np.ndarray, bool]:
    """
    Processes an image for pose estimation inference using a bounding box.

    This function crops, resizes, and optionally rotates or flips an image
    based on a given bounding box, preparing it for pose estimation input.

    Args:
        image (np.ndarray): Input image in BGR format (as read by OpenCV).
        bbox (list[float] | tuple[float, float, float, float]): Bounding box in
            COCO format [x, y, width, height].
        target_size (tuple[int, int]): Target output size as (width, height).
        angle (float, optional): Rotation angle in degrees. Defaults to 0.
        flip (bool, optional): If True, flip the image horizontally. Defaults to False.

    Returns:
        tuple[np.ndarray, bool]:
            - Processed image (np.ndarray).
            - Success flag (bool), True if processing succeeded, False otherwise.

    Raises:
        ValueError: If the bounding box dimensions are invalid or out of image bounds.
    """
    try:
        if image is None or not isinstance(image, np.ndarray):
            raise ValueError("Invalid image input: must be a numpy array.")

        # Convert bbox and clamp to image bounds
        x1, y1, x2, y2 = map(lambda x: int(round(x)), bbox)
        w = x2 - x1
        h = y2 - y1

        # Ensure bbox is within image dimensions
        if w <= 0 or h <= 0:
            raise ValueError(f"Invalid bbox dimensions: {bbox}")

        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(image.shape[1], x2)
        y2 = min(image.shape[0], y2)

        if x2 <= x1 or y2 <= y1:
            raise ValueError(
                f"Invalid bbox coordinates after clamping: {(x1, y1, x2, y2)}"
            )

        # Step 1: Crop image
        cropped_image = image[y1:y2, x1:x2]

        # Step 2: Resize image
        resized_image = cv2.resize(cropped_image, target_size)

        # Step 3: Rotate if needed
        if angle != 0:
            center = (target_size[0] // 2, target_size[1] // 2)
            rot_mat = cv2.getRotationMatrix2D(center, angle, 1.0)
            resized_image = cv2.warpAffine(resized_image, rot_mat, target_size)

        # Step 4: Flip horizontally if needed
        if flip:
            resized_image = cv2.flip(resized_image, 1)

        return resized_image, True

    except Exception as e:
        print(f"[process_image] Error processing image: {e}")
        blank_image = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
        return blank_image, False


def generate_heatmaps(
    keypoints: np.ndarray,
    keypoints_visible: np.ndarray,
    input_size: Tuple[int, int],
    heatmap_size: Tuple[int, int] = (64, 48),
    sigma: float = 2.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate Gaussian heatmaps centered at keypoints for pose estimation,
    and return a mask for visible/labeled keypoints.

    Args:
        keypoints (np.ndarray): Keypoint coordinates in input image space, shape (K, 2)
        keypoints_visible (np.ndarray): Visibility flags for each keypoint, shape (K,)
        input_size (Tuple[int, int]): Input image size (width, height)
        heatmap_size (Tuple[int, int], optional): Output heatmap size (width, height).
        Default is (64, 48)
        sigma (float, optional): Gaussian sigma controlling spread. Default is 2.0

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - Heatmaps tensor of shape (K, H, W)
            - Mask tensor of shape (K,) with 1 for valid keypoints, 0 for missing/invisible
    """
    try:
        K, D = keypoints.shape
        if D != 2:
            raise ValueError(
                f"Keypoints should have shape (K, 2), got {keypoints.shape}"
            )

        W, H = heatmap_size
        w, h = input_size

        heatmaps = np.zeros((K, H, W), dtype=np.float32)
        mask = np.zeros((K,), dtype=np.float32)

        # Scale factor from input image to heatmap
        scale_factor = np.array([(W - 1) / (w - 1), (H - 1) / (h - 1)])

        for k in range(K):
            if keypoints_visible[k] < 0.5:
                continue

            x, y = keypoints[k]
            x_hm = x * scale_factor[0]
            y_hm = y * scale_factor[1]

            if not (0 <= x_hm < W and 0 <= y_hm < H):
                continue

            # Creating Gaussian heatmap
            x_grid = np.arange(W)
            y_grid = np.arange(H)
            xx, yy = np.meshgrid(x_grid, y_grid)
            gaussian = np.exp(-((xx - x_hm) ** 2 + (yy - y_hm) ** 2) / (2 * sigma**2))
            heatmaps[k] = gaussian
            mask[k] = 1.0

        heatmaps_tensor = torch.from_numpy(heatmaps)
        mask_tensor = torch.from_numpy(mask)

        return heatmaps_tensor, mask_tensor

    except Exception as e:
        print(f"Error in generate_heatmaps: {e}")
        raise e


#### Dataloader

In [4]:
import os
import torch
import logging
import numpy as np
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
from typing import Optional, Callable, Any, List, Tuple
from torch import Tensor

# Setting up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

def get_downloaded_image_ids(root_dir: str, ann_file: str) -> List[int]:
    """
    Checks the local directory and returns IDs of images that actually exist.
    """
    coco = COCO(ann_file)
    all_img_ids = coco.getImgIds()
    existing_ids = []
    
    # Fast check: see if the file exists on disk
    for img_id in all_img_ids:
        filename = coco.loadImgs(img_id)[0]['file_name']
        if os.path.exists(os.path.join(root_dir, filename)):
            existing_ids.append(img_id)
            
    return existing_ids

class CustomCocoKeypoints(Dataset):
    """
    Custom COCO keypoints dataset where each item is a single person annotation.
    Instead of 1 image = 1 sample, we use 1 person = 1 sample.
    """
    def __init__(
        self, root: str, annFile: str, transform: Optional[Callable] = None
    ) -> None:
        self.root = root
        self.transform = transform
        
        # Load COCO API
        self.coco = COCO(annFile)
        
        # FIX: Ensure arguments match the utility function signature (root_dir, ann_file)
        try:
            downloaded_ids = get_downloaded_image_ids(self.root, annFile)
        except Exception as e:
            logger.warning(f"Could not filter downloaded images: {e}. Using all IDs from JSON.")
            downloaded_ids = list(self.coco.imgs.keys())

        logger.info("Mapping all valid person annotations...")
        
        # Store pairs of (image_id, annotation_dict)
        self.valid_anns = []
        for img_id in downloaded_ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            
            # Filter for annotations that actually have keypoints
            for ann in anns:
                if ann.get("num_keypoints", 0) > 0:
                    # Every person becomes a distinct entry in the dataset
                    self.valid_anns.append((img_id, ann))
        
        logger.info(
            f"Dataset initialized: {len(self.valid_anns)} people found "
            f"across {len(downloaded_ids)} images."
        )

    def __len__(self) -> int:
        return len(self.valid_anns)

    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, Tensor]:
        try:
            # 1. Get image ID and specific annotation for this index
            img_id, target = self.valid_anns[index]
            
            # 2. Load Image
            img_info = self.coco.loadImgs(img_id)[0]
            img_path = os.path.join(self.root, img_info["file_name"])
            img = Image.open(img_path).convert("RGB")
            num_image = np.array(img)

            # 3. Extract keypoints and bounding box for THIS specific person
            bbox = target["bbox"]
        
            
            keyp, visibility = extract_keypoints_and_visibility(target["keypoints"])
            
            # 4. Process (Crop to person bbox, resize, etc.)
            processed_img_np, processed_keypoints = process_image_and_keypoints(
                image=num_image,
                keypoints=keyp,
                bbox=bbox,
                target_size=(288, 384),
                angle=0,
                flip=False,
            )

            # 5. Convert Image to Tensor [3, H, W]
            processed_img: Tensor = (
                torch.from_numpy(processed_img_np).permute(2, 0, 1).float() / 255.0
            )

            # 6. Generate Heatmaps
            input_size = (384, 288)
            heatmap_size = (96, 72)
            visibility_arr = np.asarray(visibility, dtype=np.float32)
            
            heatmaps_tensor, _mask_tensor = generate_heatmaps(
                keypoints=processed_keypoints,
                keypoints_visible=visibility_arr,
                input_size=input_size,
                heatmap_size=heatmap_size,
                sigma=2.0,
            )

            return processed_img, heatmaps_tensor, _mask_tensor

        except Exception as e:
            logger.error(f"Error processing index {index}: {e}")
            raise e

def collate_fn(
    batch: List[Tuple[Tensor, Tensor, Tensor]]
) -> Tuple[Tensor, Tensor, Tensor]:
    processed_img, heatmap_tensor, _mask_tensor = zip(*batch)
    return (
        torch.stack(processed_img),
        torch.stack(heatmap_tensor),
        torch.stack(_mask_tensor)
    )

def get_coco_dataloaders(
    root_train: str, ann_train: str,
    root_val: str, ann_val: str,
    batch_size: int = 32, num_workers: int = 4,
    pin_memory: bool = True, persistent_workers: bool = True,
) -> Tuple[DataLoader, DataLoader]:
    
    try:
        train_dataset = CustomCocoKeypoints(root=root_train, annFile=ann_train)
        val_dataset = CustomCocoKeypoints(root=root_val, annFile=ann_val)
        
        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
            collate_fn=collate_fn,
        )
        
        val_loader = DataLoader(
            dataset=val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
            collate_fn=collate_fn,
            drop_last=True,
        )
        
        logger.info("COCO DataLoaders created successfully.")
        return train_loader, val_loader
    except Exception as e:
        logger.error(f"Error creating COCO dataloaders: {e}")
        raise e

# --- Execution ---
root_train = "/home/clinton-mwangi/fiftyone/coco-2017/train/data"
root_val = "/home/clinton-mwangi/fiftyone/coco-2017/validation/data"
ann_train = "/home/clinton-mwangi/fiftyone/coco-2017/raw/person_keypoints_train2017.json"
ann_val = "/home/clinton-mwangi/fiftyone/coco-2017/raw/person_keypoints_val2017.json"

train_loader, val_loader = get_coco_dataloaders(
    root_train=root_train,
    ann_train=ann_train,
    root_val=root_val,
    ann_val=ann_val,
    batch_size=8,
    num_workers=4
)

loading annotations into memory...
Done (t=5.76s)
creating index...
index created!
loading annotations into memory...
Done (t=5.68s)
creating index...
index created!


2026-01-29 13:25:16,380 - INFO - Mapping all valid person annotations...
2026-01-29 13:25:16,500 - INFO - Dataset initialized: 38151 people found across 30000 images.


loading annotations into memory...
Done (t=0.23s)
creating index...
index created!
loading annotations into memory...


2026-01-29 13:25:17,031 - INFO - Mapping all valid person annotations...
2026-01-29 13:25:17,066 - INFO - Dataset initialized: 6352 people found across 5000 images.
2026-01-29 13:25:17,067 - INFO - COCO DataLoaders created successfully.


Done (t=0.22s)
creating index...
index created!


#### LOad the data

In [5]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
class AdaptiveWingLoss(nn.Module):
    """
    Adaptive Wing Loss for keypoint heatmap regression.

    This loss is designed for heatmap-based keypoint estimation tasks,
    giving higher weight to small but important errors, and less weight
    to large errors that might correspond to outliers.

    Args:
        omega (float): Maximum weight for the loss. Default: 14
        theta (float): Threshold separating small and large errors. Default: 0.5
        epsilon (float): Small constant for numerical stability. Default: 1
        alpha (float): Exponent controlling sensitivity to small errors. Default: 2.1
        weight (float): Additional weight for keypoints above a threshold. Default: 10
    """

    def __init__(
        self,
        omega: float = 14,
        theta: float = 0.5,
        epsilon: float = 1.0,
        alpha: float = 2.1,
        weight: float = 10.0,
    ):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha
        self.weight = weight
        logger.info(
            f"AdaptiveWingLoss initialized with omega={omega},\
              theta={theta}, epsilon={epsilon}, alpha={alpha}, weight={weight}"
        )

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute the Adaptive Wing Loss between predicted and target heatmaps.

        Args:
            pred (torch.Tensor): Predicted heatmaps of shape (B, N, H, W)
            target (torch.Tensor): Ground-truth heatmaps of shape (B, N, H, W)

        Returns:
            torch.Tensor: Scalar loss value
        """
        try:
            delta = torch.abs(pred - target)

            # Dilated heatmap using 3x3 max pooling
            Hd = F.max_pool2d(target, kernel_size=3, stride=1, padding=1)

            # Mask where dilated heatmap exceeds threshold
            M = (Hd >= 0.2).float()

            # Constants for the adaptive wing function
            A = (
                self.omega
                * (
                    1.0
                    / (1.0 + torch.pow(self.theta / self.epsilon, self.alpha - target))
                )
                * (self.alpha - target)
                * (torch.pow(self.theta / self.epsilon, self.alpha - target - 1))
                * (1.0 / self.epsilon)
            )

            C = self.theta * A - self.omega * torch.log(
                1.0 + torch.pow(self.theta / self.epsilon, self.alpha - target)
            )

            # Apply piecewise loss
            losses = torch.where(
                delta < self.theta,
                self.omega
                * torch.log(1.0 + torch.pow(delta / self.epsilon, self.alpha - target)),
                A * delta - C,
            )

            # Apply weighted loss map
            weighted_losses = losses * (self.weight * M + 1)

            loss_mean = weighted_losses.mean()
            logger.debug(
                f"AdaptiveWingLoss computed, mean loss: {loss_mean.item():.6f}"
            )

            return loss_mean

        except Exception as e:
            logger.error(f"Error in AdaptiveWingLoss forward pass: {e}")
            raise e

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Training 

In [28]:
for imgs, gt_heatmaps, mask in train_loader:
    break

In [29]:
ht = teacher_model(imgs)

In [30]:
ht.shape

torch.Size([8, 17, 96, 72])

In [31]:
imgs.shape

torch.Size([8, 3, 384, 288])

In [32]:
st = student_model(imgs)

In [33]:
st.heatmaps.shape

torch.Size([8, 17, 96, 72])

In [None]:
import torch.optim as optim
from tqdm import tqdm  # type: ignore[import-untyped]
from torch.amp import GradScaler
from collections import namedtuple
from copy import deepcopy
import psutil
epochs = 1
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

for batch_idx, (imgs, gt_heatmaps, masks) in enumerate(loop):
    imgs = imgs.to(device, non_blocking=True)
    heatmaps = gt_heatmaps.to(device, non_blocking=True)
    masks = masks.to(device, non_blocking=True)

In [7]:
import torch.optim as optim
from tqdm import tqdm  # type: ignore[import-untyped]
from torch.amp import GradScaler
from collections import namedtuple
from copy import deepcopy
import psutil  # type: ignore[import-untyped]
def train_student_with_teacher(
    student_model: nn.Module,
    teacher_model: nn.Module,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    epochs: int = 3,
    lr: float = 5e-4,
    alpha: float = 0.7,
    weight_decay: float = 1e-5,
    val_frequency: int = 5,
    save_dir: str = "./models",
) -> tuple[nn.Module, dict]:
    """
    Train a student model using a teacher model for knowledge distillation.

    Args:
        student_model (nn.Module): Student model to train.
        teacher_model (nn.Module): Teacher model providing target heatmaps.
        train_loader (DataLoader): Training dataset loader.
        val_loader (DataLoader): Validation dataset loader.
        epochs (int, optional): Number of training epochs. Default: 3
        lr (float, optional): Learning rate. Default: 5e-4
        alpha (float, optional): Weight for distillation loss. Default: 0.7
        weight_decay (float, optional): Weight decay for optimizer. Default: 1e-5
        val_frequency (int, optional): Frequency (in epochs) to run validation. Default: 5
        save_dir (str, optional): Directory to save best model. Default: "./models"

    Returns:
        Tuple[nn.Module, dict]: Trained student model and training history
    """
    try:
        os.makedirs(save_dir, exist_ok=True)

        student_model.to(device)
        teacher_model.to(device)
        teacher_model.eval()

        for param in teacher_model.parameters():
            param.requires_grad = False

        optimizer = optim.AdamW(
            student_model.parameters(), lr=lr, weight_decay=weight_decay
        )
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[20, 27], gamma=0.1
        )
        criterion = AdaptiveWingLoss()

        scaler = GradScaler("cuda" if torch.cuda.is_available() else "cpu")

        history: Dict[str, List[float]] = {"train_loss": [], "val_loss": [], "lr": []}
        best_val_loss = float("inf")
        best_model_state = None

        for epoch in range(epochs):
            student_model.train()
            train_loss = 0.0
            loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

            for batch_idx, (imgs, gt_heatmaps, masks) in enumerate(loop):
                imgs = imgs.to(device, non_blocking=True)
                heatmaps = gt_heatmaps.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)
                optimizer.zero_grad(set_to_none=True)

                with torch.no_grad(), torch.amp.autocast(device_type=device.type):
                    dataset_index = torch.zeros(
                        imgs.size(0), dtype=torch.long, device=imgs.device
                    )
                    teacher_heatmaps = teacher_model(
                        imgs
                    )

                with torch.amp.autocast(device_type=device.type):
                    student_heatmaps = student_model(imgs).heatmaps
                    mask_expanded = masks.unsqueeze(-1).unsqueeze(-1)
                    student_loss = criterion(
                        student_heatmaps * mask_expanded, heatmaps * mask_expanded
                    )
                    distillation_loss = criterion(
                        student_heatmaps * mask_expanded,
                        teacher_heatmaps * mask_expanded,
                    )
                    total_loss = alpha * distillation_loss + (1 - alpha) * student_loss

                scaler.scale(total_loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()

                train_loss += total_loss.item()
                loop.set_postfix({"loss": train_loss / (batch_idx + 1)})

            avg_train_loss = train_loss / len(train_loader)
            history["train_loss"].append(avg_train_loss)
            history["lr"].append(optimizer.param_groups[0]["lr"])
            scheduler.step()
            logger.info(f"Epoch {epoch+1} - Avg training loss: {avg_train_loss:.4f}")

            if (epoch + 1) % val_frequency == 0 or (epoch + 1) == epochs:
                val_loss = validate_student(
                    student_model, teacher_model, val_loader, criterion, alpha
                )
                history["val_loss"].append(val_loss)

                logger.info(f"Epoch {epoch+1} - Validation loss: {val_loss:.4f}")

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    model_to_save = student_model
                    best_model_state = model_to_save.state_dict()
                    torch.save(
                        best_model_state,
                        os.path.join(save_dir, "best_student_model.pth"),
                    )
                    logger.info(f"Saved new best model (val_loss={val_loss:.4f})")

            # Saving chechpoint at every 5 epoch
            if (epoch + 1) % 5 == 0 or (epoch + 1) == epochs:
                model_to_save = (
                    student_model.module
                    if isinstance(student_model, nn.DataParallel)
                    else student_model
                )
                torch.save(
                    {
                        "epoch": epoch + 1,
                        "model_state_dict": model_to_save.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "loss": avg_train_loss,
                    },
                    os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pth"),
                )

            torch.cuda.empty_cache()
            gc.collect()

        # Load best model
        if best_model_state:
            model_to_load = (
                student_model.module
                if isinstance(student_model, nn.DataParallel)
                else student_model
            )
            model_to_load.load_state_dict(best_model_state)
            logger.info(f"Loaded best model with validation loss: {best_val_loss:.4f}")

        return student_model, history

    except Exception as e:
        logger.error(f"Error during training: {e}")
        raise e

In [9]:
train_student_with_teacher(
    student_model=student_model,
    teacher_model=teacher_model,
    train_loader=train_loader,
    val_loader=val_loader,
)


2026-01-29 13:37:07,685 - INFO - AdaptiveWingLoss initialized with omega=14,              theta=0.5, epsilon=1.0, alpha=2.1, weight=10.0
Epoch 1/3:   0%|                                       | 0/4769 [00:00<?, ?it/s]
2026-01-29 13:37:07,697 - ERROR - Error during training: DataLoader worker (pid(s) 53512, 53513, 53514, 53515) exited unexpectedly


RuntimeError: DataLoader worker (pid(s) 53512, 53513, 53514, 53515) exited unexpectedly