# Knowledge Distillation for 2D Human Pose Estimation

## What is Knowledge Distillation?
Knowledge Distillation (KD) is a technique in deep learning where a small, compact model (the **Student**) is trained to reproduce the behavior of a large, complex model (the **Teacher**), or an ensemble of models. The core idea was popularized by Hinton et al. in ["Distilling the Knowledge in a Neural Network" (2015)](https://arxiv.org/abs/1503.02531).

## Why use KD?
1. **Model Compression**: Reduce model size and inference latency.
2. **Accuracy Boost**: The Student often learns better from the Teacher's "soft targets" (probability distributions or heatmaps) than from one-hot ground truth labels alone, as the Teacher provides structural information about the data.

## Types of Knowledge Distillation covered here:
1. **Response-based (Logits) KD**: The Student mimics the final output (heatmaps) of the Teacher.
2. **Feature-based KD**: The Student mimics the intermediate feature maps of the Teacher, learning to extract similar spatial features.

---


# 1. Introduction.

In Human Pose Estimation (HPE), the Teacher is typically a high-performance model (like HRNet-W48) that is accurate but computationally expensive. The Student is a lightweight model (like SqueezeNet or MobileNet) designed for real-time inference on edge devices.

Nowadays, **Gaussian Heatmaps** are the standard representation for 2D Pose Estimation. Instead of regressing exact (x, y) coordinates directly (which is hard to learn), the model predicts a probability map where the peak represents the keypoint location.

**Generation Formula:**
For a keypoint $k$ at $(x_k, y_k)$, the value at pixel $(i, j)$ is:

$$H_k(i, j) = \exp\left( -\frac{(i - x_k)^2 + (j - y_k)^2}{2\sigma^2} \right)$$

Where $\sigma$ controls the spread of the Gaussian peak.

**Masks (Visibility):**
We also associate a binary **Mask** with each keypoint. This mask is **1** if the keypoint is visible and annotated, and **0** otherwise. During training, we multiply the loss by this mask to essentially 'switch off' the learning for missing or invisible joints, preventing the model from being confused by unlabelled data.


In [1]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from collections import namedtuple
import numpy as np
from typing import List, Tuple, Union, Dict
import cv2
import logging
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
from torch import Tensor
from typing import Optional, Callable
import os
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
import math
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from ptflops import get_model_complexity_info
import heatmaps_to_keypoints
import numpy as np
import cv2
import torch
from typing import List, Union, Tuple

  from .autonotebook import tqdm as notebook_tqdm


#### Student model

We design a lightweight model with <a href="https://arxiv.org/pdf/1602.07360">SquuezeNet backbone</a> pretrained to extarct features from the an images. WSo we just remove tthe classification tail of the model and add a custom backbone for generating heatmaps.

In [19]:
KeypointOutput = namedtuple('KeypointOutput', ['heatmaps'])
class SqueezeNetHPE(nn.Module):
    def __init__(self, num_keypoints=17):
        super().__init__()
        squeezenet = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1)
        self.backbone = squeezenet.features
        self.relu = nn.ReLU(inplace=True)
        self.conv_heatmap = nn.Conv2d(
            in_channels=512,
            out_channels=num_keypoints,
            kernel_size=3,
            padding=1
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.relu(x)
        x = F.interpolate(
            x,
            size=(96, 72), 
            mode="bilinear",
            align_corners=False
        )
        heatmaps = self.conv_heatmap(x)

        return KeypointOutput(heatmaps=heatmaps)

1.  **Backbone (`self.backbone`)**: We use the feature extractor from `squeezenet1_1` pretrained on ImageNet. <br>
2.  **Upsampling**: SqueezeNet significantly downsamples the input (typically by 32x). To generate high-quality heatmaps (which require spatial precision), we use `F.interpolate` to upsample the features to **96x72**.
3.  **Heatmap Head (`self.conv_heatmap`)**: Instead of a Fully Connected layer for classification, we use a final **Convolutional Layer** (kernel size 3x3) to produce **17 output channels**, each corresponding to a keypoint heatmap.



#### Teacher model


In this tutprial we will use <a href="https://github.com/HRNet/HRNet-Human-Pose-Estimation?tab=readme-ov-file"> HRNet-pose as a teacher model </a>.

**HRNet** (High-Resolution Net) maintains high-resolution representations throughout the network. Unlike traditional architectures (ResNet, VGG) that downsample the image and then upsample, HRNet connects high-to-low resolution subnetworks in parallel.


**Role as Teacher:**
HRNet-pose is a state-of-the-art CNN based model for 2D Pose Estimation. 

In [20]:
from hrnet_pose.hrnetpose_model import get_pose_net
from hrnet_pose.configs import load_configs

cfg = load_configs('hrnet_pose/hrnet_w48_model_configs.yaml')
teacher_model = get_pose_net(cfg, is_train=False)
checkpoint = torch.load('hrnet_pose/hrnet_pose_models/pose_hrnet_w48_384x288.pth', map_location='cpu')
teacher_model.load_state_dict(checkpoint)


<All keys matched successfully>

# 2. Dataset

## 2.1 COCO Keypoint Dataset
The COCO (Common Objects in Context) dataset is a large-scale object detection, segmentation, and captioning dataset. <br>
For Human Pose Estimation, we use the **Keypoints** task, which annotates **17 keypoints** (joints) on human bodies.

The 17 Keypoints are:
0: Nose, 1: Left Eye, 2: Right Eye, 3: Left Ear, 4: Right Ear, 5: Left Shoulder, 6: Right Shoulder, 7: Left Elbow, 8: Right Elbow, 9: Left Wrist, 10: Right Wrist, 11: Left Hip, 12: Right Hip, 13: Left Knee, 14: Right Knee, 15: Left Ankle, 16: Right Ankle. <br>

The keypoints are connected to form a skeleton structure, helping in visualizing pose.

![COCO Skeleton](images/keypoints-skeleton.png) ![COCO Examples](images/keypoints-examples.png)

*Example of COCO Keypoint Annotations (Source: COCO Dataset)*

Read more about COCO dataset <a href="https://cocodataset.org/#keypoints-2017">here</a>.

### 2.1.1 Get data using fiftyone library

[FiftyOne](https://voxel51.com/fiftyone/) is an open-source toolset by Voxel51 designed for building high-quality datasets and computer vision models. It acts as a visual interface for your datasets, allowing you to:

- **Visualize** raw images and their annotations (bounding boxes, keypoints, etc.) instantly.
- **Query** specific samples (e.g., "show me all images with > 10 people").
- **Debug** model predictions by overlaying them on ground truth.
- **Manage** data splits and export to various formats (COCO, YOLO, TFRecord).

In [4]:
import fiftyone as fo
import fiftyone.zoo as foz

In [5]:
train_data = foz.load_zoo_dataset("coco-2017", split='train',max_samples=2000,  label_types=["keypoints"])

Downloading split 'train' to '/home/clinton-mwangi/fiftyone/coco-2017/train' if necessary
Found annotations at '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_train2017.json'
Sufficient images already downloaded
Existing download of split 'train' is sufficient
Loading existing dataset 'coco-2017-train-2000'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


In [6]:
val_data = foz.load_zoo_dataset("coco-2017", split="validation",  max_samples=1000, label_types=["keypoints"])

Downloading split 'validation' to '/home/clinton-mwangi/fiftyone/coco-2017/validation' if necessary
Found annotations at '/home/clinton-mwangi/fiftyone/coco-2017/raw/instances_val2017.json'
Sufficient images already downloaded
Existing download of split 'validation' is sufficient
Loading existing dataset 'coco-2017-validation-1000'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


In [7]:
train_data

Name:        coco-2017-train-2000
Media type:  None
Num samples: 0
Persistent:  False
Tags:        []
Sample fields:
    id:               fiftyone.core.fields.ObjectIdField
    filepath:         fiftyone.core.fields.StringField
    tags:             fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)
    metadata:         fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)
    created_at:       fiftyone.core.fields.DateTimeField
    last_modified_at: fiftyone.core.fields.DateTimeField

In [8]:
session = fo.launch_app(val_data, auto=False)


Session launched. Run `session.show()` to open the App in a cell output.


In [9]:
session.show()

### 2.1.2. Get data by direct download

In [None]:
# Path to store the data
DATA_DIR="/path/to/store/data"

!mkdir -p $DATA_DIR

!wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P $DATA_DIR
!wget -c http://images.cocodataset.org/zips/val2017.zip -P $DATA_DIR
!wget -c http://images.cocodataset.org/zips/train2017.zip -P $DATA_DIR

In [None]:
# Extracting the files
!unzip -q $DATA_DIR/annotations_trainval2017.zip -d $DATA_DIR
!unzip -q $DATA_DIR/val2017.zip -d $DATA_DIR
!unzip -q $DATA_DIR/train2017.zip -d $DATA_DIR

# Cleanup zip files
!rm -rf $DATA_DIR/*.zip

## 2.2 Data Prepocessing

We will follow the following data preprocessing that are unique to human pose estimation tasks based on paper: <a href="https://arxiv.org/abs/1911.07524"> The devils is in the details </a>.
1.  **Filtering**: We first check if an image actually contains a person with valid annotations. 
2.  **Cropping**: Once we identify a valid image, we locate the person using their bounding box and **crop** the image to center on them. 
3.  **Coordinate Transformation**: We transform the original keypoint annotations from the full image space into our new **cropped coordinate space**.
4.  **Ground Truth Generation**: Finally, we take these transformed keypoints and generate **Gaussian Heatmaps**, which serve as the training targets for our model.

In [21]:
class DataProcessor:
    def __init__(self, target_size=(288, 384), heatmap_size=(72, 96), sigma=2.0):
        self.target_size = target_size
        self.heatmap_size = heatmap_size
        self.sigma = sigma
        
        # Pre-compute the coordinate grid for heatmap generation
        W, H = self.heatmap_size
        self.yy, self.xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')

    @staticmethod
    def extract_keypoints_and_visibility(keypoints: List[float]):
        try:
            kps_array = np.array(keypoints, dtype=np.float32).reshape(-1, 3)
            coords = kps_array[:, :2]
            visibility = kps_array[:, 2]
            return coords, visibility.astype(np.int32)
        except Exception as e:
            print(f"Error extracting keypoints: {e}")
            raise e

    def process_image_and_keypoints(
        self,
        image: np.ndarray,
        keypoints: np.ndarray,
        bbox: Union[List[int], Tuple[int, int, int, int]],
        angle: float = 0,
        flip: bool = False
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Sequential process of proocessing the input image
        """
        try:
            x1, y1, w, h = np.round(bbox).astype(int)
            img_h, img_w = image.shape[:2]

            # Step 1: Safe Crop
            cx1, cy1 = max(0, x1), max(0, y1) #Why do we do this?
            cx2, cy2 = min(img_w, x1 + w), min(img_h, y1 + h)
            image = image[cy1:cy2, cx1:cx2]
            crop_h, crop_w = image.shape[:2]

            # Step 2: Prepare Keypoints
            kps = keypoints.copy().astype(np.float32)
            zero_mask = np.all(kps == 0, axis=1)
            kps -= [cx1, cy1]

            # Step 3: Boundary Masking (Vectorized)
            invalid_mask = (kps[:, 0] < 0) | (kps[:, 1] < 0) | (kps[:, 0] >= crop_w) | (kps[:, 1] >= crop_h)

            # Step 4: Resize (High Speed via OpenCV)
            image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_LINEAR)
            kps *= [self.target_size[0] / crop_w, self.target_size[1] / crop_h]

            # Step 5: Rotation
            if angle != 0:
                center = (self.target_size[0] / 2, self.target_size[1] / 2)
                M = cv2.getRotationMatrix2D(center, angle, 1.0)
                image = cv2.warpAffine(image, M, self.target_size)
                ones = np.ones((kps.shape[0], 1))
                kps_homo = np.hstack([kps, ones])
                kps = kps_homo @ M.T

            # Step 6: Horizontal Flip
            if flip:
                image = cv2.flip(image, 1)
                kps[:, 0] = self.target_size[0] - kps[:, 0]

            # Step 7: Restore invalid points to zero
            kps[zero_mask | invalid_mask] = 0

            return image, kps
        except Exception as e:
            print(f"Error in processing: {e}")
            raise e

    def generate_heatmaps(
        self,
        keypoints: np.ndarray,
        keypoints_visible: np.ndarray
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Highly optimized Gaussian Heatmap generation using pre-computed grids.
        """
        W_hm, H_hm = self.heatmap_size
        W_img, H_img = self.target_size
        scale_x = W_hm / W_img
        scale_y = H_hm / H_img
        kps_hm = keypoints * [scale_x, scale_y]

        mask = (keypoints_visible >= 0.5) & \
               (kps_hm[:, 0] >= 0) & (kps_hm[:, 0] < W_hm) & \
               (kps_hm[:, 1] >= 0) & (kps_hm[:, 1] < H_hm)
 
        num_kps = kps_hm.shape[0]
        heatmaps = np.zeros((num_kps, H_hm, W_hm), dtype=np.float32)
        for i in range(num_kps):
            if not mask[i]:
                continue
                
            mu_x, mu_y = kps_hm[i]
            dist_sq = (self.xx - mu_x) ** 2 + (self.yy - mu_y) ** 2
            heatmaps[i] = np.exp(-dist_sq / (2 * self.sigma**2))

        return torch.from_numpy(heatmaps), torch.from_numpy(mask.astype(np.float32))

### Dataloader

In [25]:
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def get_downloaded_image_ids(root_dir: str, coco: COCO) -> List[int]:
    on_disk = set(os.listdir(root_dir))
    all_img_info = coco.loadImgs(coco.getImgIds())
    existing_ids = [info['id'] for info in all_img_info if info['file_name'] in on_disk]
    return existing_ids

class CocoKeypoints(Dataset):
    def __init__(self, root: str, annFile: str, target_size=(288, 384), heatmap_size=(72, 96), sigma=2.0) -> None:
        self.root = root
        self.coco = COCO(annFile)
        self.processor = DataProcessor(
            target_size=target_size, 
            heatmap_size=heatmap_size, 
            sigma=sigma
        )
        
        on_disk = set(os.listdir(root))
        self.samples = []
        img_ids = self.coco.getImgIds()
        all_imgs = self.coco.loadImgs(img_ids)
        img_id_to_file = {img['id']: img['file_name'] for img in all_imgs if img['file_name'] in on_disk}

        for img_id in img_id_to_file:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            for ann in anns:
                if ann.get("num_keypoints", 0) > 0:
                    self.samples.append({
                        "file_name": img_id_to_file[img_id],
                        "keypoints": ann["keypoints"],
                        "bbox": ann["bbox"]
                    })

        logger.info(f"Initialized {len(self.samples)} samples.")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, index: int):
        sample = self.samples[index]
        img_path = os.path.join(self.root, sample["file_name"])
        image = cv2.imread(img_path)
        if image is None:
            return torch.zeros((3, *self.processor.target_size)), torch.zeros((17, *self.processor.heatmap_size)), 
            torch.zeros(17), torch.zeros((17, 2))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        coords, visibility = self.processor.extract_keypoints_and_visibility(sample["keypoints"])
        processed_img_np, processed_kps = self.processor.process_image_and_keypoints(
            image, 
            coords, 
            sample["bbox"]
        )
        heatmaps, masks = self.processor.generate_heatmaps(processed_kps, visibility)
        img_tensor = torch.from_numpy(processed_img_np.copy()).permute(2, 0, 1).float().div(255.0)
        coords_tensors = torch.from_numpy(processed_kps.copy()) 
        return img_tensor, heatmaps, masks, coords_tensors

def collate_fn(batch):
    return torch.utils.data.dataloader.default_collate(batch)

def get_coco_dataloaders(root_train, ann_train, root_val, ann_val, batch_size=32, num_workers=8):
    train_dataset = CocoKeypoints(root_train, ann_train)
    val_dataset = CocoKeypoints(root_val, ann_val)    
    loader_args = dict(
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True, 
        collate_fn=collate_fn,
        persistent_workers=True 
    )
    return (
        DataLoader(train_dataset, shuffle=True, **loader_args),
        DataLoader(val_dataset, shuffle=False, **loader_args))

In [26]:
zoo_dir = fo.config.dataset_zoo_dir
zoo_dir

'/home/clinton-mwangi/fiftyone'

In [27]:
#Fiftyone downloads the data in your home directory by default

coco_root = os.path.join(fo.config.dataset_zoo_dir, "coco-2017")
root_train = os.path.join(coco_root, "train", "data")
root_val = os.path.join(coco_root, "validation", "data")

ann_train = os.path.join(coco_root, "raw", "person_keypoints_train2017.json")
ann_val = os.path.join(coco_root, "raw", "person_keypoints_val2017.json")

train_loader, val_loader = get_coco_dataloaders(
    root_train=root_train,
    ann_train=ann_train,
    root_val=root_val,
    ann_val=ann_val,
    batch_size=32,
    num_workers=4
)

loading annotations into memory...
Done (t=9.33s)
creating index...
index created!


2026-02-01 17:16:07,226 - INFO - Initialized 38151 samples.


loading annotations into memory...


2026-02-01 17:16:07,706 - INFO - Initialized 6352 samples.


Done (t=0.39s)
creating index...
index created!


# 3. Training and distillation

## 3.1 Baseline Student Training

We start by training the **Student Model (SqueezeNet)** purely on the Ground Truth labels, **without any distillation**. on the sample data.

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = SqueezeNetHPE(num_keypoints=17).to(device)

### Training class

In [16]:
class PoseDistillationTrainer:
    def __init__(
        self, 
        student_model, 
        device, 
        teacher_model=None,
        mode='baseline', 
        checkpoint_dir='checkpoints',
        lr=1e-3,
        num_epochs=10,
        temp=4.0,
        alpha_logit=1.0,
        alpha_feat=0.5,
        student_layer=None,
        teacher_layer=None,
        s_channels=None,
        t_channels=None
    ):
        self.device = device
        self.mode = mode
        self.student = student_model.to(device)
        self.teacher = teacher_model.to(device).eval() if teacher_model else None
        self.num_epochs = num_epochs
        self.checkpoint_dir = checkpoint_dir
        
        # KD Parameters
        self.temp = temp
        self.alpha_logit = alpha_logit
        self.alpha_feat = alpha_feat
        self.best_val_loss = float('inf')

        os.makedirs(self.checkpoint_dir, exist_ok=True)

        # 1. Setup Feature Distillation (Hooks & Adapters)
        self.adapter = None
        if 'feature' in mode or 'full' in mode:
            self.t_features, self.s_features = {}, {}
            self._register_hooks(student_layer, teacher_layer)
            self.adapter = nn.Conv2d(s_channels, t_channels, 1).to(device)

        # 2. Optimization Setup
        self.criterion = AdaptiveWingLoss().to(device)
        self.mse_loss = nn.MSELoss()
        
        train_params = list(self.student.parameters())
        if self.adapter:
            train_params += list(self.adapter.parameters())
            
        self.optimizer = optim.AdamW(train_params, lr=lr, weight_decay=1e-4)
        
        # Consistent Scheduler: 5 epoch warmup + Cosine Decay
        lambda_lr = lambda e: (e + 1) / 5 if e < 5 else 0.5 * (1 + math.cos(math.pi * (e - 5) / (num_epochs - 5)))
        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda_lr)
        self.scaler = GradScaler()

    def _register_hooks(self, s_path, t_path):
        def get_activation(name, storage):
            def hook(m, i, o): storage[name] = o
            return hook
        
        # Registering on teacher and student
        dict(self.teacher.named_modules())[t_path].register_forward_hook(get_activation('feat', self.t_features))
        dict(self.student.named_modules())[s_path].register_forward_hook(get_activation('feat', self.s_features))

    def spatial_kl_loss(self, s_logits, t_logits):
        B, K, H, W = s_logits.shape
        s_prob = F.log_softmax(s_logits.view(B, K, -1) / self.temp, dim=-1)
        t_prob = F.softmax(t_logits.view(B, K, -1) / self.temp, dim=-1)
        return F.kl_div(s_prob, t_prob, reduction='batchmean') * (self.temp**2)

    def train(self, train_loader, val_loader):
        print(f"--- Starting {self.mode.upper()} Training ---")
        for epoch in range(self.num_epochs):
            self.student.train()
            if self.adapter: self.adapter.train()
            
            running_loss = 0.0
            pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}')
            
            for imgs, targets, masks, _ in pbar:
                imgs, targets = imgs.to(self.device), targets.to(self.device)
                masks_exp = masks.to(self.device).unsqueeze(-1).unsqueeze(-1)
                
                self.optimizer.zero_grad(set_to_none=True)
                
                with autocast():
                    # Student Forward
                    s_out = self.student(imgs)
                    s_hms = s_out.heatmaps if hasattr(s_out, 'heatmaps') else s_out
                    
                    # 1. Base Loss
                    loss = self.criterion(s_hms * masks_exp, targets * masks_exp)

                    # 2. Knowledge Distillation Logic
                    if self.mode != 'baseline':
                        with torch.no_grad():
                            t_out = self.teacher(imgs)
                            t_hms = t_out.heatmaps if hasattr(t_out, 'heatmaps') else t_out

                        # Logit-based KD
                        if self.mode in ['logits', 'full']:
                            loss += self.alpha_logit * self.spatial_kl_loss(s_hms, t_hms)
                        
                        # Feature-based KD
                        if self.mode in ['feature', 'full']:
                            s_f, t_f = self.s_features['feat'], self.t_features['feat']
                            if s_f.shape[2:] != t_f.shape[2:]:
                                s_f = F.interpolate(s_f, size=t_f.shape[2:], mode='bilinear')
                            loss += self.alpha_feat * self.mse_loss(self.adapter(s_f), t_f)

                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                
                running_loss += loss.item()
                pbar.set_postfix({'loss': f"{loss.item():.4f}"})

            self.scheduler.step()
            val_loss = self.validate(val_loader)
            print(f"Epoch {epoch+1} | Val Loss: {val_loss:.6f}")
            
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                torch.save(self.student.state_dict(), os.path.join(self.checkpoint_dir, f'best_{self.mode}.pth'))

    def validate(self, val_loader):
        self.student.eval()
        total_loss = 0.0
        with torch.no_grad():
            for imgs, targets, _, _ in val_loader:
                imgs, targets = imgs.to(self.device), targets.to(self.device)
                outputs = self.student(imgs)
                s_hms = outputs.heatmaps if hasattr(outputs, 'heatmaps') else outputs
                total_loss += self.criterion(s_hms, targets).item()
        return total_loss / len(val_loader)

In [None]:
trainer = PoseDistillationTrainer(
    student_model=student_model, 
    device=device, 
    mode='baseline', 
    checkpoint_dir= "checkpoint_logits",
    num_epochs=10
)
trainer.train(train_loader, val_loader)

## 3.2. Logits-based Knowledge Distillation

The distillation loss is defined as the Kullback-Leibler (KL) Divergence between the softened student logits $z_s$ and teacher logits $z_t$, scaled by temperature $T^2$:

$$ L_{KD} = T^2 \cdot KL\left( \sigma\left(\frac{z_s}{T}\right), \sigma\left(\frac{z_t}{T}\right) \right) $$

Where $\sigma$ is the softmax function.

**In Human Pose Estimation (Regression/Heatmaps):**
Typically, MSE is used for heatmaps. However, we can treat the heatmaps as **Spatial Probability Distributions** by flattening the spatial dimensions ($H \times W$) and applying Softmax. This allows us to use Hinton's formulation directly by softening the peaky heatmap distributions.

**Loss Function:**
$$ L_{total} = L_{GT} + \alpha \cdot T^2 \cdot KL(\text{SpatialSoftmax}(H_S/T), \text{SpatialSoftmax}(H_T/T)) $$


#### Understanding Temperature ($T$) in Logits Loss

You might wonder: **Why do we divide logits by a Temperature $T$?**

The standard Softmax function is defined as:
$$ \sigma(z_i) = \frac{\exp(z_i)}{\sum_j \exp(z_j)} $$

When a model is very confident, one value in inputs $z$ is much larger than the others. This makes the output probability distribution **extremely sharp** (almost 1.0 for the peak, and 0.0 for everything else).

**The Problem:** If the Teacher predicts a perfect 1.0 for the correct keypoint and 0.0 everywhere else, it provides **no more information** than the Ground Truth labels! We lose the rich structural knowledge (e.g., "this pixel is 0.001 likely, but that pixel is 0.05 likely").

**The Solution (Temperature):**
By introducing $T > 1$, we "soften" the distribution:
$$ q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} $$

- **High T (> 1)**: Flattens the peaks. The Teacher's output becomes softer, spreading probability mass to neighboring pixels. This reveals the **"Dark Knowledge"**â€”the relationships between the peak and its surrounding area.
- **Low T (1)**: Standard sharp Softmax.

We implement this in `spatial_kl_loss`. We take the raw heatmap logits (before softmax), divide them by `temp`, and *then* apply Softmax. This forces the Student to learn not just the peak location, but the entire spatial uncertainty shape of the Teacher.


In [None]:
trainer = PoseDistillationTrainer(
    student_model=student_model, 
    teacher_model=teacher_model, 
    device=device, 
    mode='logits', 
    temp=4.0, 
    alpha_logit=1.0
)
trainer.train(train_loader, val_loader)

## 3.3 Feature-based Knowledge Distillation

Feature-based KD encourages the Student to learn intermediate representations that resemble the Teacher's. Since Student and Teacher features often have different dimensions (channels/resolution), we typically use a **Connector** (e.g., 1x1 Conv) to map Student features to the Teacher's space.

We will use **Forward Hooks** to extract intermediate feature maps.

**Loss Function:**
$$ L_{total} = L_{GT} + \beta L_{Feat}(F_{Adaptor}(F_S), F_T) $$


In [None]:
trainer = PoseDistillationTrainer(
    student_model=student_model, 
    teacher_model=teacher_model, 
    device=device, 
    mode='full', 
    # Feature specific
    student_layer="backbone.10", 
    teacher_layer="layer1",
    s_channels=512, 
    t_channels=48,
    # KD weights
    alpha_logit=1.0,
    alpha_feat=0.5
)
trainer.train(train_loader, val_loader)

# Conclusion Results Comparison
### Evaluate Baseline Student
We evaluate the trained student model in two ways:

1.  **Qualitative Accuracy (Visual)** (Using FiftyOne to visualize model predictions vs Ground Truth.)
2.  **Model Complexity** (FLOPs and Parameters.)
3.  **Coco performance metrics** - Average Precision


In [None]:
class ModelEvaluator:
    def __init__(self, device: torch.device):
        self.device = device

    def count_trainable_parameters(self, model: nn.Module, model_name: str = "Model") -> int:
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_size_mb = (sum(p.numel() for p in model.parameters()) * 4) / (1024**2)
        print(f"--- {model_name} Profile ---")
        print(f"Trainable Parameters: {trainable_params:,}")
        print(f"Model Size: {total_size_mb:.2f} MB\n")
        return trainable_params

    def compute_model_complexity(self, model: nn.Module, input_shape: Tuple[int, int, int], name_of_model: str = "Model"):
        try:
            from ptflops import get_model_complexity_info
            silent_buffer = io.StringIO()
            model.eval()
            macs, _ = get_model_complexity_info(model, input_shape, as_strings=False, 
                                               print_per_layer_stat=False, verbose=False, ost=silent_buffer)
            results = {"GMACs": macs / 1e9, "GFLOPs": (2 * macs) / 1e9}
            print(f"--- {name_of_model} Complexity ---")
            print(f"MACs: {results['GMACs']:.3f} G | FLOPs: {results['GFLOPs']:.3f} G\n")
            return results
        except ImportError:
            return None

    @torch.no_grad()
    def add_predictions_to_fiftyone(self, dataset, model, num_samples=20, batch_size=8, 
                                   input_size=(288, 384), heatmap_size=(72, 96), field_name="predictions"):
        model.eval().to(self.device)
        view = dataset.take(num_samples)
        samples_list = list(view)
        for i in tqdm(range(0, num_samples, batch_size), desc=f"Predicting {field_name}"):
            batch_samples = samples_list[i : i + batch_size]
            batch_imgs = [cv2.resize(cv2.cvtColor(cv2.imread(s.filepath), cv2.COLOR_BGR2RGB), input_size) for s in batch_samples]
            inputs = torch.from_numpy(np.stack(batch_imgs)).permute(0, 3, 1, 2).float().to(self.device) / 255.0
            outputs = model(inputs)
            hms = outputs.heatmaps if hasattr(outputs, "heatmaps") else outputs
            preds, maxvals = self._get_max_preds_vectorized(hms.cpu().numpy())
            for j, sample in enumerate(batch_samples):
                norm_points = (preds[j] / [heatmap_size[0], heatmap_size[1]]).clip(0, 1).tolist()
                sample[field_name] = fo.Keypoints(keypoints=[fo.Keypoint(points=norm_points, confidence=maxvals[j].flatten().tolist())])
                sample.save()
        return view

    def _get_max_preds_vectorized(self, batch_heatmaps):
        B, K, H, W = batch_heatmaps.shape
        heatmaps_reshaped = batch_heatmaps.reshape((B, K, -1))
        idx = np.argmax(heatmaps_reshaped, axis=2).reshape((B, K, 1))
        maxvals = np.amax(heatmaps_reshaped, axis=2).reshape((B, K, 1))
        preds = np.zeros((B, K, 2), dtype=np.float32)
        preds[:, :, 0], preds[:, :, 1] = idx[:, :, 0] % W, idx[:, :, 0] // W
        return preds, maxvals

In [None]:
##Adding skeleton
skeleton = fo.KeypointSkeleton(
    labels=[
        "nose", "l_eye", "r_eye", "l_ear", "r_ear", 
        "l_shoulder", "r_shoulder", "l_elbow", "r_elbow", 
        "l_wrist", "r_wrist", "l_hip", "r_hip", 
        "l_knee", "r_knee", "l_ankle", "r_ankle"
    ],
    edges=[
        [0, 1], [0, 2], [1, 3], [2, 4], # Head
        [5, 6], [5, 7], [7, 9], [6, 8], [8, 10], # Arms
        [5, 11], [6, 12], [11, 12], # Torso
        [11, 13], [13, 15], [12, 14], [14, 16] # Legs
    ]
)
val_data.default_skeleton = skeleton
val_data.save()

In [None]:
evaluator = ModelEvaluator(device=device)

# --- 1. Teacher Evaluation ---
evaluator.count_trainable_parameters(teacher_model, "Teacher (HRNet)")
evaluator.compute_model_complexity(teacher_model, (3, 288, 384))
evaluator.add_predictions_to_fiftyone(val_data, teacher_model, field_name="teacher_preds")

# --- 2. Logit-KD Student Evaluation ---
# Load the best weights from your checkpoints_logits_kd directory
student_model_logits.load_state_dict(torch.load('checkpoints_logits_kd/best_model.pth'))
evaluator.count_trainable_parameters(student_model_logits, "Student (Logit-KD)")
evaluator.add_predictions_to_fiftyone(val_data, student_model_logits, field_name="student_logit_kd")

# --- 3. Feature-KD Student Evaluation ---
# Load the best weights from your checkpoints_feat_kd directory
student_model_feat.load_state_dict(torch.load('checkpoints_feat_kd/best_feat_model.pth'))
evaluator.count_trainable_parameters(student_model_feat, "Student (Feature-KD)")
evaluator.add_predictions_to_fiftyone(val_data, student_model_feat, field_name="student_feat_kd")

# --- 4. Launch Visualization ---
session = fo.launch_app(val_data)