# Knowledge Distillation Experiments - Unified Framework

## Master's Thesis: Robust Knowledge Distillation for Compact Vision Models

**Author:** Gheith Alrawahi  
**Institution:** Nankai University  
**Supervisor:** Prof. Jing Wang

---

### Experiments Overview:

| ID   | Name       | Method       | Key Feature                     |
| :--- | :--------- | :----------- | :------------------------------ |
| v1   | Baseline   | Standard KD  | Mixup + CutMix only             |
| v2   | Enhanced   | Standard KD  | + AutoAugment + Label Smoothing |
| v3   | DKD β=8.0  | Decoupled KD | Default DKD parameters          |
| v3.1 | DKD β=2.0  | Decoupled KD | Tuned beta parameter            |
| v4   | Saturation | Standard KD  | Strong teacher + Standard KD    |

---


## 1. Setup and Configuration


In [1]:
# Cell 1: Imports and Setup
import os
import sys
import time
import copy
from pathlib import Path
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

# Local imports
from config import (
    RESULTS_DIR, MODELS_DIR, CHECKPOINTS_DIR,
    get_experiment, ALL_EXPERIMENTS,
    ExperimentConfig
)
from utils import (
    set_seed, mixup_data, cutmix_data,
    kd_loss_with_mixup, evaluate_model,
    save_checkpoint, load_checkpoint, cleanup_checkpoints,
    TrainingLogger  # Use unified logger
)
from data import get_dataloaders
from models import create_teacher_model, create_student_model, load_teacher_model

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Device: cuda
GPU: NVIDIA GeForce RTX 5070 Laptop GPU
Memory: 8.5 GB


In [2]:
# Cell 2: Select Experiment
# ============================================================
# CHANGE THIS TO RUN DIFFERENT EXPERIMENTS
# Options: "v1", "v2", "v3", "v3.1", "v4"
# ============================================================

EXPERIMENT_NAME = "v1"  # <-- CHANGE THIS

# ============================================================

# Load configuration
config = get_experiment(EXPERIMENT_NAME)

print("=" * 60)
print(f"EXPERIMENT: {config.experiment_id}")
print(f"Name: {config.experiment_name}")
print(f"Description: {config.description}")
print("=" * 60)
print(f"\nMethod: {config.distillation.method}")
print(f"Temperature: {config.distillation.temperature}")
if config.distillation.method == "dkd":
    print(f"DKD Alpha: {config.distillation.dkd_alpha}")
    print(f"DKD Beta: {config.distillation.dkd_beta}")
else:
    print(f"KD Alpha: {config.distillation.alpha}")
print(f"\nAugmentation:")
print(f"  AutoAugment: {config.augmentation.auto_augment}")
print(f"  RandomErasing: {config.augmentation.random_erasing}")
print(f"  Mixup: {config.augmentation.mixup}")
print(f"  CutMix: {config.augmentation.cutmix}")
print(f"\nTraining:")
print(f"  Epochs: {config.base.num_epochs}")
print(f"  Batch Size: {config.base.batch_size}")
print(f"  Learning Rate: {config.base.learning_rate}")
print(f"  Early Stopping Patience: {config.base.patience}")

EXPERIMENT: v1_baseline
Name: Baseline Standard KD
Description: Standard KD with basic Mixup/CutMix augmentation

Method: standard_kd
Temperature: 4.0
KD Alpha: 0.7

Augmentation:
  AutoAugment: False
  RandomErasing: False
  Mixup: True
  CutMix: True

Training:
  Epochs: 200
  Batch Size: 32
  Learning Rate: 0.001
  Early Stopping Patience: 30


In [3]:
# Cell 3: Set Seed and Initialize Logger
set_seed(config.base.seed)

# Get results directory with timestamp (prevents overwriting)
# Set use_timestamp=False if you want to overwrite previous results
USE_TIMESTAMP = True  # <-- Set to False to overwrite previous runs

results_dir = config.get_results_dir(use_timestamp=USE_TIMESTAMP)
print(f"Results will be saved to: {results_dir}")

# Initialize Student Logger (same structure as Teacher)
logger = TrainingLogger(config._run_id, RESULTS_DIR, model_type="student")

# Save configuration
config.save()

Random seed set to: 42
Results will be saved to: d:\Projects\KnowledgeDistillation\code_v3_224\results\v1_baseline_20251209_024347
Logger initialized: d:\Projects\KnowledgeDistillation\code_v3_224\results\v1_baseline_20251209_024347
Config saved: d:\Projects\KnowledgeDistillation\code_v3_224\results\v1_baseline_20251209_024347\config.json


## 2. Data Loading


In [4]:
# Cell 4: Load Data
# Reload data module to get latest changes
import importlib
import data
importlib.reload(data)
from data import get_dataloaders

train_loader, test_loader = get_dataloaders(
    aug_config=config.augmentation,
    batch_size=config.base.batch_size,
    num_workers=config.base.num_workers,
    pin_memory=config.base.pin_memory,
    image_size=config.base.image_size  # Pass image_size for 224x224 upscaling
)

Using pre-loaded 224x224 data from .pt files...
  Note: Using num_workers=0 (required for preloaded data on Windows)
Data loaded (preloaded):
  Training samples: 50000
  Test samples: 10000
  Batch size: 32
  Image size: 224x224
  Augmentation: AutoAugment=False, RandomErasing=False, Mixup=True, CutMix=True


## 3. Model Initialization


In [5]:
# Cell 5: Create/Load Teacher Model
print("\n" + "=" * 60)
print("TEACHER MODEL")
print("=" * 60)

# Check for existing teacher
teacher_path = MODELS_DIR / "teacher_trained.pth"

if teacher_path.exists():
    print(f"Loading existing teacher from: {teacher_path}")
    teacher_model = load_teacher_model(str(teacher_path), device=device)
    
    # Evaluate teacher
    teacher_results = evaluate_model(teacher_model, test_loader, device)
    teacher_accuracy = teacher_results['accuracy']
    print(f"Teacher Accuracy: {teacher_accuracy:.2f}%")
    
    TRAIN_TEACHER = False
else:
    print("No trained teacher found. Will train from scratch.")
    teacher_model = create_teacher_model(device=device)
    teacher_accuracy = 0.0
    TRAIN_TEACHER = True


TEACHER MODEL
No trained teacher found. Will train from scratch.
Teacher Model: EfficientNetV2-L
  Parameters: 117,362,372 (117.36M)
  Size: 449.66 MB
  Pretrained: True


In [6]:
# Cell 6: Train Teacher (if needed)
if TRAIN_TEACHER:
    print("\n" + "=" * 60)
    print("TRAINING TEACHER MODEL")
    print("=" * 60)
    
    # Teacher training configuration
    TEACHER_EPOCHS = config.base.num_epochs
    
    # Create checkpoint directory for teacher
    teacher_checkpoint_dir = CHECKPOINTS_DIR / "teacher"
    teacher_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    print(f"Checkpoints: {teacher_checkpoint_dir}")
    
    # Initialize Teacher Logger (same structure as Student)
    teacher_logger = TrainingLogger("teacher", RESULTS_DIR, model_type="teacher")
    teacher_logger.start_training()
    
    # Optimizer and scheduler
    teacher_optimizer = optim.AdamW(
        teacher_model.parameters(),
        lr=config.base.learning_rate,
        weight_decay=config.base.weight_decay
    )
    teacher_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        teacher_optimizer,
        T_max=TEACHER_EPOCHS - config.base.warmup_epochs
    )
    
    # Training setup
    scaler = torch.amp.GradScaler('cuda')
    best_teacher_acc = 0.0
    best_teacher_weights = None
    epochs_no_improve = 0
    
    for epoch in range(TEACHER_EPOCHS):
        teacher_model.train()
        running_loss = 0.0
        
        # Learning rate warmup
        if epoch < config.base.warmup_epochs:
            warmup_lr = config.base.learning_rate * (epoch + 1) / config.base.warmup_epochs
            for param_group in teacher_optimizer.param_groups:
                param_group['lr'] = warmup_lr
        
        loop = tqdm(train_loader, desc=f"Teacher Epoch {epoch+1}/{TEACHER_EPOCHS}")
        
        for inputs, labels in loop:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Apply Mixup or CutMix
            if config.augmentation.mixup or config.augmentation.cutmix:
                if np.random.rand() > 0.5 and config.augmentation.cutmix:
                    inputs, labels_a, labels_b, lam = cutmix_data(
                        inputs.clone(), labels, config.augmentation.cutmix_alpha, device
                    )
                elif config.augmentation.mixup:
                    inputs, labels_a, labels_b, lam = mixup_data(
                        inputs, labels, config.augmentation.mixup_alpha, device
                    )
                else:
                    labels_a, labels_b, lam = labels, labels, 1.0
            else:
                labels_a, labels_b, lam = labels, labels, 1.0
            
            teacher_optimizer.zero_grad()
            
            with torch.amp.autocast('cuda'):
                outputs = teacher_model(inputs)
                loss = lam * nn.functional.cross_entropy(outputs, labels_a, label_smoothing=config.distillation.label_smoothing) + \
                       (1 - lam) * nn.functional.cross_entropy(outputs, labels_b, label_smoothing=config.distillation.label_smoothing)
            
            scaler.scale(loss).backward()
            scaler.unscale_(teacher_optimizer)
            nn.utils.clip_grad_norm_(teacher_model.parameters(), config.base.grad_clip)
            scaler.step(teacher_optimizer)
            scaler.update()
            
            running_loss += loss.item()
            loop.set_postfix(loss=f"{loss.item():.4f}")
        
        # Step scheduler after warmup
        if epoch >= config.base.warmup_epochs:
            teacher_scheduler.step()
        
        # Validation
        train_loss = running_loss / len(train_loader)
        val_results = evaluate_model(teacher_model, test_loader, device)
        val_acc = val_results['accuracy']
        val_loss = val_results['loss']
        current_lr = teacher_optimizer.param_groups[0]['lr']
        
        # Log epoch (same as student)
        is_best = teacher_logger.log_epoch(epoch + 1, train_loss, val_loss, val_acc, current_lr)
        
        print(f"Epoch {epoch+1}/{TEACHER_EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | LR: {current_lr:.6f}")
        
        # Save best model
        if is_best:
            best_teacher_acc = val_acc
            best_teacher_weights = copy.deepcopy(teacher_model.state_dict())
            epochs_no_improve = 0
            print(f"  * New best teacher! Accuracy: {best_teacher_acc:.2f}%")
            
            # Save best teacher immediately
            torch.save({
                'model_state_dict': best_teacher_weights,
                'accuracy': best_teacher_acc,
                'epoch': epoch + 1
            }, teacher_path)
        else:
            epochs_no_improve += 1
        
        # Save checkpoint and history every N epochs
        if (epoch + 1) % config.base.checkpoint_frequency == 0:
            # Save checkpoint
            checkpoint_path = teacher_checkpoint_dir / f"teacher_epoch_{epoch+1}.pth"
            torch.save({
                'model_state_dict': teacher_model.state_dict(),
                'optimizer_state_dict': teacher_optimizer.state_dict(),
                'scheduler_state_dict': teacher_scheduler.state_dict(),
                'epoch': epoch + 1,
                'best_accuracy': best_teacher_acc
            }, checkpoint_path)
            print(f"  Checkpoint saved: {checkpoint_path.name}")
            
            # Save history (same as student)
            teacher_logger.save_checkpoint_history()
            
            # Cleanup old checkpoints
            cleanup_checkpoints(teacher_checkpoint_dir, keep=config.base.keep_checkpoints)
        
        # Early stopping
        if epochs_no_improve >= config.base.patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Training complete - save final results
    teacher_model.load_state_dict(best_teacher_weights)
    teacher_accuracy = best_teacher_acc
    
    # Save final model
    torch.save({
        'model_state_dict': best_teacher_weights,
        'accuracy': best_teacher_acc
    }, teacher_path)
    
    # Save final results (same structure as student)
    teacher_logger.save_final_results(
        model_name="EfficientNetV2-L",
        total_epochs=epoch + 1,
        early_stopped=(epochs_no_improve >= config.base.patience)
    )


TRAINING TEACHER MODEL
Checkpoints: d:\Projects\KnowledgeDistillation\code_v3_224\checkpoints\teacher
Logger initialized: d:\Projects\KnowledgeDistillation\code_v3_224\results\teacher


Teacher Epoch 1/200:  16%|█▌        | 246/1562 [11:31<1:01:38,  2.81s/it, loss=3.2354]


KeyboardInterrupt: 

In [None]:
# Cell 7: Create Student Model
print("\n" + "=" * 60)
print("STUDENT MODEL")
print("=" * 60)

student_model = create_student_model(device=device)

# Freeze teacher
teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False

## 4. Knowledge Distillation Training


In [None]:
# Cell 8: Setup Training
print("\n" + "=" * 60)
print(f"KNOWLEDGE DISTILLATION - {config.experiment_name}")
print("=" * 60)

# Start timing
logger.start_training()

# Optimizer and scheduler
optimizer = optim.AdamW(
    student_model.parameters(),
    lr=config.base.learning_rate,
    weight_decay=config.base.weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config.base.num_epochs - config.base.warmup_epochs
)

# Training setup
scaler = torch.amp.GradScaler('cuda')
best_acc = 0.0
best_weights = None
epochs_no_improve = 0
early_stopped = False

# Checkpoint directory for this experiment
exp_checkpoint_dir = CHECKPOINTS_DIR / config.experiment_id
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)

print(f"\nDistillation Method: {config.distillation.method}")
print(f"Teacher Accuracy: {teacher_accuracy:.2f}%")
print(f"Checkpoints: {exp_checkpoint_dir}")

In [None]:
# Cell 9: Training Loop
for epoch in range(config.base.num_epochs):
    student_model.train()
    running_loss = 0.0
    valid_batches = 0
    
    # Learning rate warmup
    if epoch < config.base.warmup_epochs:
        warmup_lr = config.base.learning_rate * (epoch + 1) / config.base.warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = warmup_lr
    
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.base.num_epochs}")
    
    for inputs, labels in loop:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Apply Mixup or CutMix
        if config.augmentation.mixup or config.augmentation.cutmix:
            if np.random.rand() > 0.5 and config.augmentation.cutmix:
                inputs, labels_a, labels_b, lam = cutmix_data(
                    inputs.clone(), labels, config.augmentation.cutmix_alpha, device
                )
            elif config.augmentation.mixup:
                inputs, labels_a, labels_b, lam = mixup_data(
                    inputs, labels, config.augmentation.mixup_alpha, device
                )
            else:
                labels_a, labels_b, lam = labels, labels, 1.0
        else:
            labels_a, labels_b, lam = labels, labels, 1.0
        
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda'):
            # Forward pass
            student_outputs = student_model(inputs)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)
            
            # Calculate KD loss
            loss = kd_loss_with_mixup(
                student_outputs, teacher_outputs,
                labels_a, labels_b, lam,
                method=config.distillation.method,
                temperature=config.distillation.temperature,
                alpha=config.distillation.alpha,
                dkd_alpha=config.distillation.dkd_alpha,
                dkd_beta=config.distillation.dkd_beta,
                label_smoothing=config.distillation.label_smoothing
            )
        
        # Skip NaN
        if torch.isnan(loss):
            continue
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(student_model.parameters(), config.base.grad_clip)
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        valid_batches += 1
        loop.set_postfix(loss=f"{loss.item():.4f}")
    
    # Step scheduler after warmup
    if epoch >= config.base.warmup_epochs:
        scheduler.step()
    
    # Validation
    train_loss = running_loss / max(valid_batches, 1)
    val_results = evaluate_model(student_model, test_loader, device)
    val_acc = val_results['accuracy']
    val_loss = val_results['loss']
    current_lr = optimizer.param_groups[0]['lr']
    
    # Log epoch (same structure as teacher)
    is_best = logger.log_epoch(epoch + 1, train_loss, val_loss, val_acc, current_lr)
    
    print(f"Epoch {epoch+1}/{config.base.num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | LR: {current_lr:.6f}")
    
    # Save best model
    if is_best:
        best_acc = val_acc
        best_weights = copy.deepcopy(student_model.state_dict())
        epochs_no_improve = 0
        print(f"  * New best model! Accuracy: {best_acc:.2f}%")
        
        # Save best model
        torch.save(
            {'model_state_dict': best_weights, 'accuracy': best_acc, 'epoch': epoch + 1},
            MODELS_DIR / f"{config.experiment_id}_best.pth"
        )
    else:
        epochs_no_improve += 1
    
    # Checkpointing and save history
    if (epoch + 1) % config.base.checkpoint_frequency == 0:
        checkpoint_path = exp_checkpoint_dir / f"checkpoint_epoch_{epoch+1}.pth"
        save_checkpoint(
            student_model, optimizer, scheduler,
            epoch + 1, best_acc, logger.history,
            checkpoint_path, is_best=(val_acc == best_acc)
        )
        print(f"  Checkpoint saved: {checkpoint_path.name}")
        
        # Save history (same as teacher)
        logger.save_checkpoint_history()
        
        cleanup_checkpoints(exp_checkpoint_dir, keep=config.base.keep_checkpoints)
    
    # Early stopping
    if epochs_no_improve >= config.base.patience:
        print(f"\nEarly stopping triggered at epoch {epoch+1}")
        early_stopped = True
        break
    
    # Clear cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Training complete
total_epochs = epoch + 1

# Load best weights
student_model.load_state_dict(best_weights)

In [None]:
# Cell 10: Save Final Results
# Save training history
logger.save_history()

# Save final results (same structure as teacher)
final_results = logger.save_final_results(
    model_name="EfficientNetV2-S",
    total_epochs=total_epochs,
    early_stopped=early_stopped,
    config=config.to_dict(),
    teacher_accuracy=teacher_accuracy
)

# Save final model
final_model_path = MODELS_DIR / f"{config.experiment_id}_final.pth"
torch.save({
    'model_state_dict': best_weights,
    'config': config.to_dict(),
    'results': final_results,
    'history': logger.history
}, final_model_path)

print(f"\nFinal model saved: {final_model_path}")

In [None]:
# Cell 11: Summary
print("\n" + "=" * 60)
print("EXPERIMENT COMPLETE")
print("=" * 60)
print(f"\nExperiment: {config.experiment_id}")
print(f"Method: {config.distillation.method}")
print(f"\nResults:")
print(f"  Teacher Accuracy: {teacher_accuracy:.2f}%")
print(f"  Student Accuracy: {logger.best_accuracy:.2f}%")
print(f"  Retention Rate: {(logger.best_accuracy/teacher_accuracy)*100:.2f}%")
print(f"\nTraining:")
print(f"  Best Epoch: {logger.best_epoch}")
print(f"  Total Epochs: {total_epochs}")
print(f"  Early Stopped: {early_stopped}")
print(f"  Training Time: {logger.get_training_time():.1f} minutes")
print(f"\nSaved Files:")
print(f"  Results: {logger.results_dir}")
print(f"    - config.json")
print(f"    - training_history.csv")
print(f"    - training_history.json")
print(f"    - final_results.json")
print(f"  Model: {final_model_path}")