# Knowledge Distillation for EfficientNetV2 on CIFAR-100 (SOTA)

This notebook implements **State-of-the-Art** knowledge distillation with:

- **Decoupled Knowledge Distillation (DKD)** - Separates target and non-target class knowledge
- **CutMix + Mixup** - Advanced data augmentation
- **Google Drive integration** for persistent checkpoint storage

---

## Setup Instructions

### Google Colab Setup:

1. **Enable GPU:** Runtime -> Change runtime type -> **GPU (T4)**
2. **Run Cell 1** to mount Google Drive
3. **Training will auto-resume** from the last checkpoint if interrupted

### Expected Results:

- **Teacher (EfficientNetV2-L):** ~68-70% accuracy
- **Distilled Student (DKD):** ~69-72% accuracy

---


In [None]:
# Cell 1: Setup and Imports
!pip install thop -q

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s, efficientnet_v2_l, EfficientNet_V2_S_Weights, EfficientNet_V2_L_Weights
from tqdm import tqdm
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import glob
import json
import random
from datetime import datetime
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Setup Directories
DRIVE_ROOT = '/content/drive/MyDrive/KnowledgeDistillation_SOTA'
MODEL_DIR = f'{DRIVE_ROOT}/models'
DATA_DIR = f'{DRIVE_ROOT}/data'
CHECKPOINT_DIR = f'{DRIVE_ROOT}/checkpoints'

os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Directories ready at: {DRIVE_ROOT}")

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Cell 2: Experiment Configuration (SOTA)
# ==========================================
# HYPERPARAMETERS
# ==========================================
NUM_EPOCHS = 200            # Increased for proper convergence with Distillation
BATCH_SIZE = 128
LEARNING_RATE = 0.001       # Initial LR
WEIGHT_DECAY = 0.05         # Strong regularization
PATIENCE = 30               # Increased patience for early stopping

# Distillation Params (DKD)
DKD_ALPHA = 1.0             # Weight for Target Knowledge
DKD_BETA = 8.0              # Weight for Non-Target Knowledge (Crucial)
TEMPERATURE = 4.0           # Softmax Temperature

# Augmentation Params
MIXUP_ALPHA = 0.8
CUTMIX_ALPHA = 1.0
CHECKPOINT_FREQUENCY = 20
NUM_CLASSES = 100

print(f"{'='*50}")
print(f"CONFIG: Epochs={NUM_EPOCHS} | Batch={BATCH_SIZE} | Temp={TEMPERATURE}")
print(f"DKD: Alpha={DKD_ALPHA}, Beta={DKD_BETA}")
print(f"{'='*50}")

In [None]:
# Cell 3: Data Loading
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    # Note: CutMix/Mixup are applied in the training loop
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

trainset = torchvision.datasets.CIFAR100(root=DATA_DIR, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, 
                                          num_workers=2, pin_memory=True, drop_last=True)

testset = torchvision.datasets.CIFAR100(root=DATA_DIR, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, 
                                         num_workers=2, pin_memory=True)

print(f"Data loaded: {len(trainset)} Training, {len(testset)} Test images")

In [None]:
# Cell 4: Helper Functions (DKD Loss, CutMix, Utils)

# ==========================================
# 1. Decoupled Knowledge Distillation (DKD)
# Paper: https://arxiv.org/abs/2203.08679
# ==========================================
def dkd_loss(student_logits, teacher_logits, target, alpha=1.0, beta=8.0, temp=4.0):
    gt_mask = _get_gt_mask(student_logits, target)
    other_mask = _get_other_mask(student_logits, target)
    
    pred_student = F.softmax(student_logits / temp, dim=1)
    pred_teacher = F.softmax(teacher_logits / temp, dim=1)
    
    # Target Class Knowledge Distillation (TCKD)
    pred_student_tckd = _cat_mask(pred_student, gt_mask, other_mask)
    pred_teacher_tckd = _cat_mask(pred_teacher, gt_mask, other_mask)
    log_pred_student_tckd = torch.log(pred_student_tckd + 1e-8)
    
    tckd_loss = (
        F.kl_div(log_pred_student_tckd, pred_teacher_tckd, reduction='batchmean')
        * (temp**2)
    )
    
    # Non-Target Class Knowledge Distillation (NCKD)
    pred_teacher_part2 = F.softmax(
        teacher_logits / temp - 1000.0 * gt_mask, dim=1
    )
    log_pred_student_part2 = F.log_softmax(
        student_logits / temp - 1000.0 * gt_mask, dim=1
    )
    
    nckd_loss = (
        F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')
        * (temp**2)
    )
    
    return alpha * tckd_loss + beta * nckd_loss

def _get_gt_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask

def _get_other_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask

def _cat_mask(t, mask1, mask2):
    t1 = (t * mask1).sum(dim=1, keepdims=True)
    t2 = (t * mask2).sum(1, keepdims=True)
    rt = torch.cat([t1, t2], dim=1)
    return rt

# ==========================================
# 2. Augmentations: Mixup & CutMix
# ==========================================
def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

# ==========================================
# 3. Utilities (Save/Load/Evaluate)
# ==========================================
def evaluate_model_with_loss(model, dataloader, criterion):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    avg_loss = running_loss / len(dataloader)
    return acc, avg_loss

def save_checkpoint(model, optimizer, scheduler, epoch, best_acc, history, model_name, epochs_no_improve):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_acc': best_acc,
        'history': history,
        'epochs_no_improve': epochs_no_improve
    }
    path = f"{CHECKPOINT_DIR}/{model_name}_epoch{epoch+1}.pth"
    torch.save(checkpoint, path)
    print(f"  Checkpoint saved: {path}")
    
def load_checkpoint(model, optimizer, scheduler, model_name):
    checkpoints = sorted(glob.glob(f"{CHECKPOINT_DIR}/{model_name}_epoch*.pth"))
    if not checkpoints:
        return None
    latest = checkpoints[-1]
    print(f"  Loading checkpoint: {latest}")
    checkpoint = torch.load(latest)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return checkpoint

def cleanup_old_checkpoints(model_name, keep=3):
    checkpoints = sorted(glob.glob(f"{CHECKPOINT_DIR}/{model_name}_epoch*.pth"))
    if len(checkpoints) > keep:
        for chk in checkpoints[:-keep]:
            os.remove(chk)
            print(f"  Cleaned up: {os.path.basename(chk)}")

print("Helper functions loaded (DKD, CutMix, Mixup)")

In [None]:
# Cell 5: Optimized Training Loop
def train_model_optimized(model, dataloader, optimizer, scheduler, num_epochs, model_name, 
                         teacher_model=None, dkd_alpha=1.0, dkd_beta=8.0, temp=4.0, 
                         patience=30, grad_clip=1.0):
    
    # 1. Load Checkpoint
    checkpoint = load_checkpoint(model, optimizer, scheduler, model_name)
    if checkpoint:
        start_epoch = checkpoint['epoch'] + 1
        best_acc = checkpoint['best_acc']
        history = checkpoint['history']
        epochs_no_improve = checkpoint['epochs_no_improve']
        best_model_wts = copy.deepcopy(model.state_dict())
        print(f"  Resuming from epoch {start_epoch}, Best Acc: {best_acc:.2f}%")
    else:
        start_epoch = 0
        best_acc = 0.0
        history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
        epochs_no_improve = 0
        best_model_wts = copy.deepcopy(model.state_dict())
        print(f"  Starting fresh training...")

    # 2. Setup
    scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())
    val_criterion = nn.CrossEntropyLoss()
    
    if teacher_model:
        teacher_model.eval()

    # 3. Training Loop
    for epoch in range(start_epoch, num_epochs):
        model.train()
        running_loss = 0.0
        
        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for inputs, labels in loop:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Randomly choose Mixup (50%) or CutMix (50%)
            use_cutmix = np.random.rand() > 0.5
            if use_cutmix:
                inputs_aug, labels_a, labels_b, lam = cutmix_data(inputs.clone(), labels, alpha=CUTMIX_ALPHA)
            else:
                inputs_aug, labels_a, labels_b, lam = mixup_data(inputs, labels, alpha=MIXUP_ALPHA)
            
            optimizer.zero_grad()
            
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                # Teacher Forward
                if teacher_model:
                    with torch.no_grad():
                        teacher_outputs = teacher_model(inputs_aug)
                
                # Student Forward
                student_outputs = model(inputs_aug)
                
                # Loss Calculation
                if teacher_model:
                    # DKD Loss (Calculated for both mixed labels)
                    loss_a = dkd_loss(student_outputs, teacher_outputs, labels_a, dkd_alpha, dkd_beta, temp)
                    loss_b = dkd_loss(student_outputs, teacher_outputs, labels_b, dkd_alpha, dkd_beta, temp)
                    loss = lam * loss_a + (1 - lam) * loss_b
                else:
                    # Standard CE (for Teacher training)
                    loss = lam * nn.CrossEntropyLoss()(student_outputs, labels_a) + \
                           (1 - lam) * nn.CrossEntropyLoss()(student_outputs, labels_b)
            
            # Backward
            scaler.scale(loss).backward()
            
            if grad_clip > 0:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        # Step Scheduler
        scheduler.step()
        
        # Validation
        train_loss = running_loss / len(dataloader)
        val_acc, val_loss = evaluate_model_with_loss(model, testloader, val_criterion)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_acc)
        
        current_lr = scheduler.get_last_lr()[0]
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | LR: {current_lr:.6f}")
        
        # Save Best
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), f"{MODEL_DIR}/{model_name}.pth")
            epochs_no_improve = 0
            print(f"  New best model saved! Accuracy: {best_acc:.2f}%")
        else:
            epochs_no_improve += 1
            
        # Checkpointing
        if (epoch + 1) % CHECKPOINT_FREQUENCY == 0:
            save_checkpoint(model, optimizer, scheduler, epoch, best_acc, history, model_name, epochs_no_improve)
            cleanup_old_checkpoints(model_name)
            
        # Early Stopping
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    print(f"\nTraining complete. Best accuracy: {best_acc:.2f}%")
    model.load_state_dict(best_model_wts)
    return model, history

print("Training function loaded")

In [None]:
# Cell 6: Initialize Models
print("Loading Teacher (EfficientNetV2-L)...")
teacher_model = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.IMAGENET1K_V1)
teacher_model.classifier[1] = nn.Linear(teacher_model.classifier[1].in_features, NUM_CLASSES)
teacher_model = teacher_model.to(device)

print("Loading Student (EfficientNetV2-S)...")
student_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
student_model.classifier[1] = nn.Linear(student_model.classifier[1].in_features, NUM_CLASSES)
student_model = student_model.to(device)

print("Models loaded")

In [None]:
# Cell 7: Train/Load Teacher Model
print("\n" + "="*70)
print("TEACHER MODEL")
print("="*70)

teacher_path = f"{MODEL_DIR}/teacher_model.pth"
if os.path.exists(teacher_path):
    print(f"Found existing Teacher Model: {teacher_path}")
    teacher_model.load_state_dict(torch.load(teacher_path))
else:
    print("Training Teacher Model (This may take a while)...")
    opt_t = optim.AdamW(teacher_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    sch_t = optim.lr_scheduler.CosineAnnealingLR(opt_t, T_max=NUM_EPOCHS)
    teacher_model, teacher_history = train_model_optimized(
        teacher_model, trainloader, opt_t, sch_t, NUM_EPOCHS, "teacher_model", teacher_model=None
    )

# Evaluate Teacher
teacher_model.eval()
teacher_accuracy, _ = evaluate_model_with_loss(teacher_model, testloader, nn.CrossEntropyLoss())
print(f"\nTeacher Accuracy: {teacher_accuracy:.2f}%")

In [None]:
# Cell 8: Train Distilled Student (SOTA - DKD + CutMix)
print("\n" + "="*70)
print("DISTILLED STUDENT MODEL (DKD + CutMix/Mixup)")
print("="*70)

student_name = "distilled_student_dkd_sota"
student_path = f"{MODEL_DIR}/{student_name}.pth"

if os.path.exists(student_path):
    print(f"Found existing Distilled Model: {student_path}")
    student_model.load_state_dict(torch.load(student_path))
    distilled_accuracy, _ = evaluate_model_with_loss(student_model, testloader, nn.CrossEntropyLoss())
    print(f"Distilled Student Accuracy: {distilled_accuracy:.2f}%")
else:
    print(f"\nStarting SOTA Distillation (DKD + CutMix)...")
    print(f"  DKD Alpha (TCKD): {DKD_ALPHA}")
    print(f"  DKD Beta (NCKD): {DKD_BETA}")
    print(f"  Temperature: {TEMPERATURE}")
    
    # Optimizer & Scheduler
    opt_s = optim.AdamW(student_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    sch_s = optim.lr_scheduler.CosineAnnealingLR(opt_s, T_max=NUM_EPOCHS)
    
    # Run SOTA Training
    trained_student, distilled_history = train_model_optimized(
        model=student_model,
        dataloader=trainloader,
        optimizer=opt_s,
        scheduler=sch_s,
        num_epochs=NUM_EPOCHS,
        model_name=student_name,
        teacher_model=teacher_model,
        dkd_alpha=DKD_ALPHA,
        dkd_beta=DKD_BETA,
        temp=TEMPERATURE,
        patience=PATIENCE
    )
    
    distilled_accuracy, _ = evaluate_model_with_loss(trained_student, testloader, nn.CrossEntropyLoss())
    print(f"\nDistilled Student Final Accuracy: {distilled_accuracy:.2f}%")

In [None]:
# Cell 9: Results Summary
print("\n" + "="*80)
print("FINAL RESULTS SUMMARY")
print("="*80)

print("\n| Model                    | Accuracy (%) |")
print("|--------------------------|--------------|")
try:
    print(f"| Teacher (EfficientNet-L) | {teacher_accuracy:12.2f} |")
except NameError:
    print(f"| Teacher (EfficientNet-L) | {'N/A':>12} |")

try:
    print(f"| Distilled (DKD+CutMix)   | {distilled_accuracy:12.2f} |")
except NameError:
    print(f"| Distilled (DKD+CutMix)   | {'N/A':>12} |")

try:
    improvement = distilled_accuracy - teacher_accuracy
    print(f"\n{'='*80}")
    if improvement > 0:
        print(f"Student SURPASSED Teacher by: {improvement:+.2f}%")
    else:
        print(f"Gap from Teacher: {improvement:.2f}%")
    print(f"{'='*80}")
except NameError:
    pass

print(f"\nAll models saved to: {MODEL_DIR}")
print(f"Checkpoints saved to: {CHECKPOINT_DIR}")