# DINO Pretraining Notebook

This notebook runs the full DINO self-supervised pretraining on unlabeled images.

**Note**: Evaluation datasets are not yet available, so this focuses on pretraining only.


## 1. Setup and Imports


In [None]:
import torch
import torch.nn as nn
import yaml
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from datetime import datetime

# Import our modules
from data_loader import PretrainDataset
from transforms import MultiCropTransform
from vit_model import build_vit
from dino_wrapper import DINO
from optimizer import build_optimizer, build_scheduler, cosine_schedule
from train_dino import train_dino, train_epoch, dino_loss

print("âœ“ All imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## 2. Load Configuration


In [None]:
def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

# Load configs
# OPTIMIZED: Use optimized configs for ~5-8x speedup (recommended)
# DEFAULT: Use original configs for baseline comparison
config_mode = "optimized"  # Options: "optimized", "default"

if config_mode == "optimized":
    data_cfg = load_config('data_config_optimized.yaml')
    train_cfg = load_config('train_config_optimized.yaml')
    print("âš¡ Using OPTIMIZED configs (~5-8x speedup)")
    print("  - 2 global + 2 local crops (4 total)")
    print("  - 75 epochs, 5 warmup")
    print("  - Reduced projection head (32k)")
    print("  - All performance optimizations enabled")
else:
    data_cfg = load_config('data_config.yaml')
    train_cfg = load_config('train_config.yaml')
    print("ðŸ“Š Using DEFAULT configs (baseline)")

model_cfg = load_config('model_config.yaml')

print("=== Configuration ===")
print(f"Model: {model_cfg['model_name']}")
print(f"Image size: {data_cfg['image_size']}")
print(f"Batch size: {train_cfg['batch_size']}")
print(f"Epochs: {train_cfg['num_epochs']}")
print(f"Learning rate: {train_cfg['learning_rate']}")
print(f"Local crops: {data_cfg['local_crops_number']}")
print(f"Total crops per image: {2 + data_cfg['local_crops_number']} (2 global + {data_cfg['local_crops_number']} local)")
print(f"Data workers: {data_cfg.get('num_workers', 4)}")


## 3. Setup Device and Checkpoint Directory


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create checkpoint directory
checkpoint_dir = train_cfg.get('checkpoint_dir', './checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"Checkpoints will be saved to: {checkpoint_dir}")


## 4. Load Pretraining Dataset


In [None]:
# Create multi-crop transform
transform = MultiCropTransform(
    global_crops_scale=tuple(data_cfg['global_crops_scale']),
    local_crops_scale=tuple(data_cfg['local_crops_scale']),
    local_crops_number=data_cfg['local_crops_number'],
    image_size=data_cfg['image_size']
)

# Load dataset
print("Loading pretraining dataset...")
pretrain_dataset = PretrainDataset(transform=transform)
print(f"Dataset size: {len(pretrain_dataset)}")

# Create DataLoader with optimizations
train_loader = DataLoader(
    pretrain_dataset,
    batch_size=train_cfg['batch_size'],
    shuffle=True,
    num_workers=data_cfg['num_workers'],
    pin_memory=data_cfg['pin_memory'],
    drop_last=True,
    persistent_workers=data_cfg.get('persistent_workers', False),  # Keep workers alive between epochs
    prefetch_factor=data_cfg.get('prefetch_factor', 2)  # Prefetch batches
)

print(f"Number of batches per epoch: {len(train_loader)}")
print(f"Total training steps: {len(train_loader) * train_cfg['num_epochs']}")
print(f"Effective batch size: {train_cfg['batch_size']} Ã— {2 + data_cfg['local_crops_number']} crops = {train_cfg['batch_size'] * (2 + data_cfg['local_crops_number'])} forward passes per batch")
print(f"Data workers: {data_cfg['num_workers']} (with persistent_workers={data_cfg.get('persistent_workers', False)})")


## 5. Create Model


In [None]:
# Build ViT backbone
print(f"Building {model_cfg['model_name']}...")
backbone = build_vit(
    model_name=model_cfg['model_name'],
    img_size=model_cfg['img_size'],
    patch_size=model_cfg['patch_size'],
    drop_path_rate=model_cfg['drop_path_rate']
)

# Build DINO model
model = DINO(
    backbone,
    out_dim=train_cfg['out_dim'],
    use_cls_token=model_cfg['use_cls_token']
)

# Move to device
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel created successfully!")
print(f"Total parameters: {total_params / 1e6:.2f}M")
print(f"Trainable parameters: {trainable_params / 1e6:.2f}M")
print(f"Model size: {total_params * 4 / 1e6:.2f} MB (float32)")


## 6. Setup Optimizer and Scheduler


In [None]:
# Enable performance optimizations
if device.type == 'cuda':
    torch.backends.cuda.matmul.allow_tf32 = train_cfg.get('use_tf32', True)
    torch.backends.cudnn.allow_tf32 = train_cfg.get('use_tf32', True)
    if train_cfg.get('use_tf32', True):
        print("âœ“ TF32 enabled for faster training")

# Convert model to channels_last if enabled
if train_cfg.get('use_channels_last', False) and device.type == 'cuda':
    model = model.to(memory_format=torch.channels_last)
    print("âœ“ Model converted to channels_last format")

# Create optimizer with fused AdamW if enabled
optimizer = build_optimizer(
    model,
    lr=train_cfg['learning_rate'],
    weight_decay=train_cfg['weight_decay'],
    fused=train_cfg.get('use_fused_adamw', True)
)

# Create scheduler
scheduler = build_scheduler(
    optimizer,
    num_epochs=train_cfg['num_epochs'],
    warmup_epochs=train_cfg['warmup_epochs']
)

# Create gradient scaler for mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))

# Initialize center for DINO
out_dim = train_cfg['out_dim']
center = torch.zeros(out_dim, device=device)

# Compile model if enabled (AFTER creating optimizer, BEFORE training)
# Note: Compile after checkpoint loading if resuming
use_compile = train_cfg.get('use_torch_compile', False)
compiled_model = None
if use_compile and hasattr(torch, 'compile'):
    print("Note: Model will be compiled after checkpoint loading (if resuming)")

print("\nOptimizer and scheduler created")
print(f"Initial learning rate: {optimizer.param_groups[0]['lr']}")
print(f"Warmup epochs: {train_cfg['warmup_epochs']}")
print(f"Total epochs: {train_cfg['num_epochs']}")
print(f"Projection head size: {out_dim}")


## 7. Training Loop


In [None]:
# Training history
train_losses = []
learning_rates = []

# Resume from checkpoint if specified
start_epoch = 0
resume_from = None  # Set to checkpoint path if resuming, e.g., "checkpoints/checkpoint_epoch_50.pth"

if resume_from and os.path.exists(resume_from):
    print(f"Loading checkpoint from {resume_from}...")
    checkpoint = torch.load(resume_from, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    scaler.load_state_dict(checkpoint['scaler'])
    center = checkpoint['center']
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resumed from epoch {start_epoch}")
else:
    print("Starting training from scratch")

# Compile model now (after checkpoint loading if resuming)
if use_compile and hasattr(torch, 'compile') and compiled_model is None:
    print("Compiling model with torch.compile...")
    model = torch.compile(model, mode='reduce-overhead')
    print("âœ“ Model compiled")

# Performance optimizations summary:
# âœ“ Reduced crops: 2 global + 2 local (4 total vs 10)
# âœ“ Restricted loss pairings (avoid local-to-local)
# âœ“ Reduced projection head (32k vs 65k)
# âœ“ Shorter training (75 epochs vs 200)
# âœ“ torch.compile, channels_last, BF16, TF32, fused AdamW
# âœ“ Optimized data loading (24 workers, persistent, prefetch)

print(f"\n{'='*60}")
print(f"Starting training for {train_cfg['num_epochs']} epochs")
print(f"{'='*60}")


In [None]:
# Main training loop
for epoch in range(start_epoch, train_cfg['num_epochs']):
    epoch_start_time = datetime.now()
    
    # Train one epoch with optimized settings
    num_global = 2
    num_local = data_cfg['local_crops_number']
    
    avg_loss, center = train_epoch(
        model=model,
        dataloader=train_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        center=center,
        device=device,
        scaler=scaler,
        epoch=epoch,
        num_epochs=train_cfg['num_epochs'],
        teacher_temp=train_cfg['teacher_temp'],
        student_temp=train_cfg['student_temp'],
        warmup_teacher_temp=train_cfg['warmup_teacher_temp'],
        warmup_teacher_temp_epochs=train_cfg['warmup_teacher_temp_epochs'],
        num_global=num_global,
        num_local=num_local
    )
    
    # Record history
    train_losses.append(avg_loss)
    learning_rates.append(scheduler.get_last_lr()[0])
    
    # Calculate epoch time
    epoch_time = (datetime.now() - epoch_start_time).total_seconds()
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{train_cfg['num_epochs']} Summary:")
    print(f"  Loss: {avg_loss:.4f}")
    print(f"  Learning rate: {scheduler.get_last_lr()[0]:.6f}")
    print(f"  Time: {epoch_time:.1f}s ({epoch_time/60:.1f} min)")
    
    # Save checkpoint
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'scaler': scaler.state_dict(),
        'center': center,
        'epoch': epoch,
        'config': {
            'model': model_cfg,
            'train': train_cfg,
            'data': data_cfg
        }
    }
    
    # Save latest
    torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_latest.pth")
    
    # Save periodic checkpoints
    if (epoch + 1) % train_cfg.get('save_freq', 10) == 0 or (epoch + 1) == train_cfg['num_epochs']:
        torch.save(checkpoint, f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}.pth")
        print(f"  Checkpoint saved: checkpoint_epoch_{epoch+1}.pth")
    
    # Estimate remaining time
    if epoch > start_epoch:
        avg_time_per_epoch = sum([(datetime.now() - epoch_start_time).total_seconds()]) / (epoch - start_epoch + 1)
        remaining_epochs = train_cfg['num_epochs'] - epoch - 1
        remaining_time = avg_time_per_epoch * remaining_epochs
        print(f"  Estimated time remaining: {remaining_time/3600:.1f} hours")

print(f"\n{'='*60}")
print("Training completed!")
print(f"{'='*60}")


## 8. Plot Training Curves


In [None]:
# Plot training loss
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(learning_rates)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.yscale('log')

plt.tight_layout()
plt.savefig(f'{checkpoint_dir}/training_curves.png', dpi=150)
plt.show()

print(f"Training curves saved to {checkpoint_dir}/training_curves.png")


## 9. Final Model Summary


In [None]:
print("=== Training Summary ===")
print(f"Total epochs: {len(train_losses)}")
print(f"Final loss: {train_losses[-1]:.4f}")
print(f"Best loss: {min(train_losses):.4f} (epoch {train_losses.index(min(train_losses))+1})")
print(f"\nCheckpoints saved in: {checkpoint_dir}")
print(f"Latest checkpoint: {checkpoint_dir}/checkpoint_latest.pth")
print(f"Final checkpoint: {checkpoint_dir}/checkpoint_epoch_{len(train_losses)}.pth")

# Show checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
print(f"\nAll checkpoints ({len(checkpoint_files)}):")
for f in sorted(checkpoint_files):
    size = os.path.getsize(f"{checkpoint_dir}/{f}") / 1e6
    print(f"  - {f} ({size:.1f} MB)")


## 10. Next Steps

Once evaluation datasets become available, you can:

1. **Extract features** using `extract_features_main.py`:
   ```python
   !python extract_features_main.py \
       --checkpoint checkpoints/checkpoint_latest.pth \
       --data_config data_config.yaml \
       --model_config model_config.yaml \
       --eval_config eval_config.yaml \
       --output_dir ./features \
       --device cuda
   ```

2. **Evaluate with k-NN** using `knn_eval_main.py`:
   ```python
   !python knn_eval_main.py \
       --features features/features.pt \
       --eval_config eval_config.yaml
   ```

3. **Optional: Linear probe** by setting `linear_probe: true` in `eval_config.yaml`
