In [None]:
# @title Google colab install requirements.
%%capture
import os
import sys

# Check if we are in Google Colab
if 'google.colab' in str(get_ipython()):
    print("Detected Google Colab. Cloning repository...")
    
    REPO_URL = "https://github.com/mwangi-clinton/2026-02-02-knowledge-distillation-for-2d-hpe.git"
    REPO_NAME = "2026-02-02-knowledge-distillation-for-2d-hpe"
    if not os.path.exists(REPO_NAME):
        !git clone $REPO_URL
        %cd $REPO_NAME
        !pip install -r requirements.txt

    sys.path.append(os.path.abspath("."))

# Knowledge Distillation for 2D Human Pose Estimation

## 1. Overview
Knowledge Distillation (KD) is a deep learning technique where a small, compact model (the **Student**) is trained to reproduce the behavior of a large, complex model (the **Teacher**). The core idea was popularized by Hinton et al. (2015) to transfer "knowledge" from a heavy ensemble or high-parameter model to a lightweight one suitable for edge devices.

In Human Pose Estimation (HPE):
* **Teacher:** High-performance models (e.g., HRNet-W48) that are accurate but computationally expensive.
* **Student:** Lightweight models (e.g., SqueezeNet or MobileNet) designed for real-time inference.



---

## 2. Representations in HPE: Gaussian Heatmaps
Modern HPE models typically predict **Gaussian Heatmaps** rather than direct $(x, y)$ coordinates. Each keypoint $k$ at position $(x_k, y_k)$ is represented as a probability map where the value at pixel $(i, j)$ is calculated as:

$$H_k(i, j) = \exp\left( -\frac{(i - x_k)^2 + (j - y_k)^2}{2\sigma^2} \right)$$

* **$\sigma$:** Controls the spread of the Gaussian peak.
* **Masks (Visibility):** A binary mask $M$ is used ($1$ if visible, $0$ if not). We multiply the loss by this mask to "switch off" learning for unlabelled or invisible joints.

---

## 3. The Role of Temperature ($T$)
The standard Softmax function often produces "sharp" distributions where one value nears $1.0$ and others near $0.0$:

$$\sigma(z_i) = \frac{\exp(z_i)}{\sum_j \exp(z_j)}$$

**The Problem:** If a Teacher is too confident (1.0 at peak), it provides no more information than the Ground Truth. We lose the **"Dark Knowledge"**—the subtle spatial relationships in the pixels surrounding the peak.

**The Solution:** By introducing a Temperature $T > 1$, we "soften" the distribution:

$$q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$

* **High T (> 1):** Flattens the peaks, spreading probability mass to neighboring pixels and revealing the Teacher's spatial uncertainty.
* **Low T (= 1):** Returns to the standard sharp Softmax.



---

## 4. Distillation Loss Formulation
To distill heatmaps, we treat them as **Spatial Probability Distributions**. We flatten the $H \times W$ dimensions and apply a Spatial Softmax with temperature.

### The KD Loss
The Kullback-Leibler (KL) Divergence between the softened student logits $z_s$ and teacher logits $z_t$ is scaled by $T^2$:

$$L_{KD} = T^2 \cdot KL\left( \sigma\left(\frac{z_s}{T}\right), \sigma\left(\frac{z_t}{T}\right) \right)$$

### Total Training Objective
The student is trained using a weighted sum of the Ground Truth (GT) loss and the Distillation loss:

$$L_{total} = L_{GT} + \alpha \cdot L_{KD}$$

---

## 5. Implementation: Spatial KL Divergence
In practice, the loss is calculated per channel (keypoint) and then averaged.

**For a single channel $k$:**
$$D_{KL}(P_k || \hat{P}_k) = \sum_{x=1}^{W} \sum_{y=1}^{H} P_k(x, y) \log \left( \frac{P_k(x, y)}{\hat{P}_k(x, y) + \epsilon} \right)$$

**Variable Definitions:**
* $P_k(x, y)$: The teacher's probability at pixel $(x, y)$.
* $\hat{P}_k(x, y)$: The student's predicted probability at pixel $(x, y)$.
* $\epsilon$: A tiny constant ($10^{-8}$) to prevent $\log(0)$.

**Final Averaged Loss:**
$$L_{KL} = \frac{1}{K} \sum_{k=1}^{K} D_{KL}(P_k || \hat{P}_k)$$

In [7]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from collections import namedtuple
import numpy as np
from typing import List, Tuple, Union, Dict
import cv2
import logging
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
from torch import Tensor
from typing import Optional, Callable
import os
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
import math
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import numpy as np
import cv2
import torch
from typing import List, Union, Tuple

#### Student model

We design a lightweight model with <a href="https://arxiv.org/pdf/1602.07360">SquuezeNet backbone</a> pretrained to extarct features from the an images. WSo we just remove tthe classification tail of the model and add a custom backbone for generating heatmaps.

In [26]:
KeypointOutput = namedtuple('KeypointOutput', ['heatmaps'])
class SqueezeNetHPE(nn.Module):
    def __init__(self, num_keypoints=17):
        super().__init__()
        squeezenet = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1)
        self.backbone = squeezenet.features
        self.relu = nn.ReLU(inplace=True)
        self.conv_heatmap = nn.Conv2d(
            in_channels=512,
            out_channels=num_keypoints,
            kernel_size=3,
            padding=1
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.relu(x)
        x = F.interpolate(
            x,
            size=(96, 72), 
            mode="bilinear",
            align_corners=False
        )
        heatmaps = self.conv_heatmap(x)

        return KeypointOutput(heatmaps=heatmaps)

1.  **Backbone (`self.backbone`)**: We use the feature extractor from `squeezenet1_1` pretrained on ImageNet. <br>
2.  **Upsampling**: SqueezeNet significantly downsamples the input (typically by 32x). To generate high-quality heatmaps (which require spatial precision), we use `F.interpolate` to upsample the features to **96x72**.
3.  **Heatmap Head (`self.conv_heatmap`)**: Instead of a Fully Connected layer for classification, we use a final **Convolutional Layer** (kernel size 3x3) to produce **17 output channels**, each corresponding to a keypoint heatmap.



#### Teacher model


In this tutprial we will use <a href="https://github.com/HRNet/HRNet-Human-Pose-Estimation?tab=readme-ov-file"> HRNet-pose as a teacher model </a>.

**HRNet** (High-Resolution Net) maintains high-resolution representations throughout the network. Unlike traditional architectures (ResNet, VGG) that downsample the image and then upsample, HRNet connects high-to-low resolution subnetworks in parallel.


**Role as Teacher:**
HRNet-pose is a state-of-the-art CNN based model for 2D Pose Estimation. 

In [22]:
from hrnet_pose.hrnetpose_model import get_pose_net
from hrnet_pose.configs import load_configs

cfg = load_configs('hrnet_pose/hrnet_w48_model_configs.yaml')
teacher_model = get_pose_net(cfg, is_train=False)
checkpoint = torch.load('hrnet_pose/hrnet_pose_models/pose_hrnet_w48_384x288.pth', map_location='cpu')
teacher_model.load_state_dict(checkpoint)


<All keys matched successfully>

# 2. Dataset

## 2.1 COCO Keypoint Dataset
The COCO (Common Objects in Context) dataset is a large-scale object detection, segmentation, and captioning dataset. <br>
For Human Pose Estimation, we use the **Keypoints** task, which annotates **17 keypoints** (joints) on human bodies.

The 17 Keypoints are:
0: Nose, 1: Left Eye, 2: Right Eye, 3: Left Ear, 4: Right Ear, 5: Left Shoulder, 6: Right Shoulder, 7: Left Elbow, 8: Right Elbow, 9: Left Wrist, 10: Right Wrist, 11: Left Hip, 12: Right Hip, 13: Left Knee, 14: Right Knee, 15: Left Ankle, 16: Right Ankle. <br>

The keypoints are connected to form a skeleton structure, helping in visualizing pose.

![COCO Skeleton](images/keypoints-skeleton.png) ![COCO Examples](images/keypoints-examples.png)

*Example of COCO Keypoint Annotations (Source: COCO Dataset)*

Read more about COCO dataset <a href="https://cocodataset.org/#keypoints-2017">here</a>.

### 2.1.1 Get data using fiftyone library

[FiftyOne](https://voxel51.com/fiftyone/) is an open-source toolset by Voxel51 designed for building high-quality datasets and computer vision models. It acts as a visual interface for your datasets, allowing you to:

- **Visualize** raw images and their annotations (bounding boxes, keypoints, etc.) instantly.
- **Query** specific samples (e.g., "show me all images with > 10 people").
- **Debug** model predictions by overlaying them on ground truth.
- **Manage** data splits and export to various formats (COCO, YOLO, TFRecord).

In [1]:
import fiftyone as fo
import fiftyone.zoo as foz

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
train_data = foz.load_zoo_dataset("coco-2017", split='train',max_samples=3000,  label_types=["keypoints"])

Downloading split 'train' to '/home/clinton-mwangi/fiftyone/coco-2017/train' if necessary


2026-02-01 17:58:59,125 - INFO - Downloading split 'train' to '/home/clinton-mwangi/fiftyone/coco-2017/train' if necessary


Downloading annotations to '/home/clinton-mwangi/fiftyone/coco-2017/tmp-download/annotations_trainval2017.zip'


2026-02-01 17:58:59,127 - INFO - Downloading annotations to '/home/clinton-mwangi/fiftyone/coco-2017/tmp-download/annotations_trainval2017.zip'


 100% |██████|    1.9Gb/1.9Gb [12.7m elapsed, 0s remaining, 6.9Mb/s]        


2026-02-01 18:11:39,733 - INFO -  100% |██████|    1.9Gb/1.9Gb [12.7m elapsed, 0s remaining, 6.9Mb/s]        


Extracting annotations to '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_train2017.json'


2026-02-01 18:11:39,734 - INFO - Extracting annotations to '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_train2017.json'


Downloading 3000 images


2026-02-01 18:11:58,984 - INFO - Downloading 3000 images


 100% |████████████████| 3000/3000 [14.4m elapsed, 0s remaining, 3.7 images/s]      


2026-02-01 18:26:22,450 - INFO -  100% |████████████████| 3000/3000 [14.4m elapsed, 0s remaining, 3.7 images/s]      


Writing annotations for 3000 downloaded samples to '/home/clinton-mwangi/fiftyone/coco-2017/train/labels.json'


2026-02-01 18:26:22,467 - INFO - Writing annotations for 3000 downloaded samples to '/home/clinton-mwangi/fiftyone/coco-2017/train/labels.json'


Dataset info written to '/home/clinton-mwangi/fiftyone/coco-2017/info.json'


2026-02-01 18:26:25,342 - INFO - Dataset info written to '/home/clinton-mwangi/fiftyone/coco-2017/info.json'


Loading 'coco-2017' split 'train'


2026-02-01 18:26:25,396 - INFO - Loading 'coco-2017' split 'train'


 100% |███████████████| 3000/3000 [1.1s elapsed, 0s remaining, 2.8K samples/s]         


2026-02-01 18:26:27,051 - INFO -  100% |███████████████| 3000/3000 [1.1s elapsed, 0s remaining, 2.8K samples/s]         


Dataset 'coco-2017-train-3000' created


2026-02-01 18:26:27,060 - INFO - Dataset 'coco-2017-train-3000' created


In [11]:
val_data = foz.load_zoo_dataset("coco-2017", split="validation",  max_samples=1000, label_types=["keypoints"])

Downloading split 'validation' to '/home/clinton-mwangi/fiftyone/coco-2017/validation' if necessary


2026-02-01 18:26:27,091 - INFO - Downloading split 'validation' to '/home/clinton-mwangi/fiftyone/coco-2017/validation' if necessary


Found annotations at '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_val2017.json'


2026-02-01 18:26:27,093 - INFO - Found annotations at '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_val2017.json'


Downloading 1000 images


2026-02-01 18:26:27,911 - INFO - Downloading 1000 images


 100% |████████████████| 1000/1000 [4.4m elapsed, 0s remaining, 5.0 images/s]      


2026-02-01 18:30:54,124 - INFO -  100% |████████████████| 1000/1000 [4.4m elapsed, 0s remaining, 5.0 images/s]      


Writing annotations for 1000 downloaded samples to '/home/clinton-mwangi/fiftyone/coco-2017/validation/labels.json'


2026-02-01 18:30:54,142 - INFO - Writing annotations for 1000 downloaded samples to '/home/clinton-mwangi/fiftyone/coco-2017/validation/labels.json'


Dataset info written to '/home/clinton-mwangi/fiftyone/coco-2017/info.json'


2026-02-01 18:30:54,361 - INFO - Dataset info written to '/home/clinton-mwangi/fiftyone/coco-2017/info.json'


Loading existing dataset 'coco-2017-validation-1000'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


2026-02-01 18:30:54,364 - INFO - Loading existing dataset 'coco-2017-validation-1000'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


In [13]:
train_data

Name:        coco-2017-train-3000
Media type:  image
Num samples: 3000
Persistent:  False
Tags:        []
Sample fields:
    id:               fiftyone.core.fields.ObjectIdField
    filepath:         fiftyone.core.fields.StringField
    tags:             fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)
    metadata:         fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.ImageMetadata)
    created_at:       fiftyone.core.fields.DateTimeField
    last_modified_at: fiftyone.core.fields.DateTimeField
    ground_truth:     fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Keypoints)

In [14]:
session = fo.launch_app(val_data, auto=False)


Session launched. Run `session.show()` to open the App in a cell output.


2026-02-01 18:37:10,064 - INFO - Session launched. Run `session.show()` to open the App in a cell output.


In [15]:
session.show()

### 2.1.2. Get data by direct download

In [None]:
# Path to store the data
DATA_DIR="/path/to/store/data"

!mkdir -p $DATA_DIR

!wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P $DATA_DIR
!wget -c http://images.cocodataset.org/zips/val2017.zip -P $DATA_DIR
!wget -c http://images.cocodataset.org/zips/train2017.zip -P $DATA_DIR

In [None]:
# Extracting the files
!unzip -q $DATA_DIR/annotations_trainval2017.zip -d $DATA_DIR
!unzip -q $DATA_DIR/val2017.zip -d $DATA_DIR
!unzip -q $DATA_DIR/train2017.zip -d $DATA_DIR

# Cleanup zip files
!rm -rf $DATA_DIR/*.zip

## 2.2 Data Prepocessing

We will follow the following data preprocessing that are unique to human pose estimation tasks based on paper: <a href="https://arxiv.org/abs/1911.07524"> The devils is in the details </a>.
1.  **Filtering**: We first check if an image actually contains a person with valid annotations. 
2.  **Cropping**: Once we identify a valid image, we locate the person using their bounding box and **crop** the image to center on them. 
3.  **Coordinate Transformation**: We transform the original keypoint annotations from the full image space into our new **cropped coordinate space**.
4.  **Ground Truth Generation**: Finally, we take these transformed keypoints and generate **Gaussian Heatmaps**, which serve as the training targets for our model.

In [16]:
class DataProcessor:
    def __init__(self, target_size=(288, 384), heatmap_size=(72, 96), sigma=2.0):
        self.target_size = target_size
        self.heatmap_size = heatmap_size
        self.sigma = sigma
        
        # Pre-compute the coordinate grid for heatmap generation
        W, H = self.heatmap_size
        self.yy, self.xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')

    @staticmethod
    def extract_keypoints_and_visibility(keypoints: List[float]):
        try:
            kps_array = np.array(keypoints, dtype=np.float32).reshape(-1, 3)
            coords = kps_array[:, :2]
            visibility = kps_array[:, 2]
            return coords, visibility.astype(np.int32)
        except Exception as e:
            print(f"Error extracting keypoints: {e}")
            raise e

    def process_image_and_keypoints(
        self,
        image: np.ndarray,
        keypoints: np.ndarray,
        bbox: Union[List[int], Tuple[int, int, int, int]],
        angle: float = 0,
        flip: bool = False
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Sequential process of proocessing the input image
        """
        try:
            x1, y1, w, h = np.round(bbox).astype(int)
            img_h, img_w = image.shape[:2]

            # Step 1: Safe Crop
            cx1, cy1 = max(0, x1), max(0, y1) #Why do we do this?
            cx2, cy2 = min(img_w, x1 + w), min(img_h, y1 + h)
            image = image[cy1:cy2, cx1:cx2]
            crop_h, crop_w = image.shape[:2]

            # Step 2: Prepare Keypoints
            kps = keypoints.copy().astype(np.float32)
            zero_mask = np.all(kps == 0, axis=1)
            kps -= [cx1, cy1]

            # Step 3: Boundary Masking (Vectorized)
            invalid_mask = (kps[:, 0] < 0) | (kps[:, 1] < 0) | (kps[:, 0] >= crop_w) | (kps[:, 1] >= crop_h)

            # Step 4: Resize (High Speed via OpenCV)
            image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_LINEAR)
            kps *= [self.target_size[0] / crop_w, self.target_size[1] / crop_h]

            # Step 5: Rotation
            if angle != 0:
                center = (self.target_size[0] / 2, self.target_size[1] / 2)
                M = cv2.getRotationMatrix2D(center, angle, 1.0)
                image = cv2.warpAffine(image, M, self.target_size)
                ones = np.ones((kps.shape[0], 1))
                kps_homo = np.hstack([kps, ones])
                kps = kps_homo @ M.T

            # Step 6: Horizontal Flip
            if flip:
                image = cv2.flip(image, 1)
                kps[:, 0] = self.target_size[0] - kps[:, 0]

            # Step 7: Restore invalid points to zero
            kps[zero_mask | invalid_mask] = 0

            return image, kps
        except Exception as e:
            print(f"Error in processing: {e}")
            raise e

    def generate_heatmaps(
        self,
        keypoints: np.ndarray,
        keypoints_visible: np.ndarray
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Highly optimized Gaussian Heatmap generation using pre-computed grids.
        """
        W_hm, H_hm = self.heatmap_size
        W_img, H_img = self.target_size
        scale_x = W_hm / W_img
        scale_y = H_hm / H_img
        kps_hm = keypoints * [scale_x, scale_y]

        mask = (keypoints_visible >= 0.5) & \
               (kps_hm[:, 0] >= 0) & (kps_hm[:, 0] < W_hm) & \
               (kps_hm[:, 1] >= 0) & (kps_hm[:, 1] < H_hm)
 
        num_kps = kps_hm.shape[0]
        heatmaps = np.zeros((num_kps, H_hm, W_hm), dtype=np.float32)
        for i in range(num_kps):
            if not mask[i]:
                continue
                
            mu_x, mu_y = kps_hm[i]
            dist_sq = (self.xx - mu_x) ** 2 + (self.yy - mu_y) ** 2
            heatmaps[i] = np.exp(-dist_sq / (2 * self.sigma**2))

        return torch.from_numpy(heatmaps), torch.from_numpy(mask.astype(np.float32))

### Dataloader

In [17]:
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def get_downloaded_image_ids(root_dir: str, coco: COCO) -> List[int]:
    on_disk = set(os.listdir(root_dir))
    all_img_info = coco.loadImgs(coco.getImgIds())
    existing_ids = [info['id'] for info in all_img_info if info['file_name'] in on_disk]
    return existing_ids

class CocoKeypoints(Dataset):
    def __init__(self, root: str, annFile: str, target_size=(288, 384), heatmap_size=(72, 96), sigma=2.0) -> None:
        self.root = root
        self.coco = COCO(annFile)
        self.processor = DataProcessor(
            target_size=target_size, 
            heatmap_size=heatmap_size, 
            sigma=sigma
        )
        
        on_disk = set(os.listdir(root))
        self.samples = []
        img_ids = self.coco.getImgIds()
        all_imgs = self.coco.loadImgs(img_ids)
        img_id_to_file = {img['id']: img['file_name'] for img in all_imgs if img['file_name'] in on_disk}

        for img_id in img_id_to_file:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            for ann in anns:
                if ann.get("num_keypoints", 0) > 0:
                    self.samples.append({
                        "file_name": img_id_to_file[img_id],
                        "keypoints": ann["keypoints"],
                        "bbox": ann["bbox"]
                    })

        logger.info(f"Initialized {len(self.samples)} samples.")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, index: int):
        sample = self.samples[index]
        img_path = os.path.join(self.root, sample["file_name"])
        image = cv2.imread(img_path)
        if image is None:
            return torch.zeros((3, *self.processor.target_size)), torch.zeros((17, *self.processor.heatmap_size)), 
            torch.zeros(17), torch.zeros((17, 2))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        coords, visibility = self.processor.extract_keypoints_and_visibility(sample["keypoints"])
        processed_img_np, processed_kps = self.processor.process_image_and_keypoints(
            image, 
            coords, 
            sample["bbox"]
        )
        heatmaps, masks = self.processor.generate_heatmaps(processed_kps, visibility)
        img_tensor = torch.from_numpy(processed_img_np.copy()).permute(2, 0, 1).float().div(255.0)
        coords_tensors = torch.from_numpy(processed_kps.copy()) 
        return img_tensor, heatmaps, masks, coords_tensors

def collate_fn(batch):
    return torch.utils.data.dataloader.default_collate(batch)

def get_coco_dataloaders(root_train, ann_train, root_val, ann_val, batch_size=32, num_workers=8):
    train_dataset = CocoKeypoints(root_train, ann_train)
    val_dataset = CocoKeypoints(root_val, ann_val)    
    loader_args = dict(
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True, 
        collate_fn=collate_fn,
        persistent_workers=True 
    )
    return (
        DataLoader(train_dataset, shuffle=True, **loader_args),
        DataLoader(val_dataset, shuffle=False, **loader_args))

In [18]:
zoo_dir = fo.config.dataset_zoo_dir
zoo_dir

'/home/clinton-mwangi/fiftyone'

In [19]:
#Fiftyone downloads the data in your home directory by default

coco_root = os.path.join(fo.config.dataset_zoo_dir, "coco-2017")
root_train = os.path.join(coco_root, "train", "data")
root_val = os.path.join(coco_root, "validation", "data")

ann_train = os.path.join(coco_root, "raw", "person_keypoints_train2017.json")
ann_val = os.path.join(coco_root, "raw", "person_keypoints_val2017.json")

train_loader, val_loader = get_coco_dataloaders(
    root_train=root_train,
    ann_train=ann_train,
    root_val=root_val,
    ann_val=ann_val,
    batch_size=30,
    num_workers=4
)

loading annotations into memory...
Done (t=4.91s)
creating index...


2026-02-01 18:41:31,248 - INFO - Initialized 3854 samples.


index created!
loading annotations into memory...


2026-02-01 18:41:31,518 - INFO - Initialized 1275 samples.


Done (t=0.22s)
creating index...
index created!


# 3. Training and distillation

## 3.1 Baseline Student Training

We start by training the **Student Model (SqueezeNet)** purely on the Ground Truth labels, **without any distillation**. on the sample data.


In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = SqueezeNetHPE(num_keypoints=17).to(device)

In [None]:
class BaselineTrainer:
    def __init__(self, model, device, 
                 checkpoint_dir='checkpoints_baseline',
                 lr=1e-3, num_epochs=10):
        self.device = device
        self.model = model.to(device)
        self.num_epochs = num_epochs
        self.checkpoint_dir = checkpoint_dir
        self.best_val_loss = float('inf')

        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.criterion = nn.KLDivLoss(reduction='none').to(device)        
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, weight_decay=1e-4)        
        
        # Scheduler: 5 epoch warmup + Cosine Decay
        lambda_lr = lambda e: (e + 1) / 5 if e < 5 else 0.5 * (1 + math.cos(math.pi * (e - 5) / (max(1, num_epochs - 5))))
        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda_lr)
        self.scaler = GradScaler()

    def train(self, train_loader, val_loader):
        print(f"--- Starting Training ---")
        for epoch in range(self.num_epochs):
            self.model.train()            
            running_loss = 0.0
            pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}')            
            
            for imgs, targets, masks, _ in pbar:
                imgs = imgs.to(self.device)
                targets = targets.to(self.device)
                masks = masks.to(self.device)
                
                self.optimizer.zero_grad(set_to_none=True)                
                with autocast():
                    outputs = self.model(imgs)
                    preds = outputs.heatmaps if hasattr(outputs, 'heatmaps') else outputs
                    B, K, H, W = preds.shape

                    flat_preds = preds.view(B, K, -1)
                    flat_targets = targets.view(B, K, -1)

                    log_probs = F.log_softmax(flat_preds, dim=-1)
                    target_probs = flat_targets / (flat_targets.sum(dim=-1, keepdim=True) + 1e-8)

                    loss_elementwise = self.criterion(log_probs, target_probs)
                    loss_per_keypoint = loss_elementwise.sum(dim=-1)

                    masked_loss = loss_per_keypoint * masks
                    loss = masked_loss.sum() / (masks.sum() + 1e-8)

                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                
                running_loss += loss.item()
                pbar.set_postfix({'loss': f"{loss.item():.4f}"})

            self.scheduler.step()
            if (epoch + 1) >= 5:
                val_loss = self.validate(val_loader)
                print(f"Epoch {epoch+1} | Val Loss: {val_loss:.6f}")
                
                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    save_path = os.path.join(self.checkpoint_dir, 'best_baseline_kl.pth')
                    torch.save(self.model.state_dict(), save_path)
                    print(f"New best model saved to {save_path}")
            else:
                print(f"Epoch {epoch+1} completed.")

    def validate(self, val_loader):
        self.model.eval()
        total_loss = 0.0
        with torch.no_grad():
            for imgs, targets, masks, _ in val_loader:
                imgs = imgs.to(self.device)
                targets = targets.to(self.device)
                masks = masks.to(self.device)
                
                outputs = self.model(imgs)
                preds = outputs.heatmaps if hasattr(outputs, 'heatmaps') else outputs
                
                B, K, H, W = preds.shape
                flat_preds = preds.view(B, K, -1)
                flat_targets = targets.view(B, K, -1)
                
                log_probs = F.log_softmax(flat_preds, dim=-1)
                target_probs = flat_targets / (flat_targets.sum(dim=-1, keepdim=True) + 1e-8)
                
                loss_elementwise = self.criterion(log_probs, target_probs)
                loss_per_keypoint = loss_elementwise.sum(dim=-1)
                
                masked_loss = loss_per_keypoint * masks
                loss = masked_loss.sum() / (masks.sum() + 1e-8)
                
                total_loss += loss.item()
                
        return total_loss / len(val_loader)

In [None]:
trainer = BaselineTrainer(
    model=student_model,           
    device=device,
    checkpoint_dir='experiments/run_01', 
    lr=1e-3,
    num_epochs=50              
)
trainer.train(train_loader, val_loader)

## 3.2. Logits-based Knowledge Distillation

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import os
import math

class LogitDistillationTrainer:
    def __init__(self, student_model, teacher_model, device, 
                 checkpoint_dir='checkpoints_distillation',
                 lr=1e-3, num_epochs=10, temp=4.0, alpha_logit=1.0):
        self.device = device
        self.student = student_model.to(device)
        self.teacher = teacher_model.to(device).eval() 
        self.num_epochs = num_epochs
        self.checkpoint_dir = checkpoint_dir
        
        # Distillation Hyperparameters
        self.temp = temp
        self.alpha_logit = alpha_logit
        self.best_val_loss = float('inf')

        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        self.kl_criterion = nn.KLDivLoss(reduction='none').to(device)
        
        self.optimizer = optim.AdamW(self.student.parameters(), lr=lr, weight_decay=1e-4)
        
        lambda_lr = lambda e: (e + 1) / 5 if e < 5 else 0.5 * (1 + math.cos(math.pi * (e - 5) / (max(1, num_epochs - 5))))
        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda_lr)
        self.scaler = GradScaler()

    def spatial_distillation_loss(self, s_logits, t_logits, masks):
        """
        Implements the Spatial Softmax KL Divergence theory.
        """
        B, K, H, W = s_logits.shape
        
        # 1. Flatten spatial dimensions
        s_logits_flat = s_logits.view(B, K, -1)
        t_logits_flat = t_logits.view(B, K, -1)
        
        # 2. Soften and compute distributions
        s_log_prob = F.log_softmax(s_logits_flat / self.temp, dim=-1)
        t_prob = F.softmax(t_logits_flat / self.temp, dim=-1)
        
        # 3. Compute KL Divergence scaled by T^2
        kl_loss = self.kl_criterion(s_log_prob, t_prob).sum(dim=-1) 
        
        # 4. Mask and average over valid keypoints
        masked_kl = kl_loss * masks
        return (masked_kl.sum() / (masks.sum() + 1e-8)) * (self.temp ** 2)

    def train(self, train_loader, val_loader):
        print(f"--- Starting distillation training (T={self.temp}) ---")
        for epoch in range(self.num_epochs):
            self.student.train()
            running_loss = 0.0
            pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}')
            
            for imgs, targets, masks, _ in pbar:
                imgs, targets, masks = imgs.to(self.device), targets.to(self.device), masks.to(self.device)
                
                self.optimizer.zero_grad(set_to_none=True)
                
                with autocast():
                    s_out = self.student(imgs)
                    s_hms = s_out.heatmaps 
                    
                    with torch.no_grad():
                        t_out = self.teacher(imgs)
                        t_hms = t_out.heatmaps if hasattr(t_out, 'heatmaps') else t_out

                    B, K, _, _ = s_hms.shape
                    s_log_probs_gt = F.log_softmax(s_hms.view(B, K, -1), dim=-1)
                    t_probs_gt = targets.view(B, K, -1) / (targets.view(B, K, -1).sum(dim=-1, keepdim=True) + 1e-8)
                    
                    loss_gt = (self.kl_criterion(s_log_probs_gt, t_probs_gt).sum(dim=-1) * masks).sum() / (masks.sum() + 1e-8)

                    loss_kd = self.spatial_distillation_loss(s_hms, t_hms, masks)

                    total_loss = loss_gt + (self.alpha_logit * loss_kd)

                self.scaler.scale(total_loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                
                running_loss += total_loss.item()
                pbar.set_postfix({'loss': f"{total_loss.item():.4f}", 'kd': f"{loss_kd.item():.4f}"})

            self.scheduler.step()

            if (epoch + 1) >= 5:
                val_loss = self.validate(val_loader)
                print(f"Epoch {epoch+1} | Val Loss: {val_loss:.6f}")
                
                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    save_path = os.path.join(self.checkpoint_dir, f'best_logit_kd_T{self.temp}.pth')
                    torch.save(self.student.state_dict(), save_path)
                    print(f"New best model saved to {save_path}")
            else:
                print(f"Epoch {epoch+1} finished. Validation starts at epoch 5.")

    def validate(self, val_loader):
        self.student.eval()
        total_loss = 0.0
        with torch.no_grad():
            for imgs, targets, masks, _ in val_loader:
                imgs, targets, masks = imgs.to(self.device), targets.to(self.device), masks.to(self.device)
                
                outputs = self.student(imgs)
                preds = outputs.heatmaps if hasattr(outputs, 'heatmaps') else outputs
                
                B, K = preds.shape[0], preds.shape[1]
                log_probs = F.log_softmax(preds.view(B, K, -1), dim=-1)
                target_probs = targets.view(B, K, -1) / (targets.view(B, K, -1).sum(dim=-1, keepdim=True) + 1e-8)
                
                loss = (self.kl_criterion(log_probs, target_probs).sum(dim=-1) * masks).sum() / (masks.sum() + 1e-8)
                total_loss += loss.item()
                
        return total_loss / len(val_loader)

In [None]:
student_model_logits = SqueezeNetHPE(num_keypoints=17).to(device)

trainer_logits = LogitDistillationTrainer(
    student_model=student_model_logits, 
    teacher_model=teacher_model, 
    device=device, 
    checkpoint_dir='checkpoints_logits_kd',
    lr=1e-3, 
    num_epochs=50, 
    temp=4.0, 
    alpha_logit=1.0
)
trainer_logits.train(train_loader, val_loader)

## Evaluation
At this part we will do a simple evaluatin of the models.

1.  **Qualitative Accuracy (Visual)** (Using FiftyOne to visualize model predictions vs Ground Truth.)
2.  **Model Complexity** (FLOPs and Parameters.)
3.  **Coco performance metrics** - Average Precision


In [None]:
class ModelEvaluator:
    def __init__(self, device: torch.device):
        self.device = device

    def count_trainable_parameters(self, model: nn.Module, model_name: str = "Model") -> int:
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_size_mb = (sum(p.numel() for p in model.parameters()) * 4) / (1024**2)
        print(f"--- {model_name} Profile ---")
        print(f"Trainable Parameters: {trainable_params:,}")
        print(f"Model Size: {total_size_mb:.2f} MB\n")
        return trainable_params

    def compute_model_complexity(self, model: nn.Module, input_shape: Tuple[int, int, int], name_of_model: str = "Model"):
        try:
            from ptflops import get_model_complexity_info
            silent_buffer = io.StringIO()
            model.eval()
            macs, _ = get_model_complexity_info(model, input_shape, as_strings=False, 
                                               print_per_layer_stat=False, verbose=False, ost=silent_buffer)
            results = {"GMACs": macs / 1e9, "GFLOPs": (2 * macs) / 1e9}
            print(f"--- {name_of_model} Complexity ---")
            print(f"MACs: {results['GMACs']:.3f} G | FLOPs: {results['GFLOPs']:.3f} G\n")
            return results
        except ImportError:
            return None

    @torch.no_grad()
    def add_predictions_to_fiftyone(self, dataset, model, num_samples=20, batch_size=8, 
                                   input_size=(288, 384), heatmap_size=(72, 96), field_name="predictions"):
        model.eval().to(self.device)
        view = dataset.take(num_samples)
        samples_list = list(view)
        for i in tqdm(range(0, num_samples, batch_size), desc=f"Predicting {field_name}"):
            batch_samples = samples_list[i : i + batch_size]
            batch_imgs = [cv2.resize(cv2.cvtColor(cv2.imread(s.filepath), cv2.COLOR_BGR2RGB), input_size) for s in batch_samples]
            inputs = torch.from_numpy(np.stack(batch_imgs)).permute(0, 3, 1, 2).float().to(self.device) / 255.0
            outputs = model(inputs)
            hms = outputs.heatmaps if hasattr(outputs, "heatmaps") else outputs
            preds, maxvals = self._get_max_preds_vectorized(hms.cpu().numpy())
            for j, sample in enumerate(batch_samples):
                norm_points = (preds[j] / [heatmap_size[0], heatmap_size[1]]).clip(0, 1).tolist()
                sample[field_name] = fo.Keypoints(keypoints=[fo.Keypoint(points=norm_points, confidence=maxvals[j].flatten().tolist())])
                sample.save()
        return view

    def _get_max_preds_vectorized(self, batch_heatmaps):
        B, K, H, W = batch_heatmaps.shape
        heatmaps_reshaped = batch_heatmaps.reshape((B, K, -1))
        idx = np.argmax(heatmaps_reshaped, axis=2).reshape((B, K, 1))
        maxvals = np.amax(heatmaps_reshaped, axis=2).reshape((B, K, 1))
        preds = np.zeros((B, K, 2), dtype=np.float32)
        preds[:, :, 0], preds[:, :, 1] = idx[:, :, 0] % W, idx[:, :, 0] // W
        return preds, maxvals

In [None]:
## Adding skeleton for visualization
skeleton = fo.KeypointSkeleton(
    labels=[
        "nose", "l_eye", "r_eye", "l_ear", "r_ear", 
        "l_shoulder", "r_shoulder", "l_elbow", "r_elbow", 
        "l_wrist", "r_wrist", "l_hip", "r_hip", 
        "l_knee", "r_knee", "l_ankle", "r_ankle"
    ],
    edges=[
        [0, 1], [0, 2], [1, 3], [2, 4], # Head
        [5, 6], [5, 7], [7, 9], [6, 8], [8, 10], # Arms
        [5, 11], [6, 12], [11, 12], # Torso
        [11, 13], [13, 15], [12, 14], [14, 16] # Legs
    ]
)
val_data.default_skeleton = skeleton
val_data.save()

In [None]:
evaluator = ModelEvaluator(device=device)


evaluator.count_trainable_parameters(teacher_model, "Teacher (HRNet)")
evaluator.compute_model_complexity(teacher_model, (3, 288, 384), name_of_model="Teacher (HRNet)")
evaluator.add_predictions_to_fiftyone(val_data, teacher_model, field_name="teacher_preds")


student_model_logits = SqueezeNetHPE(num_keypoints=17).to(device)
student_model_logits.load_state_dict(torch.load('checkpoints_logits_kd/best_logit_kd_T4.0.pth', map_location=device))
evaluator.count_trainable_parameters(student_model_logits, "Student (Logit-KD)")
evaluator.add_predictions_to_fiftyone(val_data, student_model_logits, field_name="student_logit_kd")


student_model_baseline = SqueezeNetHPE(num_keypoints=17).to(device)
student_model_baseline.load_state_dict(torch.load('checkpoints_baseline_kd/run_01/best_baseline_kl.pth', map_location=device))
evaluator.count_trainable_parameters(student_model_baseline, "Student baseline")
evaluator.add_predictions_to_fiftyone(val_data, student_model_baseline, field_name="student_baseline_kd")

# --- Launch Visualization ---
session = fo.launch_app(val_data)

In [None]:
# Get a single batch
images, targets, masks, _ = next(iter(val_loader))
images = images.to(device)

teacher_model.eval()
student_model_logits.eval()
student_model_baseline.eval()

with torch.no_grad():
    t_out = teacher_model(images)
    t_hms = t_out.heatmaps if hasattr(t_out, 'heatmaps') else t_out
    t_preds, _ = evaluator._get_max_preds_vectorized(t_hms.cpu().numpy())

    s_logit_out = student_model_logits(images)
    s_logit_hms = s_logit_out.heatmaps 
    s_logit_preds, _ = evaluator._get_max_preds_vectorized(s_logit_hms.cpu().numpy())

    s_base_out = student_model_baseline(images)
    s_base_hms = s_base_out.heatmaps if hasattr(s_base_out, 'heatmaps') else s_base_out
    s_base_preds, _ = evaluator._get_max_preds_vectorized(s_base_hms.cpu().numpy())


print("--- Keypoint Predictions for First Image ---")
print("Teacher Keypoints:\n", t_preds[0])
print("\nStudent (Logit-KD) Keypoints:\n", s_logit_preds[0])
print("\nStudent (Baseline) Keypoints:\n", s_base_preds[0])

## 3.3. Feature-based Knowledge Distillation

Recently, it has been understood that rich information is normally located in the **intermediate layers** of a model. While logits capture the final outcome, intermediate features capture the structural and semantic hierarchies (e.g., edges, textures, and joint relationships) that the teacher has learned.

Feature-based KD encourages the Student to learn intermediate representations that resemble the Teacher's. Since Student and Teacher features often have different dimensions in terms of both channels and spatial resolution, we typically use a **Connector** (also known as an **Adaptor**, e.g., $1 \times 1$ Conv) to map Student features into the Teacher's feature space.

**Feature Extraction**
We utilize **Forward Hooks** to extract these intermediate feature maps during the forward pass without requiring structural modifications to the original model architectures.

**Loss Function**
The distillation loss for features is usually calculated using the Mean Squared Error (MSE) between the transformed student features and the teacher features. The total objective is defined as:

$$L_{total} = L_{GT} + \beta \cdot L_{Feat}(F_{Adaptor}(F_S), F_T)$$

Where:
* $L_{GT}$ is the standard ground truth loss (e.g., MSE or KL Divergence).
* $F_S$ and $F_T$ are the intermediate feature maps from the Student and Teacher, respectively.
* $F_{Adaptor}$ is a learnable transformation (typically a $1 \times 1$ Convolution) used to align the channel dimensions.
* $\beta$ is the weighting hyperparameter for the feature distillation task.

In [None]:
from helper.cka_implementation import CKACalculator
from helper.cka_plots import plot_cka_heatmap


In [None]:
# Ensure models are in eval mode
teacher_model.eval()
student_model_logits.eval()

print("Calculating CKA: Teacher vs Student (Logit KD)...")
cka_calc_logit = CKACalculator(
    model1=teacher_model,
    model2=student_model_logits,
    dataloader=val_loader,
    num_epochs=1,
    is_main_process=True
)
cka_matrix_logit = cka_calc_logit.calculate_cka_matrix()

plot_cka_heatmap(
    cka_matrix=cka_matrix_logit,
    model1_name='Teacher (HRNet)',
    model2_name='Student (Logit KD)',
    title="CKA: Teacher vs Student (Logit KD)",
    save_path='cka_teacher_vs_logit_kd.png'
)

In [None]:
student_model_baseline.eval()

print("Calculating CKA: Teacher vs Student (Baseline)...")
cka_calc_base = CKACalculator(
    model1=teacher_model,
    model2=student_model_baseline,
    dataloader=val_loader,
    num_epochs=1,
    is_main_process=True
)
cka_matrix_base = cka_calc_base.calculate_cka_matrix()

plot_cka_heatmap(
    cka_matrix=cka_matrix_base,
    model1_name='Teacher (HRNet)',
    model2_name='Student (Baseline)',
    title="CKA: Teacher vs Student (Baseline)",
    save_path='cka_teacher_vs_baseline.png'
)