# Knowledge Distillation Training Notebook

This notebook demonstrates how to:
1. Load the teacher model (DINOv2)
2. Create the student model (ViT-S/16)
3. Load and visualize the training data
4. Train the student to mimic the teacher

**No labeled data is used** - this is pure self-supervised learning!


## 1. Setup and Imports


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
import yaml
import timm
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os

from data_loader import PretrainDataset
from transforms import MultiCropTransform
from optimizer import build_optimizer, build_scheduler

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## 2. Load Configuration Files


In [None]:
def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

# Load configs
data_cfg = load_config('data_config_optimized.yaml')
train_cfg = load_config('train_config_kd.yaml')
model_cfg = load_config('model_config_kd.yaml')

print("Data Config:")
print(f"  Dataset: {data_cfg['dataset_name']}")
print(f"  Image size: {data_cfg['image_size']}")
print(f"  Workers: {data_cfg['num_workers']}")

print("\nTraining Config:")
print(f"  Batch size: {train_cfg['batch_size']}")
print(f"  Epochs: {train_cfg['num_epochs']}")
print(f"  Learning rate: {train_cfg['learning_rate']}")

print("\nModel Config:")
print(f"  Teacher: {model_cfg['teacher_name']}")
print(f"  Student: {model_cfg['student_name']}")
print(f"  Student image size: {model_cfg['student_img_size']}")


## 3. Load Teacher Model (DINOv2)


In [None]:
print("Loading teacher model (DINOv2)...")
teacher_name = model_cfg['teacher_name']

try:
    teacher = torch.hub.load("facebookresearch/dinov2", teacher_name)
    teacher = teacher.to(device)
    teacher.eval()
    
    # Freeze all parameters
    for param in teacher.parameters():
        param.requires_grad = False
    
    num_params = sum(p.numel() for p in teacher.parameters())
    print(f"✓ Teacher loaded: {teacher_name}")
    print(f"  Parameters: {num_params / 1e6:.2f}M")
    print(f"  Frozen: True")
except Exception as e:
    print(f"✗ Failed to load teacher: {e}")
    raise


## 4. Create Student Model (ViT-S/16)


In [None]:
print("Creating student model...")
student_name = model_cfg['student_name']
student_img_size = model_cfg['student_img_size']

student = timm.create_model(
    student_name,
    pretrained=False,  # Random initialization
    img_size=student_img_size,
    num_classes=0,  # No classification head
)
student = student.to(device)
student.train()

num_params = sum(p.numel() for p in student.parameters())
trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad)

print(f"✓ Student created: {student_name}")
print(f"  Parameters: {num_params / 1e6:.2f}M")
print(f"  Trainable: {trainable_params / 1e6:.2f}M")
print(f"  Image size: {student_img_size}x{student_img_size}")


## 5. Load Training Data


In [None]:
from torchvision import transforms
from torchvision.transforms import InterpolationMode

# Create augmentation transform
use_multi_crop = train_cfg.get('use_multi_crop', False)

if use_multi_crop:
    transform = MultiCropTransform(
        global_crops_scale=tuple(data_cfg.get('global_crops_scale', [0.4, 1.0])),
        local_crops_scale=tuple(data_cfg.get('local_crops_scale', [0.05, 0.4])),
        local_crops_number=data_cfg.get('local_crops_number', 8),
        image_size=student_img_size
    )
else:
    # Simple augmentation for KD
    transform = transforms.Compose([
        transforms.RandomResizedCrop(student_img_size, scale=(0.2, 1.0),
                                   interpolation=InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0))
        ], p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

# Create dataset
dataset = PretrainDataset(transform=transform)
print(f"✓ Dataset loaded: {len(dataset)} images")

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=train_cfg['batch_size'],
    shuffle=True,
    num_workers=data_cfg.get('num_workers', 4),
    pin_memory=data_cfg.get('pin_memory', True),
    drop_last=True
)

print(f"✓ DataLoader created: {len(dataloader)} batches per epoch")


## 6. Visualize Sample Data


In [None]:
# Get a sample batch
sample_batch = next(iter(dataloader))

if isinstance(sample_batch, list):
    # Multi-crop: show first crop
    images = sample_batch[0]
    print(f"Multi-crop batch: {len(sample_batch)} crops")
else:
    images = sample_batch
    print(f"Single-crop batch")

print(f"Batch shape: {images.shape}")

# Visualize first 4 images
fig, axes = plt.subplots(1, 4, figsize=(12, 3))
for i in range(min(4, images.shape[0])):
    img = images[i].cpu()
    # Denormalize
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img * std + mean
    img = torch.clamp(img, 0, 1)
    
    axes[i].imshow(img.permute(1, 2, 0))
    axes[i].axis('off')
    axes[i].set_title(f'Image {i+1}')

plt.tight_layout()
plt.show()


## 7. Define Feature Extraction Functions


In [None]:
def extract_teacher_features(teacher, images, use_cls_token=True):
    """Extract features from frozen teacher model"""
    with torch.no_grad():
        features = teacher.forward_features(images)
        
        # Handle DINOv2 output format (dict or tensor)
        if isinstance(features, dict):
            if 'x_norm_clstoken' in features:
                cls_embedding = features['x_norm_clstoken']
            elif 'cls_token' in features:
                cls_embedding = features['cls_token']
            else:
                cls_embedding = features.get('x', features.get('tokens', None))[:, 0]
            
            if 'x_norm_patchtokens' in features:
                patch_embeddings = features['x_norm_patchtokens']
            elif 'patch_tokens' in features:
                patch_embeddings = features['patch_tokens']
            else:
                patch_embeddings = features.get('x', features.get('tokens', None))[:, 1:]
        else:
            # Tensor format [B, N+1, D]
            if use_cls_token:
                cls_embedding = features[:, 0]
            else:
                cls_embedding = features[:, 1:].mean(dim=1)
            patch_embeddings = features[:, 1:]
        
        # Normalize
        cls_embedding = F.normalize(cls_embedding, dim=-1, p=2)
        patch_embeddings = F.normalize(patch_embeddings, dim=-1, p=2)
    
    return cls_embedding, patch_embeddings


def extract_student_features(student, images, use_cls_token=True):
    """Extract features from student model"""
    features = student.forward_features(images)
    
    if use_cls_token:
        cls_embedding = features[:, 0]
    else:
        cls_embedding = features[:, 1:].mean(dim=1)
    
    patch_embeddings = features[:, 1:]
    
    # Normalize
    cls_embedding = F.normalize(cls_embedding, dim=-1, p=2)
    patch_embeddings = F.normalize(patch_embeddings, dim=-1, p=2)
    
    return cls_embedding, patch_embeddings


## 8. Test Feature Extraction


In [None]:
# Test with a small batch
test_images = images[:2].to(device)

print("Testing feature extraction...")

# Teacher features
teacher_cls, teacher_patches = extract_teacher_features(teacher, test_images)
print(f"Teacher CLS shape: {teacher_cls.shape}")
print(f"Teacher patches shape: {teacher_patches.shape}")

# Student features
student_cls, student_patches = extract_student_features(student, test_images)
print(f"Student CLS shape: {student_cls.shape}")
print(f"Student patches shape: {student_patches.shape}")

# Compute initial similarity
cosine_sim_cls = F.cosine_similarity(teacher_cls, student_cls, dim=-1).mean().item()
print(f"\nInitial CLS cosine similarity: {cosine_sim_cls:.4f}")
print("(This will increase during training)")


## 9. Define Distillation Loss


In [None]:
def compute_distillation_loss(student_cls, student_patches, 
                             teacher_cls, teacher_patches,
                             loss_weights=None):
    """Compute distillation loss between student and teacher embeddings"""
    if loss_weights is None:
        loss_weights = {'cls': 1.0, 'patch': 0.5}
    
    # CLS token loss
    if student_cls.shape[-1] == teacher_cls.shape[-1]:
        loss_cls = F.mse_loss(student_cls, teacher_cls)
    else:
        # Different dimensions: use cosine similarity loss
        cosine_sim = F.cosine_similarity(student_cls, teacher_cls, dim=-1)
        loss_cls = (1 - cosine_sim).mean()
    
    # Patch embeddings loss
    B_s, N_s, D_s = student_patches.shape
    B_t, N_t, D_t = teacher_patches.shape
    
    if N_s == N_t and D_s == D_t:
        loss_patch = F.mse_loss(student_patches, teacher_patches)
    elif D_s == D_t:
        # Same embedding dim, different num patches
        if N_s < N_t:
            teacher_patches = teacher_patches[:, :N_s, :]
        else:
            student_patches = student_patches[:, :N_t, :]
        loss_patch = F.mse_loss(student_patches, teacher_patches)
    else:
        # Different dimensions: use mean-pooled cosine similarity
        student_pooled = student_patches.mean(dim=1)
        teacher_pooled = teacher_patches.mean(dim=1)
        if D_s == D_t:
            loss_patch = F.mse_loss(student_pooled, teacher_pooled)
        else:
            cosine_sim = F.cosine_similarity(student_pooled, teacher_pooled, dim=-1)
            loss_patch = (1 - cosine_sim).mean()
    
    # Weighted combination
    total_loss = loss_weights['cls'] * loss_cls + loss_weights['patch'] * loss_patch
    
    return total_loss, {
        'total': total_loss.item(),
        'cls': loss_cls.item(),
        'patch': loss_patch.item()
    }

# Test loss computation
test_loss, test_metrics = compute_distillation_loss(
    student_cls, student_patches,
    teacher_cls, teacher_patches,
    loss_weights=train_cfg.get('distill_loss_weights', {'cls': 1.0, 'patch': 0.5})
)
print(f"Test loss: {test_loss.item():.4f}")
print(f"  CLS loss: {test_metrics['cls']:.4f}")
print(f"  Patch loss: {test_metrics['patch']:.4f}")


## 10. Setup Optimizer and Scheduler


In [None]:
# Build optimizer
optimizer = build_optimizer(
    student,
    lr=train_cfg['learning_rate'],
    weight_decay=train_cfg['weight_decay'],
    fused=True
)

# Build scheduler
scheduler = build_scheduler(
    optimizer,
    num_epochs=train_cfg['num_epochs'],
    warmup_epochs=train_cfg['warmup_epochs']
)

# GradScaler for mixed precision
scaler = GradScaler(enabled=(device.type == 'cuda'))

print(f"✓ Optimizer: AdamW (lr={train_cfg['learning_rate']})")
print(f"✓ Scheduler: Cosine with {train_cfg['warmup_epochs']} warmup epochs")
print(f"✓ Mixed precision: {'Enabled' if device.type == 'cuda' else 'Disabled'}")


## 11. Training Loop


In [None]:
# Training configuration
num_epochs = train_cfg['num_epochs']
loss_weights = train_cfg.get('distill_loss_weights', {'cls': 1.0, 'patch': 0.5})
use_cls_token = model_cfg.get('use_cls_token', True)

# Track losses
train_losses = []
cls_losses = []
patch_losses = []

# Enable TF32 for faster training
if device.type == 'cuda':
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    print("✓ TF32 enabled")

print(f"\nStarting training for {num_epochs} epochs...")
print(f"Loss weights: CLS={loss_weights['cls']}, Patch={loss_weights['patch']}")
print("-" * 60)


In [None]:
# Training loop
for epoch in range(num_epochs):
    student.train()
    epoch_losses = {'total': [], 'cls': [], 'patch': []}
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, batch in enumerate(progress_bar):
        # Handle multi-crop or single image
        if isinstance(batch, list):
            images = batch[0].to(device)  # Use first crop
        else:
            images = batch.to(device)
        
        optimizer.zero_grad()
        
        # Mixed precision training
        device_type = 'cuda' if device.type == 'cuda' else 'cpu'
        dtype = torch.bfloat16 if device_type == 'cuda' else torch.float32
        
        with autocast(device_type=device_type, dtype=dtype):
            # Teacher forward (frozen)
            teacher_cls, teacher_patches = extract_teacher_features(
                teacher, images, use_cls_token=use_cls_token
            )
            
            # Student forward
            student_cls, student_patches = extract_student_features(
                student, images, use_cls_token=use_cls_token
            )
            
            # Compute distillation loss
            loss, metrics = compute_distillation_loss(
                student_cls, student_patches,
                teacher_cls, teacher_patches,
                loss_weights=loss_weights
            )
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Track losses
        epoch_losses['total'].append(metrics['total'])
        epoch_losses['cls'].append(metrics['cls'])
        epoch_losses['patch'].append(metrics['patch'])
        
        # Update progress bar
        current_lr = optimizer.param_groups[0]['lr']
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'cls': f'{metrics["cls"]:.4f}',
            'patch': f'{metrics["patch"]:.4f}',
            'lr': f'{current_lr:.6f}'
        })
    
    # Step scheduler at end of epoch
    scheduler.step()
    
    # Compute epoch averages
    avg_loss = np.mean(epoch_losses['total'])
    avg_cls = np.mean(epoch_losses['cls'])
    avg_patch = np.mean(epoch_losses['patch'])
    
    train_losses.append(avg_loss)
    cls_losses.append(avg_cls)
    patch_losses.append(avg_patch)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f} "
          f"(CLS: {avg_cls:.4f}, Patch: {avg_patch:.4f})")
    
    # Save checkpoint periodically
    checkpoint_dir = train_cfg.get('checkpoint_dir', './checkpoints')
    checkpoint_dir = os.path.expandvars(checkpoint_dir)
    
    if (epoch + 1) % train_cfg.get('save_freq', 10) == 0 or (epoch + 1) == num_epochs:
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint = {
            'student': student.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'scaler': scaler.state_dict(),
            'epoch': epoch,
        }
        torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}.pth")
        print(f"  ✓ Saved checkpoint: checkpoint_epoch_{epoch+1}.pth")

print("\n✓ Training completed!")


## 12. Visualize Training Progress


In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Total Loss')
plt.plot(cls_losses, label='CLS Loss', alpha=0.7)
plt.plot(patch_losses, label='Patch Loss', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Total Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Total Loss (Log Scale)')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final loss: {train_losses[-1]:.4f}")
print(f"Best loss: {min(train_losses):.4f} (epoch {np.argmin(train_losses)+1})")


## 13. Test Final Embedding Similarity


In [None]:
# Test with a batch
student.eval()
test_images = next(iter(dataloader))
if isinstance(test_images, list):
    test_images = test_images[0]
test_images = test_images[:8].to(device)

with torch.no_grad():
    teacher_cls, teacher_patches = extract_teacher_features(teacher, test_images)
    student_cls, student_patches = extract_student_features(student, test_images)
    
    # Compute cosine similarity
    cosine_sim_cls = F.cosine_similarity(teacher_cls, student_cls, dim=-1).mean().item()
    
    # For patches, compute mean similarity
    if student_patches.shape == teacher_patches.shape:
        cosine_sim_patches = F.cosine_similarity(
            student_patches.view(-1, student_patches.shape[-1]),
            teacher_patches.view(-1, teacher_patches.shape[-1]),
            dim=-1
        ).mean().item()
    else:
        # Different shapes: use pooled embeddings
        student_pooled = student_patches.mean(dim=1)
        teacher_pooled = teacher_patches.mean(dim=1)
        cosine_sim_patches = F.cosine_similarity(student_pooled, teacher_pooled, dim=-1).mean().item()

print(f"Final CLS cosine similarity: {cosine_sim_cls:.4f}")
print(f"Final patch cosine similarity: {cosine_sim_patches:.4f}")
print(f"\nHigher similarity = better distillation!")


## 14. Save Final Model


In [None]:
# Save final checkpoint
checkpoint_dir = train_cfg.get('checkpoint_dir', './checkpoints')
checkpoint_dir = os.path.expandvars(checkpoint_dir)
os.makedirs(checkpoint_dir, exist_ok=True)

final_checkpoint = {
    'student': student.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
    'scaler': scaler.state_dict(),
    'epoch': num_epochs - 1,
    'train_losses': train_losses,
    'config': {
        'model': model_cfg,
        'train': train_cfg,
        'data': data_cfg
    }
}

torch.save(final_checkpoint, f"{checkpoint_dir}/checkpoint_final.pth")
print(f"✓ Final model saved to: {checkpoint_dir}/checkpoint_final.pth")
print(f"\nYou can now use this checkpoint for feature extraction and evaluation!")
