# Knowledge Distillation for 2D Human Pose Estimation

## What is Knowledge Distillation?
Knowledge Distillation (KD) is a technique in deep learning where a small, compact model (the **Student**) is trained to reproduce the behavior of a large, complex model (the **Teacher**), or an ensemble of models. The core idea was popularized by Hinton et al. in ["Distilling the Knowledge in a Neural Network" (2015)](https://arxiv.org/abs/1503.02531).

In Human Pose Estimation (HPE), the Teacher is typically a high-performance model (like HRNet-W48) that is accurate but computationally expensive. The Student is a lightweight model (like SqueezeNet or MobileNet) designed for real-time inference on edge devices.

## Why use KD?
1. **Model Compression**: Reduce model size and inference latency.
2. **Accuracy Boost**: The Student often learns better from the Teacher's "soft targets" (probability distributions or heatmaps) than from one-hot ground truth labels alone, as the Teacher provides structural information about the data.

## Types of Knowledge Distillation covered here:
1. **Response-based (Logits) KD**: The Student mimics the final output (heatmaps) of the Teacher.
2. **Feature-based KD**: The Student mimics the intermediate feature maps of the Teacher, learning to extract similar spatial features.

---


# Introduction.
What is KD, importance, why we need to learn from it, why is expect of the members to get.
Requirements needed etc.

#### Student model

We design a lightweight model with <a href="https://arxiv.org/pdf/1602.07360">SquuezeNet backbone</a> pretrained to extarct features from the an images. WSo we just remove tthe classification tail of the model and add a custom backbone for generating heatmaps.

**SqueezeNet** is designed for efficiency. It achieves AlexNet-level accuracy with 50x fewer parameters. Its core building block is the **Fire Module**, which 'squeezes' the feature map depth with 1x1 filters before 'expanding' it with a mix of 1x1 and 3x3 filters.

**Why use it as a Student?**
- **Lightweight**: Ideal for deployment on mobile or embedded devices.
- **Speed**: Fast inference times.
- **Capacity gap**: It has significantly less capacity than HRNet, making it a perfect candidate to test the effectiveness of Knowledge Distillation.



In [1]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from collections import namedtuple

KeypointOutput = namedtuple('KeypointOutput', ['heatmaps'])
class SqueezeNetHPE(nn.Module):
    def __init__(self, num_keypoints=17):
        super().__init__()
        squeezenet = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1)
        self.backbone = squeezenet.features
        self.relu = nn.ReLU(inplace=True)
        self.conv_heatmap = nn.Conv2d(
            in_channels=512,
            out_channels=num_keypoints,
            kernel_size=3,
            padding=1
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.relu(x)
        x = F.interpolate(
            x,
            size=(96, 72), 
            mode="bilinear",
            align_corners=False
        )
        heatmaps = self.conv_heatmap(x)

        return KeypointOutput(heatmaps=heatmaps)



The code above modifies the standard SqueezeNet classifier for the task of Pose Estimation (Heatmap Regression):

1.  **Backbone (`self.backbone`)**: We use the feature extractor from `squeezenet1_1` pretrained on ImageNet. This provides rich, low-level visual features.
2.  **Upsampling**: SqueezeNet significantly downsamples the input (typically by 32x). To generate high-quality heatmaps (which require spatial precision), we use `F.interpolate` to upsample the features to **96x72**.
3.  **Heatmap Head (`self.conv_heatmap`)**: Instead of a Fully Connected layer for classification, we use a final **Convolutional Layer** (kernel size 3x3) to produce **17 output channels**, each corresponding to a keypoint heatmap.

This simple "Backbone + Upsample + 1x1 Conv" structure is extremely lightweight/fast but often lacks the spatial refinement of HRNet, making it an ideal candidate for Knowledge Distillation.

#### Teacher model


In this tutprial we will use <a href="https://github.com/HRNet/HRNet-Human-Pose-Estimation?tab=readme-ov-file"> HRNet-pose as a teacher model </a>.

**HRNet** (High-Resolution Net) maintains high-resolution representations throughout the network. Unlike traditional architectures (ResNet, VGG) that downsample the image and then upsample, HRNet connects high-to-low resolution subnetworks in parallel.

**Key Features:**
- **Parallel Resolutions**: Maintains high-res stream from start to finish.
- **Multi-Scale Fusion**: Repeatedly exchanges information across resolutions.

**Role as Teacher:**
HRNet is a state-of-the-art CNN based model for 2D Pose Estimation. It captures fine spatial details (crucial for keypoints) which we want our Student to mimic.

In [2]:
from hrnet_pose.hrnetpose_model import get_pose_net
from hrnet_pose.configs import load_configs

cfg = load_configs('hrnet_pose/hrnet_w48_model_configs.yaml')
teacher_model = get_pose_net(cfg, is_train=False)
checkpoint = torch.load('hrnet_pose/hrnet_pose_models/pose_hrnet_w48_384x288.pth', map_location='cpu')
teacher_model.load_state_dict(checkpoint)


  checkpoint = torch.load('hrnet_pose/hrnet_pose_models/pose_hrnet_w48_384x288.pth', map_location='cpu')


<All keys matched successfully>

### Dataset

#### COCO Keypoint Dataset
The COCO (Common Objects in Context) dataset is a large-scale object detection, segmentation, and captioning dataset. For Human Pose Estimation, we use the **Keypoints** task, which annotates **17 keypoints** (joints) on human bodies.

The 17 Keypoints are:
0: Nose, 1: Left Eye, 2: Right Eye, 3: Left Ear, 4: Right Ear, 5: Left Shoulder, 6: Right Shoulder, 7: Left Elbow, 8: Right Elbow, 9: Left Wrist, 10: Right Wrist, 11: Left Hip, 12: Right Hip, 13: Left Knee, 14: Right Knee, 15: Left Ankle, 16: Right Ankle.

##### COCO Skeleton Visualization
The keypoints are connected to form a skeleton structure, helping in visualizing pose.

![COCO Skeleton](images/keypoints-skeleton.png) ![COCO Examples](images/keypoints-examples.png)

*Example of COCO Keypoint Annotations (Source: COCO Dataset)*

Access the <a href="https://cocodataset.org/#keypoints-2017"> COCO 2017 keypoint dataset here</a>.

#### 1. Accessing using fiftyone

#### What is FiftyOne?
[FiftyOne](https://voxel51.com/fiftyone/) is an open-source toolset by Voxel51 designed for building high-quality datasets and computer vision models. It acts as a visual interface for your datasets, allowing you to:

- **Visualize** raw images and their annotations (bounding boxes, keypoints, etc.) instantly.
- **Query** specific samples (e.g., "show me all images with > 10 people").
- **Debug** model predictions by overlaying them on ground truth.
- **Manage** data splits and export to various formats (COCO, YOLO, TFRecord).

In this notebook, we use FiftyOne to easily download specific subsets of COCO and visualize the ground truth keypoints interactively.


In [25]:
import fiftyone as fo
import fiftyone.zoo as foz

In [None]:
coco2017_train = foz.load_zoo_dataset("coco-2017", split='train',max_samples=30000,  label_types=["person_keypoints"])

In [None]:
coco2017_val = foz.load_zoo_dataset("coco-2017", split='validation')

In [None]:
coco2017_val

In [None]:
session = fo.launch_app(coco2017_train, auto=False)


#### 2. Direct download

In [None]:
# Path to store the data
!mkdir -p /capstor/scratch/cscs/ckuya/coco_data /capstor/scratch/cscs/ckuya/coco_data/annotations 

!wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P /capstor/scratch/cscs/ckuya/coco_data/
!wget -c http://images.cocodataset.org/zips/val2017.zip -P /capstor/scratch/cscs/ckuya/coco_data/
!wget -c http://images.cocodataset.org/zips/train2017.zip -P /capstor/scratch/cscs/ckuya/coco_data/


In [None]:
!unzip  /capstor/scratch/cscs/ckuya/coco_data/annotations_trainval2017.zip -d /capstor/scratch/cscs/ckuya/coco_data
!unzip -q /capstor/scratch/cscs/ckuya/coco_data/val2017.zip -d /capstor/scratch/cscs/ckuya/coco_data/
!unzip -q /capstor/scratch/cscs/ckuya/coco_data/train2017.zip -d /capstor/scratch/cscs/ckuya/coco_data/

# Cleanup zip files
!rm -rf /capstor/scratch/cscs/ckuya/coco_data/*.zip

### Data Prepocessing

1.  **Filtering**: We first check if an image actually contains a person with valid annotations. We filter out images where the number of annotated keypoints is zero.
2.  **Cropping**: Once we identify a valid image, we locate the person using their bounding box and **crop** the image to center on them. This removes background noise and focuses the model's attention.
3.  **Coordinate Transformation**: We transform the original keypoint annotations from the full image space into our new **cropped coordinate space**. This ensures the labels match the input image we feed to the network.
4.  **Ground Truth Generation**: Finally, we take these transformed keypoints and generate **Gaussian Heatmaps**, which serve as the training targets for our model.

In [13]:
import numpy as np
from typing import List, Tuple, Union
import cv2

def extract_keypoints_and_visibility(keypoints: List[float]):
    try:
        kps_array = np.array(keypoints, dtype=np.float32).reshape(-1, 3)
        coords = kps_array[:,:2]
        visibility = kps_array[:,2]

        return coords, visibility.astype(np.int32)

    except Exception as e:
        print(f"Error extracting keypoints and visibility: {e}")
        raise e


def process_image_and_keypoints(
    image: np.ndarray,
    keypoints: Union[List[List[float]], np.ndarray],
    bbox: Union[List[int], Tuple[int, int, int, int]],
    target_size: Tuple[int, int],
    angle: float = 0,
    flip: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:

    try:

        x1, y1, w, h = np.round(bbox).astype(int)
        img_h, img_w = image.shape[:2]

        cx1, cy1 = max(0, x1), max(0, y1)
        cx2, cy2 = min(img_w, x1 + w), min(img_h, y1 + h)

        # Step 1: Crop
        image = image[cy1:cy2, cx1:cx2]
        crop_h, crop_w = image.shape[:2]

        # Step 2: Prepare Keypoints (N, 2)
        kps = np.ascontiguousarray(keypoints, dtype=np.float32).reshape(-1, 2)

        zero_mask = np.all(kps == 0, axis=1)
        kps -= [cx1, cy1]

        # Step 3:Masking
        invalid_mask = (kps[:, 0] < 0) | (kps[:, 1] < 0) | (kps[:, 0] >= crop_w) | (kps[:, 1] >= crop_h)

        # Step 4: Resize (Image and KPs together)
        image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
        kps *= [target_size[0] / crop_w, target_size[1] / crop_h]

        # Step 5: Rotation 
        if angle != 0:
            center = (target_size[0] / 2, target_size[1] / 2)
            M = cv2.getRotationMatrix2D(center, angle, 1.0)
            image = cv2.warpAffine(image, M, target_size)
            kps = kps @ M[:, :2].T + M[:, 2]

        # Step 6: Horizontal Flip
        if flip:
            image = cv2.flip(image, 1)
            kps[:, 0] = target_size[0] - kps[:, 0]

        # Step 7: Final Restoration
        kps[zero_mask | invalid_mask] = 0

        return image, kps
    except Exception as e:
        print(f"Error in process_image_and_keypoints: {e}")
        raise e




#### Heatmaps

Nowadays, **Gaussian Heatmaps** are the standard representation for 2D Pose Estimation. Instead of regressing exact (x, y) coordinates directly (which is hard to learn), the model predicts a probability map where the peak represents the keypoint location.

**Generation Formula:**
For a keypoint $k$ at $(x_k, y_k)$, the value at pixel $(i, j)$ is:

$$H_k(i, j) = \exp\left( -\frac{(i - x_k)^2 + (j - y_k)^2}{2\sigma^2} \right)$$

Where $\sigma$ controls the spread of the Gaussian peak.

**Masks (Visibility):**
We also associate a binary **Mask** with each keypoint. This mask is **1** if the keypoint is visible and annotated, and **0** otherwise. During training, we multiply the loss by this mask to essentially 'switch off' the learning for missing or invisible joints, preventing the model from being confused by unlabelled data.


In [4]:
def generate_heatmaps(
    keypoints: np.ndarray,
    keypoints_visible: np.ndarray,
    input_size: Tuple[int, int],
    heatmap_size: Tuple[int, int] = (72, 96),
    sigma: float = 2.0) -> Tuple[torch.Tensor, torch.Tensor]:

    W, H = heatmap_size
    w, h = input_size
    K = keypoints.shape[0]

    scale = np.array([(W - 1) / (w - 1), (H - 1) / (h - 1)], dtype=np.float32)
    kps_hm = keypoints * scale  # Shape (K, 2)

    mask = (keypoints_visible >= 0.5) & \
           (kps_hm[:, 0] >= 0) & (kps_hm[:, 0] < W) & \
           (kps_hm[:, 1] >= 0) & (kps_hm[:, 1] < H)
    mask = mask.astype(np.float32)

    yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')

    mu_x = kps_hm[:, 0, np.newaxis, np.newaxis]
    mu_y = kps_hm[:, 1, np.newaxis, np.newaxis]

    dist_sq = (xx - mu_x) ** 2 + (yy - mu_y) ** 2

    heatmaps = np.exp(-dist_sq / (2 * sigma**2))

    heatmaps *= mask[:, np.newaxis, np.newaxis]

    return torch.from_numpy(heatmaps), torch.from_numpy(mask)

#### Dataloader

In [5]:
import os
import cv2
import math
import torch
import logging
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
from tqdm.auto import tqdm
from typing import Optional, Callable, List, Tuple

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

def get_downloaded_image_ids(root_dir: str, coco: COCO) -> List[int]:
    on_disk = set(os.listdir(root_dir))
    all_img_info = coco.loadImgs(coco.getImgIds())
    existing_ids = [info['id'] for info in all_img_info if info['file_name'] in on_disk]
    return existing_ids

class CocoKeypoints(Dataset):
    def __init__(self, root: str, annFile: str, target_size=(288, 384), heatmap_size=(72, 96)) -> None:
        self.root = root
        self.coco = COCO(annFile)
        self.target_size = target_size
        self.heatmap_size = heatmap_size

        img_ids = get_downloaded_image_ids(self.root, self.coco)
        self.valid_anns = []

        for img_id in img_ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            for ann in anns:
                if ann.get("num_keypoints", 0) > 0:
                    self.valid_anns.append((img_id, ann))

        logger.info(f"Initialized {len(self.valid_anns)} person samples.")

    def __len__(self) -> int:
        return len(self.valid_anns)

    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        img_id, target = self.valid_anns[index]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info["file_name"])
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        coords, visibility = extract_keypoints_and_visibility(target["keypoints"])
        coords_tensors = torch.from_numpy(coords)

        processed_img_np, processed_kps = process_image_and_keypoints(
            image, coords, target["bbox"], self.target_size
        )

        hm_scale = [(self.heatmap_size[0]-1)/(self.target_size[0]-1),
                    (self.heatmap_size[1]-1)/(self.target_size[1]-1)]
        kps_hm = processed_kps * hm_scale

        heatmaps, masks = generate_heatmaps(kps_hm, visibility, self.heatmap_size)
        img_tensor = torch.from_numpy(processed_img_np).permute(2, 0, 1).float() / 255.0

        return img_tensor, heatmaps, masks, coords_tensors

def collate_fn(batch):
    imgs, hms, masks, coords = zip(*batch)
    return torch.stack(imgs), torch.stack(hms), torch.stack(masks), torch.stack(coords)

def get_coco_dataloaders(root_train, ann_train, root_val, ann_val, batch_size=128, num_workers=8):
    train_dataset = CocoKeypoints(root_train, ann_train)
    val_dataset = CocoKeypoints(root_val, ann_val)
    
    loader_args = dict(
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_fn,
        persistent_workers=True
    )

    return (
        DataLoader(train_dataset, shuffle=True, **loader_args),
        DataLoader(val_dataset, shuffle=False, **loader_args)
    )


In [6]:
root_train = "/capstor/scratch/cscs/ckuya/coco_data/train2017"
root_val = "/capstor/scratch/cscs/ckuya/coco_data/val2017"
ann_train = "/capstor/scratch/cscs/ckuya/coco_data/annotations/person_keypoints_train2017.json"
ann_val = "/capstor/scratch/cscs/ckuya/coco_data/annotations/person_keypoints_val2017.json"

train_loader, val_loader = get_coco_dataloaders(
    root_train=root_train,
    ann_train=ann_train,
    root_val=root_val,
    ann_val=ann_val,
    batch_size=128,
    num_workers=8
)

loading annotations into memory...
Done (t=5.10s)
creating index...
index created!


2026-02-01 07:39:00,339 - INFO - Initialized 149813 person samples.
2026-02-01 07:39:00,469 - INFO - Initialized 6352 person samples.


loading annotations into memory...
Done (t=0.11s)
creating index...
index created!


In [7]:
def validate(
    student_model: nn.Module,
    val_loader: torch.utils.data.DataLoader,
    criterion: nn.Module,
) -> float:
    """
    Validate the model against ground truth heatmaps only.
    """
    student_model.eval()
    total_val_loss = 0.0

    with torch.no_grad():
        for imgs, gt_heatmaps, _, _ in tqdm(val_loader, desc="Validation process"):
            imgs = imgs.to(device, non_blocking=True)
            heatmaps = gt_heatmaps.to(device, non_blocking=True)

            with torch.amp.autocast(device_type=device.type):
                outputs = student_model(imgs)
                student_heatmaps = outputs.heatmaps

                loss = criterion(student_heatmaps, heatmaps)

            total_val_loss += loss.item()

    return total_val_loss / len(val_loader)


In [None]:
count_parameters_and_flops(student_model)

In [None]:
import heatmaps_to_keypoints

In [None]:
import heatmaps_to_keypoints
results = heatmaps_to_keypoint.decode(gt_ht)

In [None]:
out_s = student_model(imgb.to(device))

In [None]:
out_s.heatmaps.shape

In [None]:
teacher_model.to(device).eval()
out_t = teacher_model(imgb.to(device))

### Training and distillation

#### 1. Baseline Student Training

We start by training the **Student Model (SqueezeNet)** purely on the Ground Truth labels, **without any distillation**.

This establishes a **baseline performance**. It represents how well the lightweight model can learn the task on its own. Later, we will compare our distilled models against this baseline to measure the improvement gained from the Teacher's guidance.


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = SqueezeNetHPE(num_keypoints=17).to(device)

In [9]:
class AdaptiveWingLoss(nn.Module):
    def __init__(
        self,
        omega: float = 14.0,
        theta: float = 0.5,
        epsilon: float = 1.0,
        alpha: float = 2.1,
        weight: float = 10.0,
    ):
        super().__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha
        self.weight = weight
        self.inv_epsilon = 1.0 / epsilon

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        delta = (pred - target).abs()
        y_pow = self.alpha - target

        theta_eps_pow = torch.pow(self.theta * self.inv_epsilon, y_pow)

        A = (self.omega * (1.0 / (1.0 + theta_eps_pow))
             * y_pow * torch.pow(self.theta * self.inv_epsilon, y_pow - 1.0)
             * self.inv_epsilon)

        C = self.theta * A - self.omega * torch.log1p(theta_eps_pow)

        losses = torch.where(
            delta < self.theta,
            self.omega * torch.log1p(torch.pow(delta * self.inv_epsilon, y_pow)),
            A * delta - C
        )
        with torch.no_grad():
            Hd = F.max_pool2d(target, kernel_size=3, stride=1, padding=1)
            importance_mask = (Hd >= 0.2).to(pred.dtype)

        weighted_losses = losses * (self.weight * importance_mask + 1.0)

        return weighted_losses.mean()

In [10]:
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
import math


checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

criterion = AdaptiveWingLoss().to(device)
optimizer = optim.AdamW(student_model.parameters(), lr=1e-3, weight_decay=1e-4)

num_epochs = 3
warmup_epochs = 1
best_acc = 0.0

lambda_lr = lambda epoch: (epoch + 1) / warmup_epochs if epoch < warmup_epochs else 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (num_epochs - warmup_epochs)))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr)

scaler = GradScaler()

print(f"Starting Baseline Training on {device}...")

for epoch in range(num_epochs):
    student_model.train()
    running_loss = 0.0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True)
    
    for images, targets, masks, _ in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        masks_expanded = masks.unsqueeze(-1).unsqueeze(-1)
        
        optimizer.zero_grad()
        
        # Mixed Precision Context
        with autocast():
            outputs = student_model(images)
            preds = outputs.heatmaps
            
            loss = criterion(preds * masks_expanded, targets * masks_expanded)
        
        # Backward pass with scaler
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': running_loss / (pbar.n + 1), 'lr': optimizer.param_groups[0]['lr']})
    
    # Step Scheduler
    scheduler.step()
    
    # Evaluation & Checkpointing
    if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
        val_acc = validate(student_model, val_loader, criterion)
        print(f"Validation Accuracy (Epoch {epoch+1}): {val_acc*100:.2f}%")
        
        # Save Periodic Checkpoint
        torch.save(student_model.state_dict(), os.path.join(checkpoint_dir, f'student_epoch_{epoch+1}.pth'))
        print(f"Saved checkpoint for epoch {epoch+1}")
        
        # Save Best Model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(student_model.state_dict(), os.path.join(checkpoint_dir, 'student_best.pth'))
            print(f"New Best Model Saved! (Acc: {best_acc*100:.2f}%)")
    
print("Baseline Training Completed.")

  from .autonotebook import tqdm as notebook_tqdm
  scaler = GradScaler()


Starting Baseline Training on cuda...


  with autocast():
Epoch 1/3: 100%|██████████| 1171/1171 [01:37<00:00, 12.02it/s, loss=0.221, lr=0.001]
Epoch 2/3: 100%|██████████| 1171/1171 [01:31<00:00, 12.75it/s, loss=0.162, lr=0.001]
Epoch 3/3: 100%|██████████| 1171/1171 [01:31<00:00, 12.78it/s, loss=0.162, lr=0.0005]
Validation process: 100%|██████████| 50/50 [00:05<00:00,  8.47it/s]

Validation Accuracy (Epoch 3): 16.21%
Saved checkpoint for epoch 3
New Best Model Saved! (Acc: 16.21%)
Baseline Training Completed.





#### Evaluate Baseline Student
We evaluate the trained student model in two ways:
1.  **Qualitative Accuracy (Visual)**: Using FiftyOne to visualize model predictions vs Ground Truth.
2.  **Model Complexity**: Using `calflops` to measure FLOPs and Parameters.

In [None]:
def add_predictions_to_fiftyone(dataset, model, num_samples=10):
    model.eval()
    print("Generating predictions for FiftyOne...")
    view = dataset.take(num_samples)
    
    with torch.no_grad():
        for sample in view.iter_samples(autosave=True):
            # Load and Preprocess image (simplified for demo)
            # Note: In a real pipeline, we'd use the exact Transform pipeline.
            # Here we just assume we can load it.
            img = cv2.imread(sample.filepath)
            if img is None:
                continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)a            h, w, _ = img.shape
            
            # Resize to Model Input (288x384)
            img_resized = cv2.resize(img, (288, 384))
            input_tensor = torch.from_numpy(img_resized).permute(2,0,1).float() / 255.0
            input_tensor = input_tensor.unsqueeze(0).to(device)
            
            # Inference
            output = model(input_tensor)
            heatmaps = output.heatmaps
            
            # Decode
            preds, _ = get_max_preds(heatmaps.cpu())
            preds = preds[0] # First batch item
            
            # Scale back to Original Image Dimensions
            # Heatmap size is 72x96
            # scale_x = w / 72.0
            # scale_y = h / 96.0
            
            points = []
            for node_i in range(17):
                 x, y, conf = preds[node_i]
                 # Predictions are in Heatmap coords
                 # FiftyOne expects Normalized coords (0-1)
                 norm_x = (x / 72.0)
                 norm_y = (y / 96.0)
                 
                 # Clip to 0-1 to avoid plotting errors
                 norm_x = max(0, min(1, norm_x))
                 norm_y = max(0, min(1, norm_y))
                 
                 points.append((norm_x, norm_y))
            
            # Add to Sample
            sample['predictions'] = fo.Keypoints(keypoints=[fo.Keypoint(points=points)])
    
    return view

# Visualizing Predictions
val_view = add_predictions_to_fiftyone(coco2017_val, student_model, num_samples=20)
session = fo.launch_app(val_view)

In [None]:
from calflops import calculate_flops

input_shape = (1, 3, 384, 288)
flops, macs, params = calculate_flops(model=student_model, input_shape=input_shape)

print(f"Student Model Complexity:")
print(f" - FLOPs: {flops}")
print(f" - MACs:  {macs}")
print(f" - Params: {params}")


# 2. Logits-based Knowledge Distillation

## Theory: Hinton's KD with Temperature
In the seminal work by Hinton et al., Knowledge Distillation softens the logits of the Teacher to provide more information about the class distribution.

The distillation loss is defined as the Kullback-Leibler (KL) Divergence between the softened student logits $z_s$ and teacher logits $z_t$, scaled by temperature $T^2$:

$$ L_{KD} = T^2 \cdot KL\left( \sigma\left(\frac{z_s}{T}\right), \sigma\left(\frac{z_t}{T}\right) \right) $$

Where $\sigma$ is the softmax function.

**In Human Pose Estimation (Regression/Heatmaps):**
Typically, MSE is used for heatmaps. However, we can treat the heatmaps as **Spatial Probability Distributions** by flattening the spatial dimensions ($H \times W$) and applying Softmax. This allows us to use Hinton's formulation directly by softening the peaky heatmap distributions.

**Loss Function:**
$$ L_{total} = L_{GT} + \alpha \cdot T^2 \cdot KL(\text{SpatialSoftmax}(H_S/T), \text{SpatialSoftmax}(H_T/T)) $$


### Understanding Temperature ($T$) in Logits Loss

You might wonder: **Why do we divide logits by a Temperature $T$?**

The standard Softmax function is defined as:
$$ \sigma(z_i) = \frac{\exp(z_i)}{\sum_j \exp(z_j)} $$

When a model is very confident, one value in inputs $z$ is much larger than the others. This makes the output probability distribution **extremely sharp** (almost 1.0 for the peak, and 0.0 for everything else).

**The Problem:** If the Teacher predicts a perfect 1.0 for the correct keypoint and 0.0 everywhere else, it provides **no more information** than the Ground Truth labels! We lose the rich structural knowledge (e.g., "this pixel is 0.001 likely, but that pixel is 0.05 likely").

**The Solution (Temperature):**
By introducing $T > 1$, we "soften" the distribution:
$$ q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} $$

- **High T (> 1)**: Flattens the peaks. The Teacher's output becomes softer, spreading probability mass to neighboring pixels. This reveals the **"Dark Knowledge"**—the relationships between the peak and its surrounding area.
- **Low T (1)**: Standard sharp Softmax.

**In our Code:**
We implement this in `spatial_kl_loss`. We take the raw heatmap logits (before softmax), divide them by `temp`, and *then* apply Softmax. This forces the Student to learn not just the peak location, but the entire spatial uncertainty shape of the Teacher.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os, math
from tqdm import tqdm

teacher_model = teacher_model.to(device).eval()
student_model_logits = SqueezeNetHPE(num_keypoints=17).to(device)

optimizer = optim.AdamW(student_model_logits.parameters(), lr=1e-3, weight_decay=1e-4)
lambda_lr = lambda epoch: (epoch + 1) / 5 if epoch < 5 else 0.5 * (1 + math.cos(math.pi * (epoch - 5) / 25))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr)
scaler = torch.amp.GradScaler('cuda')
criterion = AdaptiveWingLoss().to(device)
def spatial_kl_loss(s_logits, t_logits, temp=4.0):
    B, K, H, W = s_logits.shape
    s_prob = F.log_softmax(s_logits.view(B, K, -1) / temp, dim=-1)
    t_prob = F.softmax(t_logits.view(B, K, -1) / temp, dim=-1)
    return F.kl_div(s_prob, t_prob, reduction='batchmean') * (temp**2)

best_val_loss = float('inf')
checkpoint_dir_kd = 'checkpoints_logits_kd'
os.makedirs(checkpoint_dir_kd, exist_ok=True)

for epoch in range(30):
    student_model_logits.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f'KD Epoch {epoch+1}/30')
    
    for images, targets, masks, _ in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        masks_expanded = masks.to(device, non_blocking=True).unsqueeze(-1).unsqueeze(-1)
        
        optimizer.zero_grad(set_to_none=True)
        
        with torch.amp.autocast('cuda'):
            with torch.no_grad():
                t_heatmaps = teacher_model(images)
            
            s_out = student_model_logits(images)
            s_heatmaps = s_out.heatmaps
            
            loss_gt = criterion(s_heatmaps * masks_expanded, targets * masks_expanded)
            loss_kd = spatial_kl_loss(s_heatmaps, t_heatmaps, temp=4.0)
            loss = loss_gt + (1.0 * loss_kd)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    scheduler.step()
    
    # Updated validation call to match your specific function signature
    val_loss = validate(student_model_logits, val_loader, criterion)
    print(f"Epoch {epoch+1} | Train Loss: {running_loss/len(train_loader):.6f} | Val Loss: {val_loss:.6f}")
    
    # Save if validation loss improves (decreases)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(student_model_logits.state_dict(), os.path.join(checkpoint_dir_kd, 'best_model.pth'))
        print(f"-> Best model saved (Val Loss: {val_loss:.6f})")

print(f"Training Complete. Best Val Loss: {best_val_loss:.6f}")

### Evaluate Logits KD Model


We now evaluate the performance of the Student model trained with Logits-based Knowledge Distillation.
Compare this accuracy with the Baseline Student to see the improvement.


In [None]:
print("Evaluating Logits-KD Student...")
acc_kd = validate(None, val_loader, coco2017_val, student_model_logits)
print(f'Logits KD Student Accuracy (Approx PCK): {acc_kd*100:.2f}%')


In [None]:
# launch a new session or update existing
# Note: You might need to close the previous session manually in the UI if it conflicts,
# but fiftyone usually handles multiple sessions fine or we can reuse `session`.

print("Visualizing Logits-KD predictions...")
val_view_kd = add_predictions_to_fiftyone(coco2017_val, student_model_logits, num_samples=20)
session_kd = fo.launch_app(val_view_kd)


# 3. Feature-based Knowledge Distillation

## Theory
Feature-based KD encourages the Student to learn intermediate representations that resemble the Teacher's. Since Student and Teacher features often have different dimensions (channels/resolution), we typically use a **Connector** (e.g., 1x1 Conv) to map Student features to the Teacher's space.

We will use **Forward Hooks** to extract intermediate feature maps.

**Loss Function:**
$$ L_{total} = L_{GT} + \beta L_{Feat}(F_{Adaptor}(F_S), F_T) $$


In [14]:
# Feature Extraction using Hooks

teacher_features = {}
student_features = {}

def get_activation(name, storage_dict):
    def hook(model, input, output):
        storage_dict[name] = output
    return hook

# Register Hooks
# We need to know the layer names. For SqueezeNet/HRNet, we pick a mid-level layer.
# Example: 'features.12' for SqueezeNet (Fire8) and a corresponding stage for HRNet.

# Note: You must check your specific model structure (print(model)) to choose layers.
# Here, we assume generic names for demonstration.

# teacher_model.layerX.register_forward_hook(get_activation('mid_layer', teacher_features))
# student_model.features[12].register_forward_hook(get_activation('mid_layer', student_features))

print("Hooks defined (Need to register on specific layers depending on model struct)")


Hooks defined (Need to register on specific layers depending on model struct)


In [15]:
# Hooks to capture features
teacher_features = {}
student_features = {}

def get_activation(name, storage):
    def hook(model, input, output):
        storage[name] = output
    return hook

# Register hooks
# Inspect models to find good layers. For SqueezeNet, maybe 'features.12' (Fire module).
# For HRNet, standard layers.
# Here we assume we looked at the architecture. Let's pick a mid-level feature.

# Checking SqueezeNet structure
# print(student_model)
# SqueezeNet: features is a Sequential. 'features.10' is a Fire module.

handle_t = teacher_model.layer1.register_forward_hook(get_activation('feat', teacher_features)) if hasattr(teacher_model, 'layer1') else None
handle_s = student_model.backbone[10].register_forward_hook(get_activation('feat', student_features))

if handle_t is None:
    # Fallback if layer name differs (common in different HRNet impls)
    # Using a dummy forward to find a valid layer if needed, but for now we trust `layer1` or similar exists
    # If HRNet structure is complex, we might skip or assume a specific index.
    pass


Note: We need a **Feature Adaptor** to match dimensions. Let's assume we capture features and define an adaptor dynamically or pre-set it.


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import os, math
from tqdm.auto import tqdm

# Ensure models are on the correct device and teacher is in eval mode
teacher_model = teacher_model.to(device).eval()
student_model_feat = SqueezeNetHPE(num_keypoints=17).to(device)

# FeatureAdapter to match student channels to teacher channels
# Adjust in_channels/out_channels based on your specific model layers
adapter = nn.Conv2d(in_channels=512, out_channels=48, kernel_size=1).to(device)

# Optimization: Combine parameters for the optimizer
optimizer = optim.AdamW(list(student_model_feat.parameters()) + list(adapter.parameters()), lr=1e-3)
lambda_lr = lambda epoch: (epoch + 1) / 5 if epoch < 5 else 0.5 * (1 + math.cos(math.pi * (epoch - 5) / 5))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr)
criterion = AdaptiveWingLoss().to(device)
# Modern AMP setup
scaler = torch.amp.GradScaler('cuda')
mse_loss = nn.MSELoss()

best_val_loss = float('inf')
checkpoint_dir_feat = 'checkpoints_feat_kd'
os.makedirs(checkpoint_dir_feat, exist_ok=True)

print("--- Starting Feature-based KD Training ---")

for epoch in range(10):
    student_model_feat.train()
    adapter.train()
    
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f'Feature KD Epoch {epoch+1}/10')
    
    for imgs, targets, masks, _ in pbar:
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        masks_expanded = masks.to(device, non_blocking=True).unsqueeze(-1).unsqueeze(-1)
        
        optimizer.zero_grad(set_to_none=True)
        
        with torch.amp.autocast('cuda'):
            # Teacher forward (No gradients)
            with torch.no_grad():
                t_out = teacher_model(imgs)
                t_activation = t_out # Adjust this if using hooks for intermediate layers
            
            # Student forward
            outputs = student_model_feat(imgs)
            s_heatmaps = outputs.heatmaps
            s_activation = s_heatmaps # Placeholder; adjust if using intermediate features
            
            # Adapt student features to teacher spatial/channel dimensions
            s_adapted = adapter(s_activation) if s_activation.shape != t_activation.shape else s_activation
            
            # 1. Ground Truth Loss
            loss_gt = criterion(s_heatmaps * masks_expanded, targets * masks_expanded)
            
            # 2. Feature-based MSE Loss
            loss_feat = mse_loss(s_adapted, t_activation)
            
            total_loss = loss_gt + (0.5 * loss_feat)
        
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += total_loss.item()
        pbar.set_postfix({'loss': f"{total_loss.item():.4f}"})
    
    scheduler.step()
    
    # Validation using your heatmap-level loss function
    val_loss = validate(student_model_feat, val_loader, criterion)
    print(f"Epoch {epoch+1} | Train Loss: {running_loss/len(train_loader):.6f} | Val Loss: {val_loss:.6f}")
    
    # Checkpoint based on validation loss improvement
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(student_model_feat.state_dict(), os.path.join(checkpoint_dir_feat, 'best_feat_model.pth'))
        print(f"-> Best model saved (Val Loss: {val_loss:.6f})")

print(f"Feature KD Training Completed. Best Val Loss: {best_val_loss:.6f}")

--- Starting Feature-based KD Training ---


Feature KD Epoch 1/10:  12%|█▏        | 145/1171 [00:29<03:25,  4.99it/s, loss=0.1763] 


KeyboardInterrupt: 

### Evaluate Feature KD Model


In [None]:
print("--- Final Results ---")
print(f"Baseline Student:     {best_acc*100:.2f}%")
print(f"Logits-KD Student:    {best_acc_kd*100:.2f}%")
print(f"Feature-KD Student:   {best_acc_feat*100:.2f}%")

# Teacher Eval (Reference)
acc_teacher = validate(None, val_loader, coco2017_val, teacher_model)
print(f"Teacher (HRNet):      {acc_teacher*100:.2f}%")


# Conclusion Results Comparison
Let's compare the performance of all three approaches.


In [None]:
print(f'Baseline: {acc*100:.2f}%')
print(f'Logits KD: {acc_logits*100:.2f}%')
print(f'Feature KD: {acc_feat*100:.2f}%')
