# Milestone M3 — Centralized Baseline

**Goal**: Obtain a solid and reproducible centralized baseline for DINO ViT-S/16 on CIFAR-100.

## Targets
- Training loop using `src.train`
- Hyperparameter sanity checks
- Save best checkpoint and metrics

In [36]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import utilities and data functions
from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, save_metrics_json, count_parameters, AverageMeter, accuracy
from src.data import load_cifar100, create_dataloader
from src.model import build_model


## 1. Setup & Configuration

In [37]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    """
    Trains the model for one epoch.
    """
    model.train()
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    
    pbar = tqdm(loader, desc='Train', leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        acc1 = accuracy(outputs, labels, topk=(1,))[0]
        loss_meter.update(loss.item(), images.size(0))
        acc_meter.update(acc1.item(), images.size(0))
        
        pbar.set_postfix(loss=f'{loss_meter.avg:.4f}', acc=f'{acc_meter.avg:.2f}%')
    
    return loss_meter.avg, acc_meter.avg


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    """
    Evaluates the model on the given loader.
    """
    model.eval()
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    
    for images, labels in tqdm(loader, desc='Eval', leave=False):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        acc1 = accuracy(outputs, labels, topk=(1,))[0]
        loss_meter.update(loss.item(), images.size(0))
        acc_meter.update(acc1.item(), images.size(0))
    
    return loss_meter.avg, acc_meter.avg

In [38]:
# Config for the Central Baseline
config = {
    'exp_name': 'central_baseline',
    'seed': 42,
    'data_dir': './data',
    'output_dir': './outputs',
    
    # Model
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'head_only',  # 'head_only' | 'finetune_all'
    'dropout': 0.0,
    
    # Training
    'epochs': 15,          # 10-20 is usually enough for linear probe
    'batch_size': 64,
    'lr': 1e-3,
    'weight_decay': 1e-4,
    'num_workers': 0,      # Set to 0 for stability unless you have a good GPU setup
    'device': None         # Will be set automatically
}

# Set seed for reproducibility
set_seed(config['seed'])

# Get device
device = get_device()
config['device'] = device

print(f"Device: {device}")
print(f"Seed: {config['seed']}")

# Setup Directories
checkpoint_dir = os.path.join(config['output_dir'], 'checkpoints')
log_dir = os.path.join(config['output_dir'], 'logs', config['exp_name'])
figures_dir = os.path.join(config['output_dir'], 'figures')

ensure_dir(checkpoint_dir)
ensure_dir(log_dir)
ensure_dir(figures_dir)

print(f"\nDirectories created:")
print(f"  Checkpoints: {checkpoint_dir}")
print(f"  Logs: {log_dir}")
print(f"  Figures: {figures_dir}")

Device: cpu
Seed: 42

Directories created:
  Checkpoints: ./outputs/checkpoints
  Logs: ./outputs/logs/central_baseline
  Figures: ./outputs/figures


## 2. Load Data (CIFAR-100)

In [39]:
print("Loading CIFAR-100...")
# Uses DINO transforms (224x224)
train_trainval, test_dataset = load_cifar100(data_dir=config['data_dir'], image_size=224)

# Split Train into Train (90%) and Val (10%)
train_size = int(0.9 * len(train_trainval))
val_size = len(train_trainval) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_trainval, 
    [train_size, val_size], 
    generator=torch.Generator().manual_seed(config['seed'])
)

print(f"\nDataset sizes:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val:   {len(val_dataset)} samples")
print(f"  Test:  {len(test_dataset)} samples")

# Create dataloaders
train_loader = create_dataloader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])
val_loader = create_dataloader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])
test_loader = create_dataloader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

print(f"\nDataloaders created:")
print(f"  Train batches per epoch: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

Loading CIFAR-100...

Dataset sizes:
  Train: 45000 samples
  Val:   5000 samples
  Test:  10000 samples

Dataloaders created:
  Train batches per epoch: 703
  Val batches: 79
  Test batches: 157


## 3. Build Model

In [40]:
# Build model
model = build_model(config)
model.to(device)

# Count parameters
total_params = count_parameters(model, trainable_only=False)
trainable_params = count_parameters(model, trainable_only=True)

print(f"Model: {config['model_name']}")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Trainable %: {100 * trainable_params / total_params:.2f}%")

# Optimizer + Scheduler
optimizer = optim.AdamW(model.get_trainable_params(), lr=config['lr'], weight_decay=config['weight_decay'])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
criterion = nn.CrossEntropyLoss()

print(f"\nOptimizer: AdamW (lr={config['lr']}, wd={config['weight_decay']})")
print(f"Scheduler: CosineAnnealingLR (T_max={config['epochs']})")

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Model: dino_vits16
  Total parameters: 21,704,164
  Trainable parameters: 38,500
  Trainable %: 0.18%

Optimizer: AdamW (lr=0.001, wd=0.0001)
Scheduler: CosineAnnealingLR (T_max=15)


## 4. Training Loop

In [None]:
# Training loop
best_acc = 0.0
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}

print(f"\n{'='*50}")
print(f"Starting training for {config['epochs']} epochs...")
print(f"{'='*50}\n")

for epoch in range(1, config['epochs'] + 1):
    print(f"Epoch {epoch}/{config['epochs']}")
    
    # Train
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    
    # Validation
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    
    # Step scheduler
    scheduler.step()
    
    # Logging
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.2f}%")
    
    # Save best checkpoint
    is_best = val_acc > best_acc
    if is_best:
        best_acc = val_acc
        save_checkpoint({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_acc': best_acc,
            'config': config,
        }, filepath=os.path.join(checkpoint_dir, 'central_best.pt'))
        print(f"  ✓ New best! Saved checkpoint (acc={val_acc:.2f}%)")
    
    print()

# Save metrics
save_metrics_json(os.path.join(log_dir, 'metrics.json'), history)
print(f"{'='*50}")
print(f"Training complete! Best Val Acc: {best_acc:.2f}%")
print(f"{'='*50}")


Starting training for 15 epochs...

Epoch 1/15


Train:   1%|          | 5/703 [01:14<2:42:25, 13.96s/it, acc=1.25%, loss=7.4937]

## 5. Final Evaluation on Test Set

In [None]:
# Load best model
ckpt = torch.load(os.path.join(checkpoint_dir, 'central_best.pt'), map_location=device)
model.load_state_dict(ckpt['model_state_dict'])

test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"\n{'='*50}")
print(f"Final Test Results")
print(f"{'='*50}")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")
print(f"Best Val Accuracy: {ckpt['best_acc']:.2f}%")

## 6. Plotting

In [None]:
# Plot training curves
epochs_range = range(1, len(history['train_loss']) + 1)

plt.figure(figsize=(12, 5))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(epochs_range, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
plt.plot(epochs_range, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
plt.title('Training and Validation Loss', fontsize=12, fontweight='bold')
plt.xlabel('Epoch', fontsize=10)
plt.ylabel('Loss', fontsize=10)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

# Accuracy plot
plt.subplot(1, 2, 2)
plt.plot(epochs_range, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
plt.plot(epochs_range, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
plt.title('Training and Validation Accuracy', fontsize=12, fontweight='bold')
plt.xlabel('Epoch', fontsize=10)
plt.ylabel('Accuracy (%)', fontsize=10)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

plt.tight_layout()
figure_path = os.path.join(figures_dir, 'central_training_curves.png')
plt.savefig(figure_path, dpi=150, bbox_inches='tight')
print(f"Training curves saved to: {figure_path}")
plt.show()