# Knowledge Distillation Training with Cached Tensors (2-Stage Pipeline)

This notebook demonstrates the **2-stage cached training pipeline**:
1. **Stage 1**: Precompute and cache preprocessed image tensors
2. **Stage 2**: Train using cached tensors for fast data loading

## Key Features

- ✅ **2-Stage Pipeline**: Precompute once, train many times
- ✅ **Cached Tensors**: 10-20× faster data loading (no HuggingFace API calls)
- ✅ **Step-capped epochs**: Only `max_steps_per_epoch` batches per epoch
- ✅ **Auto-resume**: Automatically detects and resumes from latest checkpoint
- ✅ **Optimized DataLoader**: Fast tensor loading from shard files

## Workflow

1. **Precompute cache** (one-time, ~2-4 hours): Process all 500K images and save as tensor shards
2. **Train with cache** (fast): Load from cached tensors instead of HuggingFace

**No labeled data is used** - this is pure self-supervised learning!


## 1. Setup and Imports


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler
# Try to import new autocast API (PyTorch 2.0+), fall back to old API
try:
    from torch.amp import autocast  # PyTorch 2.0+
    AUTOCAST_NEW_API = True
except ImportError:
    from torch.cuda.amp import autocast  # PyTorch < 2.0
    AUTOCAST_NEW_API = False
from torch.utils.data import DataLoader
import yaml
import timm
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os
import time
import itertools  # For limiting DataLoader iterations
import hashlib  # For teacher feature caching
from pathlib import Path  # For teacher feature caching

from data_loader import build_pretraining_dataloader, build_precompute_dataset, CachedTensorDataset
from transforms import SimpleTransform, FastMultiCropTransform
from optimizer import build_optimizer, build_scheduler

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    # Enable TF32 for faster training
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    print("✓ TF32 enabled for faster training")
else:
    print("⚠️  WARNING: Running on CPU - training will be very slow!")


## 2. Load Configuration Files


In [None]:
def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

# Load configs
data_cfg = load_config('data_config.yaml')
train_cfg = load_config('train_config_kd.yaml')
model_cfg = load_config('model_config_kd.yaml')

print("Data Config:")
print(f"  Dataset: {data_cfg['dataset_name']}")
print(f"  Image size: {data_cfg['image_size']}")
print(f"  Workers: {data_cfg['num_workers']}")
print(f"  Use cached: {data_cfg.get('use_cached', False)}")
if data_cfg.get('use_cached', False):
    print(f"  Cache dir: {data_cfg.get('cache_dir', 'N/A')}")
    print(f"  Shard size: {data_cfg.get('cache_shard_size', 'N/A')}")

print("\nTraining Config:")
print(f"  Batch size: {train_cfg['batch_size']}")
print(f"  Epochs: {train_cfg['num_epochs']}")
print(f"  Learning rate: {train_cfg['learning_rate']}")
print(f"  Max steps per epoch: {train_cfg.get('max_steps_per_epoch', 'None (full epoch)')}")

print("\nModel Config:")
print(f"  Teacher: {model_cfg['teacher_name']}")
print(f"  Student: {model_cfg['student_name']}")
print(f"  Student image size: {model_cfg['student_img_size']}")

# Get settings
use_cached = data_cfg.get('use_cached', False)
max_steps_per_epoch = train_cfg.get('max_steps_per_epoch', None)

if use_cached:
    cache_dir = data_cfg.get('cache_dir', './cache_images')
    cache_dir = os.path.expandvars(cache_dir)
    print(f"\n✓ Cached mode enabled")
    print(f"  Cache directory: {cache_dir}")
    # Check if cache exists
    index_path = os.path.join(cache_dir, 'index.json')
    if os.path.exists(index_path):
        print(f"  ✓ Cache found: {index_path}")
    else:
        print(f"  ⚠️  Cache not found! Run Stage 1 (precompute) first.")
else:
    print(f"\n✓ Original mode (HuggingFace/raw images)")
    print(f"  To use cached mode, set use_cached: true in data_config.yaml")


## Stage 1: Precompute Cache (One-Time Setup)

**Run this section once** to preprocess and cache all images. This takes ~2-4 hours but only needs to be done once.


In [None]:
# Stage 1: Precompute cache
# Uncomment and run this cell to create the cache (one-time, ~2-4 hours)

# from precompute_cache import precompute_cache
# 
# print("Starting cache precomputation...")
# print("This will process all 500K images and save as tensor shards.")
# print("Expected time: ~2-4 hours")
# print("-" * 60)
# 
# precompute_cache(data_cfg, train_cfg, batch_size=256)
# 
# print("\n✓ Cache precomputation complete!")
# print("  You can now set use_cached: true in data_config.yaml and proceed to training.")


## Stage 2: Training with Cached Data

This section loads models and trains using cached tensors (if `use_cached: true`) or original dataset.


## 3. Load Teacher Model (DINOv2) - NOT Compiled


In [None]:
import sys
import warnings

print("Loading teacher model (DINOv2)...")
print("⚠️  Teacher will NOT be compiled (frozen, no benefit)")
teacher_name = model_cfg['teacher_name']

# For Python < 3.10, use patcher for compatibility
if sys.version_info < (3, 10):
    try:
        from dinov2_patcher import load_dinov2_with_patch
        print("Using Python 3.9 compatibility patcher...")
        teacher = load_dinov2_with_patch(teacher_name, verbose=False)
    except ImportError:
        print("⚠️  Warning: dinov2_patcher not available. Trying direct load...")
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning, message=".*xFormers.*")
            teacher = torch.hub.load("facebookresearch/dinov2", teacher_name, verbose=False)
else:
    # Python 3.10+: direct load
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning, message=".*xFormers.*")
        teacher = torch.hub.load("facebookresearch/dinov2", teacher_name, verbose=False)

teacher = teacher.to(device)
teacher.eval()

# Freeze all parameters
for param in teacher.parameters():
    param.requires_grad = False

num_params = sum(p.numel() for p in teacher.parameters())
print(f"✓ Teacher loaded: {teacher_name}")
print(f"  Parameters: {num_params / 1e6:.2f}M")
print(f"  Frozen: True")
print(f"  NOT compiled (frozen model, no benefit)")


## 4. Create Student Model (ViT-S/16) - Will Be Compiled


In [None]:
print("Creating student model...")
student_name = model_cfg['student_name']
student_img_size = model_cfg['student_img_size']

student = timm.create_model(
    student_name,
    pretrained=False,  # Random initialization
    img_size=student_img_size,
    num_classes=0,  # No classification head
)
student = student.to(device)
student.train()

num_params = sum(p.numel() for p in student.parameters())
trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad)

print(f"✓ Student created: {student_name}")
print(f"  Parameters: {num_params / 1e6:.2f}M")
print(f"  Trainable: {trainable_params / 1e6:.2f}M")
print(f"  Image size: {student_img_size}x{student_img_size}")

# Compile student for speed (only if enabled)
compile_student = train_cfg.get('compile_student', True)
if compile_student and hasattr(torch, 'compile'):
    print("\nCompiling student model with torch.compile...")
    print("⚠️  First compilation may take 5-10 minutes - this is normal!")
    student = torch.compile(student, mode='reduce-overhead')
    print("✓ Student model compiled successfully")
else:
    print("✓ Student model NOT compiled (compile_student=False or PyTorch < 2.0)")


## 5. Load Training Data (Cached or Original Mode)

The `build_pretraining_dataloader` function automatically handles:
- **Cached mode** (`use_cached: true`): Loads from tensor shards (fast)
- **Original mode** (`use_cached: false`): Loads from HuggingFace (slower)


In [None]:
# Build DataLoader using factory function (handles cached vs original mode)
print("Building DataLoader...")
dataloader = build_pretraining_dataloader(data_cfg, train_cfg)

total_batches = len(dataloader)
print(f"\n✓ DataLoader created: {total_batches} batches per epoch")

if max_steps_per_epoch:
    print(f"  With step cap: Each epoch will process {min(max_steps_per_epoch, total_batches)} batches")
    print(f"  Estimated time per epoch: ~{min(max_steps_per_epoch, total_batches) * 0.5 / 60:.1f} minutes")
else:
    print(f"  Full epoch: All {total_batches} batches will be processed")

# Test data loading
print("\nTesting data loading...")
try:
    test_batch = next(iter(dataloader))
    if isinstance(test_batch, list):
        print(f"✓ Data loading works! Batch has {len(test_batch)} crops")
        print(f"  First crop shape: {test_batch[0].shape}")
    else:
        print(f"✓ Data loading works! Batch shape: {test_batch.shape}")
except Exception as e:
    print(f"⚠️  Data loading test failed: {e}")
    if use_cached:
        print("   Make sure you've run Stage 1 (precompute cache) first!")


## 6. Define Feature Extraction Functions (with Optional Caching)


In [None]:
def extract_teacher_features(teacher, images, use_cls_token=True, 
                            cache_dir=None, cache_key=None):
    """Extract features from frozen teacher model with optional caching"""
    # Check cache if enabled
    if cache_dir is not None and cache_key is not None:
        cache_path = Path(cache_dir) / f"{cache_key}.pt"
        if cache_path.exists():
            cached = torch.load(cache_path, map_location=images.device)
            return cached['cls'], cached['patches']
    
    with torch.no_grad():
        features = teacher.forward_features(images)
        
        # Handle DINOv2 output format (dict or tensor)
        if isinstance(features, dict):
            if 'x_norm_clstoken' in features:
                cls_embedding = features['x_norm_clstoken']
            elif 'cls_token' in features:
                cls_embedding = features['cls_token']
            else:
                cls_embedding = features.get('x', features.get('tokens', None))[:, 0]
            
            if 'x_norm_patchtokens' in features:
                patch_embeddings = features['x_norm_patchtokens']
            elif 'patch_tokens' in features:
                patch_embeddings = features['patch_tokens']
            else:
                patch_embeddings = features.get('x', features.get('tokens', None))[:, 1:]
        else:
            # Tensor format [B, N+1, D]
            if use_cls_token:
                cls_embedding = features[:, 0]
            else:
                cls_embedding = features[:, 1:].mean(dim=1)
            patch_embeddings = features[:, 1:]
        
        # Normalize
        cls_embedding = F.normalize(cls_embedding, dim=-1, p=2)
        patch_embeddings = F.normalize(patch_embeddings, dim=-1, p=2)
    
    # Save to cache if enabled
    if cache_dir is not None and cache_key is not None:
        cache_path = Path(cache_dir) / f"{cache_key}.pt"
        os.makedirs(cache_dir, exist_ok=True)
        torch.save({'cls': cls_embedding, 'patches': patch_embeddings}, cache_path)
    
    return cls_embedding, patch_embeddings


def extract_student_features(student, images, use_cls_token=True):
    """Extract features from student model"""
    features = student.forward_features(images)
    
    if use_cls_token:
        cls_embedding = features[:, 0]
    else:
        cls_embedding = features[:, 1:].mean(dim=1)
    
    patch_embeddings = features[:, 1:]
    
    # Normalize
    cls_embedding = F.normalize(cls_embedding, dim=-1, p=2)
    patch_embeddings = F.normalize(patch_embeddings, dim=-1, p=2)
    
    return cls_embedding, patch_embeddings


In [None]:
def compute_distillation_loss(student_cls, student_patches, 
                             teacher_cls, teacher_patches,
                             loss_weights=None):
    """Compute distillation loss between student and teacher embeddings"""
    if loss_weights is None:
        loss_weights = {'cls': 1.0, 'patch': 0.5}
    
    # CLS token loss
    if student_cls.shape[-1] == teacher_cls.shape[-1]:
        loss_cls = F.mse_loss(student_cls, teacher_cls)
    else:
        # Different dimensions: use squared norm loss
        student_sq_norm = (student_cls ** 2).sum(dim=-1)
        teacher_sq_norm = (teacher_cls ** 2).sum(dim=-1)
        loss_cls = F.mse_loss(student_sq_norm, teacher_sq_norm)
    
    # Patch embeddings loss
    B_s, N_s, D_s = student_patches.shape
    B_t, N_t, D_t = teacher_patches.shape
    
    if N_s == N_t and D_s == D_t:
        loss_patch = F.mse_loss(student_patches, teacher_patches)
    elif D_s == D_t:
        if N_s < N_t:
            teacher_patches = teacher_patches[:, :N_s, :]
        else:
            student_patches = student_patches[:, :N_t, :]
        loss_patch = F.mse_loss(student_patches, teacher_patches)
    else:
        student_pooled = student_patches.mean(dim=1)
        teacher_pooled = teacher_patches.mean(dim=1)
        if D_s == D_t:
            loss_patch = F.mse_loss(student_pooled, teacher_pooled)
        else:
            student_sq_norm = (student_pooled ** 2).sum(dim=-1)
            teacher_sq_norm = (teacher_pooled ** 2).sum(dim=-1)
            loss_patch = F.mse_loss(student_sq_norm, teacher_sq_norm)
    
    # Weighted combination
    total_loss = loss_weights['cls'] * loss_cls + loss_weights['patch'] * loss_patch
    
    return total_loss, {
        'total': total_loss.item(),
        'cls': loss_cls.item(),
        'patch': loss_patch.item()
    }


## 8. Setup Optimizer and Scheduler


In [None]:
# Build optimizer
optimizer = build_optimizer(
    student,
    lr=train_cfg['learning_rate'],
    weight_decay=train_cfg['weight_decay'],
    fused=train_cfg.get('use_fused_adamw', True)
)

# Build scheduler
scheduler = build_scheduler(
    optimizer,
    num_epochs=train_cfg['num_epochs'],
    warmup_epochs=train_cfg['warmup_epochs']
)

# GradScaler for mixed precision
scaler = GradScaler(enabled=(device.type == 'cuda'))

print(f"✓ Optimizer: AdamW (lr={train_cfg['learning_rate']})")
print(f"  Fused: {train_cfg.get('use_fused_adamw', True)}")
print(f"✓ Scheduler: Cosine with {train_cfg['warmup_epochs']} warmup epochs")
print(f"✓ Mixed precision: {'Enabled' if device.type == 'cuda' else 'Disabled'}")


## 9. Training Loop with Auto-Resume

This training loop includes:
- Step-capped epochs (max_steps_per_epoch)
- Auto-resume from latest checkpoint
- Teacher feature caching (optional)
- Detailed timing (GPU/data/batch)
- Step-based checkpointing


In [None]:
# Training configuration
num_epochs = train_cfg['num_epochs']
loss_weights = train_cfg.get('distill_loss_weights', {'cls': 1.0, 'patch': 0.5})
use_cls_token = model_cfg.get('use_cls_token', True)
use_multi_crop = train_cfg.get('use_multi_crop', False)
save_every = train_cfg.get('save_every', 0)  # 0 = only at end of epoch
cache_teacher_features = train_cfg.get('cache_teacher_features', False)
teacher_feature_dir = train_cfg.get('teacher_feature_dir', None)
if teacher_feature_dir:
    teacher_feature_dir = os.path.expandvars(teacher_feature_dir)

# Determine max steps per epoch
total_batches = len(dataloader)
if max_steps_per_epoch is not None:
    if max_steps_per_epoch > total_batches:
        print(f"⚠️  Warning: max_steps_per_epoch ({max_steps_per_epoch}) > total batches ({total_batches})")
        print(f"   Falling back to full dataset pass")
        max_steps = total_batches
    else:
        max_steps = max_steps_per_epoch
else:
    max_steps = total_batches

# Checkpoint directory
checkpoint_dir = train_cfg.get('checkpoint_dir', './checkpoints')
checkpoint_dir = os.path.expandvars(checkpoint_dir)
os.makedirs(checkpoint_dir, exist_ok=True)

# Auto-detect and load latest checkpoint if available
def find_latest_checkpoint(checkpoint_dir):
    """Find the latest checkpoint in the directory"""
    if not os.path.exists(checkpoint_dir):
        return None
    
    # Look for checkpoint_latest.pth first
    latest_path = os.path.join(checkpoint_dir, 'checkpoint_latest.pth')
    if os.path.exists(latest_path):
        return latest_path
    
    # Otherwise, find the highest epoch checkpoint
    checkpoint_files = []
    for f in os.listdir(checkpoint_dir):
        if f.startswith('checkpoint_epoch_') and f.endswith('.pth'):
            try:
                epoch_num = int(f.replace('checkpoint_epoch_', '').replace('.pth', ''))
                checkpoint_files.append((epoch_num, os.path.join(checkpoint_dir, f)))
            except ValueError:
                continue
    
    if checkpoint_files:
        checkpoint_files.sort(key=lambda x: x[0], reverse=True)
        return checkpoint_files[0][1]
    
    return None

# Try to resume from latest checkpoint
latest_checkpoint = find_latest_checkpoint(checkpoint_dir)
start_epoch = 0
global_step = 0

if latest_checkpoint:
    print(f"\n✓ Found checkpoint: {latest_checkpoint}")
    print(f"  Loading checkpoint to resume training...")
    checkpoint = torch.load(latest_checkpoint, map_location=device)
    
    student.load_state_dict(checkpoint['student'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    scaler.load_state_dict(checkpoint['scaler'])
    start_epoch = checkpoint.get('epoch', 0) + 1
    global_step = checkpoint.get('global_step', 0)
    
    print(f"✓ Resumed from epoch {start_epoch}, step {global_step}")
    if max_steps_per_epoch is not None:
        print(f"  Continuing with step cap: {max_steps_per_epoch} steps/epoch")
    
    # Load previous losses if available (for visualization)
    if 'train_losses' in checkpoint:
        train_losses = checkpoint['train_losses']
        cls_losses = checkpoint.get('cls_losses', [])
        patch_losses = checkpoint.get('patch_losses', [])
        print(f"  Loaded previous training history: {len(train_losses)} epochs")
    else:
        train_losses = []
        cls_losses = []
        patch_losses = []
else:
    print(f"\n✓ No checkpoint found, starting from scratch")
    train_losses = []
    cls_losses = []
    patch_losses = []

print(f"\nStarting training for {num_epochs} epochs...")
print(f"  Starting from epoch: {start_epoch + 1}")
print(f"Loss weights: CLS={loss_weights['cls']}, Patch={loss_weights['patch']}")
if max_steps < total_batches:
    print(f"Step-capped: {max_steps} steps per epoch (out of {total_batches} total)")
    print(f"Estimated time per epoch: ~{max_steps * 0.5 / 60:.1f} minutes")
else:
    print(f"Full epoch: {total_batches} steps per epoch")
if save_every > 0:
    print(f"Checkpointing: Every {save_every} steps")
else:
    print(f"Checkpointing: Only at end of each epoch")
if use_cached:
    print(f"Data mode: Cached tensors (fast loading)")
else:
    print(f"Data mode: Original (HuggingFace)")
print("-" * 60)


In [None]:
# Training loop with all optimizations
for epoch in range(start_epoch, num_epochs):
    student.train()
    epoch_losses = {'total': [], 'cls': [], 'patch': []}
    
    # Use itertools.islice to limit iterations (avoids KeyError with DataLoader)
    limited_dataloader = itertools.islice(dataloader, max_steps)
    
    desc = f"Epoch {epoch+1}/{num_epochs}"
    if max_steps < total_batches:
        desc += f" (capped at {max_steps}/{total_batches} steps)"
    
    progress_bar = tqdm(limited_dataloader, desc=desc, total=max_steps)
    
    batch_times = []
    data_times = []
    gpu_times = []
    prev_iter_time = time.time()
    steps_completed = 0
    
    for batch_idx, batch in enumerate(progress_bar):
        iter_start = time.time()
        data_load_time = iter_start - prev_iter_time if batch_idx > 0 else 0
        
        batch_start = time.time()
        
        # Handle multi-crop or single image
        if use_multi_crop and isinstance(batch, list):
            images = batch[0].to(device)  # Use first global crop
        else:
            images = batch.to(device)
        
        # Convert to channels_last if supported
        try:
            images = images.to(memory_format=torch.channels_last)
        except:
            pass
        
        optimizer.zero_grad()
        
        gpu_start = time.time()
        
        # Mixed precision training
        if device.type == 'cuda':
            if AUTOCAST_NEW_API:
                with autocast(device_type='cuda', dtype=torch.bfloat16):
                    # Teacher forward with optional caching
                    cache_key = None
                    if cache_teacher_features:
                        cache_key = hashlib.md5(images.cpu().numpy().tobytes()).hexdigest()
                    
                    teacher_cls, teacher_patches = extract_teacher_features(
                        teacher, images, use_cls_token=use_cls_token,
                        cache_dir=teacher_feature_dir, cache_key=cache_key
                    )
                    
                    # Student forward
                    student_cls, student_patches = extract_student_features(
                        student, images, use_cls_token=use_cls_token
                    )
                    
                    # Compute distillation loss
                    loss, metrics = compute_distillation_loss(
                        student_cls, student_patches,
                        teacher_cls, teacher_patches,
                        loss_weights=loss_weights
                    )
            else:
                with autocast():
                    cache_key = None
                    if cache_teacher_features:
                        cache_key = hashlib.md5(images.cpu().numpy().tobytes()).hexdigest()
                    
                    teacher_cls, teacher_patches = extract_teacher_features(
                        teacher, images, use_cls_token=use_cls_token,
                        cache_dir=teacher_feature_dir, cache_key=cache_key
                    )
                    
                    student_cls, student_patches = extract_student_features(
                        student, images, use_cls_token=use_cls_token
                    )
                    
                    loss, metrics = compute_distillation_loss(
                        student_cls, student_patches,
                        teacher_cls, teacher_patches,
                        loss_weights=loss_weights
                    )
        else:
            # CPU: no autocast
            cache_key = None
            if cache_teacher_features:
                cache_key = hashlib.md5(images.cpu().numpy().tobytes()).hexdigest()
            
            teacher_cls, teacher_patches = extract_teacher_features(
                teacher, images, use_cls_token=use_cls_token,
                cache_dir=teacher_feature_dir, cache_key=cache_key
            )
            
            student_cls, student_patches = extract_student_features(
                student, images, use_cls_token=use_cls_token
            )
            
            loss, metrics = compute_distillation_loss(
                student_cls, student_patches,
                teacher_cls, teacher_patches,
                loss_weights=loss_weights
            )
        
        gpu_time = time.time() - gpu_start
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Track losses
        epoch_losses['total'].append(metrics['total'])
        epoch_losses['cls'].append(metrics['cls'])
        epoch_losses['patch'].append(metrics['patch'])
        steps_completed += 1
        global_step += 1
        
        # Track times
        batch_time = time.time() - batch_start
        batch_times.append(batch_time)
        data_times.append(data_load_time)
        gpu_times.append(gpu_time)
        if len(batch_times) > 10:
            batch_times.pop(0)
            data_times.pop(0)
            gpu_times.pop(0)
        
        avg_batch_time = sum(batch_times) / len(batch_times)
        avg_data_time = sum(data_times) / len(data_times) if data_times else 0
        avg_gpu_time = sum(gpu_times) / len(gpu_times)
        
        # Update progress bar with detailed timing
        current_lr = optimizer.param_groups[0]['lr']
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'cls': f'{metrics["cls"]:.4f}',
            'patch': f'{metrics["patch"]:.4f}',
            'lr': f'{current_lr:.6f}',
            'gpu': f'{avg_gpu_time:.2f}s',
            'data': f'{avg_data_time:.2f}s',
            'batch': f'{avg_batch_time:.2f}s',
            'step': f'{steps_completed}/{max_steps}'
        })
        
        # Step-based checkpointing
        if save_every > 0 and global_step % save_every == 0:
            checkpoint = {
                'student': student.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'scaler': scaler.state_dict(),
                'epoch': epoch,
                'global_step': global_step,
            }
            torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_step_{global_step}.pth")
            torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_latest.pth")
            print(f"\n  ✓ Saved checkpoint at step {global_step}")
        
        prev_iter_time = time.time()
    
    # Step scheduler at end of epoch
    scheduler.step()
    
    # Compute epoch averages
    avg_loss = np.mean(epoch_losses['total'])
    avg_cls = np.mean(epoch_losses['cls'])
    avg_patch = np.mean(epoch_losses['patch'])
    
    train_losses.append(avg_loss)
    cls_losses.append(avg_cls)
    patch_losses.append(avg_patch)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f} "
          f"(CLS: {avg_cls:.4f}, Patch: {avg_patch:.4f}) "
          f"[{steps_completed}/{max_steps} steps]")
    
    # Save checkpoint at end of epoch
    checkpoint = {
        'student': student.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'scaler': scaler.state_dict(),
        'epoch': epoch,
        'global_step': global_step,
        'train_losses': train_losses,  # Save training history
        'cls_losses': cls_losses,
        'patch_losses': patch_losses,
    }
    torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_latest.pth")
    torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}.pth")
    print(f"  ✓ Saved checkpoint: checkpoint_epoch_{epoch+1}.pth")

print("\n✓ Training completed!")


## 10. Visualize Training Progress


In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Total Loss')
plt.plot(cls_losses, label='CLS Loss', alpha=0.7)
plt.plot(patch_losses, label='Patch Loss', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Total Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Total Loss (Log Scale)')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final loss: {train_losses[-1]:.4f}")
print(f"Best loss: {min(train_losses):.4f} (epoch {np.argmin(train_losses)+1})")
