# DINO Pipeline Testing Notebook

This notebook tests each step of the DINO self-supervised learning pipeline to verify everything works correctly.


## 1. Setup and Imports


In [None]:
import torch
import torch.nn as nn
import numpy as np
import yaml
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

# Import our modules
from data_loader import PretrainDataset, EvalDataset
from transforms import MultiCropTransform, EvalTransform
from vit_model import build_vit
from dino_wrapper import DINO, DINOHead
from optimizer import build_optimizer, build_scheduler, cosine_schedule
from train_dino import dino_loss, train_epoch
from extract_features import extract_features
from knn_eval import knn_evaluate, knn_evaluate_multiple_k

print("✓ All imports successful")


## 2. Load Configuration Files


In [None]:
def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

# Load configs
data_cfg = load_config('data_config.yaml')
model_cfg = load_config('model_config.yaml')
train_cfg = load_config('train_config.yaml')
eval_cfg = load_config('eval_config.yaml')

print("Data Config:", data_cfg)
print("\nModel Config:", model_cfg)
print("\nTrain Config:", train_cfg)
print("\nEval Config:", eval_cfg)


## 3. Test Data Loading


In [None]:
# Test pretraining dataset
print("=== Testing Pretraining Dataset ===")
transform = MultiCropTransform(
    global_crops_scale=tuple(data_cfg['global_crops_scale']),
    local_crops_scale=tuple(data_cfg['local_crops_scale']),
    local_crops_number=data_cfg['local_crops_number'],
    image_size=data_cfg['image_size']
)

print("Creating pretraining dataset...")
pretrain_dataset = PretrainDataset(transform=transform)
print(f"Dataset size: {len(pretrain_dataset)}")

# Get one sample
sample = pretrain_dataset[0]
print(f"\nNumber of crops: {len(sample)}")
print(f"Crop shapes: {[c.shape for c in sample]}")
print(f"Crop dtypes: {[c.dtype for c in sample]}")
print(f"Crop value ranges: {[(c.min().item(), c.max().item()) for c in sample]}")


In [None]:
# Test evaluation dataset
print("=== Testing Evaluation Dataset ===")
eval_transform = EvalTransform(image_size=data_cfg['image_size'])

print("Creating evaluation train dataset...")
eval_train_dataset = EvalDataset(split='train', transform=eval_transform)
print(f"Train dataset size: {len(eval_train_dataset)}")

print("\nCreating evaluation test dataset...")
eval_test_dataset = EvalDataset(split='test', transform=eval_transform)
print(f"Test dataset size: {len(eval_test_dataset)}")

# Get one sample
image, label = eval_train_dataset[0]
print(f"\nImage shape: {image.shape}")
print(f"Label: {label}")
print(f"Image dtype: {image.dtype}")
print(f"Image value range: ({image.min().item():.3f}, {image.max().item():.3f})")


In [None]:
# Test DataLoader
print("=== Testing DataLoader ===")
pretrain_loader = DataLoader(
    pretrain_dataset,
    batch_size=4,  # Small batch for testing
    shuffle=True,
    num_workers=0,  # Set to 0 for debugging
    pin_memory=False
)

batch = next(iter(pretrain_loader))
print(f"Batch type: {type(batch)}")
print(f"Number of crops in batch: {len(batch)}")
print(f"Batch shapes: {[c.shape for c in batch]}")
print(f"\nFirst crop shape: {batch[0].shape}")
print(f"Second crop shape: {batch[1].shape}")
print(f"First local crop shape: {batch[2].shape}")


## 4. Test Model Creation


In [None]:
print("=== Testing Model Creation ===")

# Build backbone
print(f"Building {model_cfg['model_name']}...")
backbone = build_vit(
    model_name=model_cfg['model_name'],
    img_size=model_cfg['img_size'],
    patch_size=model_cfg['patch_size'],
    drop_path_rate=model_cfg['drop_path_rate']
)

print(f"Backbone embed_dim: {backbone.embed_dim}")
print(f"Backbone parameters: {sum(p.numel() for p in backbone.parameters()) / 1e6:.2f}M")

# Test forward pass
dummy_input = torch.randn(2, 3, 96, 96)
with torch.no_grad():
    features = backbone.forward_features(dummy_input)
    print(f"\nInput shape: {dummy_input.shape}")
    print(f"Features shape: {features.shape}")
    print(f"CLS token shape: {features[:, 0].shape}")
    print(f"Patch tokens shape: {features[:, 1:].shape}")


In [None]:
# Build DINO model
print("=== Testing DINO Model ===")
model = DINO(
    backbone,
    out_dim=train_cfg['out_dim'],
    use_cls_token=model_cfg['use_cls_token']
)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params / 1e6:.2f}M")
print(f"Trainable parameters: {trainable_params / 1e6:.2f}M")
print(f"Student head parameters: {sum(p.numel() for p in model.student_head.parameters()) / 1e6:.2f}M")

# Test forward pass with multi-crop
dummy_crops = [torch.randn(2, 3, 96, 96) for _ in range(4)]  # 2 global + 2 local
print(f"\nTesting forward pass with {len(dummy_crops)} crops...")

with torch.no_grad():
    student_outputs = model(dummy_crops, is_teacher=False)
    teacher_outputs = model(dummy_crops, is_teacher=True)

print(f"Number of student outputs: {len(student_outputs)}")
print(f"Student output shape: {student_outputs[0].shape}")
print(f"Teacher output shape: {teacher_outputs[0].shape}")
print(f"Output dimension: {student_outputs[0].shape[-1]}")


## 5. Test Loss Function


In [None]:
print("=== Testing DINO Loss ===")

# Create dummy outputs
batch_size = 4
out_dim = train_cfg['out_dim']
num_crops = 4

student_outputs = [torch.randn(batch_size, out_dim) for _ in range(num_crops)]
teacher_outputs = [torch.randn(batch_size, out_dim) for _ in range(num_crops)]
center = torch.zeros(out_dim)

loss = dino_loss(
    student_outputs,
    teacher_outputs,
    center,
    teacher_temp=train_cfg['teacher_temp'],
    student_temp=train_cfg['student_temp']
)

print(f"Loss value: {loss.item():.4f}")
print(f"Loss shape: {loss.shape}")
print(f"Loss requires grad: {loss.requires_grad}")


## 6. Test Optimizer and Scheduler


In [None]:
print("=== Testing Optimizer and Scheduler ===")

optimizer = build_optimizer(
    model,
    lr=train_cfg['learning_rate'],
    weight_decay=train_cfg['weight_decay']
)

scheduler = build_scheduler(
    optimizer,
    num_epochs=train_cfg['num_epochs'],
    warmup_epochs=train_cfg['warmup_epochs']
)

print(f"Optimizer: {type(optimizer).__name__}")
print(f"Number of parameter groups: {len(optimizer.param_groups)}")
print(f"Initial learning rate: {optimizer.param_groups[0]['lr']}")

# Test scheduler for a few steps
print("\nTesting scheduler (first 5 steps):")
for i in range(5):
    lr = optimizer.param_groups[0]['lr']
    print(f"Step {i}: LR = {lr:.6f}")
    scheduler.step()

# Test momentum schedule
print("\nTesting momentum schedule:")
for epoch in [0, 10, 50, 100, 199]:
    momentum = cosine_schedule(epoch, max_epochs=200, base_value=0.996, final_value=1.0)
    print(f"Epoch {epoch}: Momentum = {momentum:.6f}")


## 7. Test Training Step (Single Batch)


In [None]:
print("=== Testing Single Training Step ===")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model.to(device)

# Create a small dataloader
small_loader = DataLoader(
    pretrain_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=0
)

# Get one batch
crops = next(iter(small_loader))
crops = [c.to(device) for c in crops]

print(f"Batch crops: {len(crops)}")
print(f"Crop shapes: {[c.shape for c in crops]}")

# Initialize optimizer and scaler
optimizer = build_optimizer(model, lr=train_cfg['learning_rate'], weight_decay=train_cfg['weight_decay'])
scheduler = build_scheduler(optimizer, num_epochs=train_cfg['num_epochs'], warmup_epochs=train_cfg['warmup_epochs'])
scaler = torch.cuda.amp.GradScaler()

out_dim = train_cfg['out_dim']
center = torch.zeros(out_dim, device=device)

# Forward pass
model.train()
optimizer.zero_grad()

with torch.cuda.amp.autocast():
    student_outputs = model(crops, is_teacher=False)
    with torch.no_grad():
        teacher_outputs = model(crops, is_teacher=True)
    
    loss = dino_loss(
        student_outputs, teacher_outputs, center,
        teacher_temp=train_cfg['teacher_temp'],
        student_temp=train_cfg['student_temp']
    )

print(f"\nLoss: {loss.item():.4f}")

# Backward pass
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()

print(f"Learning rate after step: {scheduler.get_last_lr()[0]:.6f}")

# Update teacher
momentum = cosine_schedule(0, max_epochs=200, base_value=0.996, final_value=1.0)
model.update_teacher(momentum)
print(f"Teacher updated with momentum: {momentum:.6f}")

# Update center
with torch.no_grad():
    teacher_out = torch.stack(teacher_outputs)
    center = 0.9 * center + 0.1 * teacher_out.mean(dim=0).mean(dim=0)
print(f"Center updated, mean: {center.mean().item():.6f}")


## 8. Test Feature Extraction


In [None]:
print("=== Testing Feature Extraction ===")

# Create a small evaluation dataset
eval_transform = EvalTransform(image_size=data_cfg['image_size'])
eval_dataset = EvalDataset(split='train', transform=eval_transform)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=0
)

print(f"Evaluation dataset size: {len(eval_dataset)}")

# Extract features from a few batches
model.eval()
backbone = model.get_backbone()
backbone.eval()

print("\nExtracting features from first 2 batches...")
features_list = []
labels_list = []

with torch.no_grad():
    for i, (images, labels) in enumerate(eval_loader):
        if i >= 2:
            break
        
        images = images.to(device)
        outputs = backbone.forward_features(images)
        
        if model_cfg['use_cls_token']:
            feat = outputs[:, 0]  # CLS token
        else:
            feat = outputs[:, 1:].mean(dim=1)  # Mean-pool
        
        feat = nn.functional.normalize(feat, dim=-1, p=2)
        
        features_list.append(feat.cpu())
        labels_list.append(labels)
        
        print(f"Batch {i+1}: features shape = {feat.shape}, labels shape = {labels.shape}")

features = torch.cat(features_list, dim=0)
labels = torch.cat(labels_list, dim=0)

print(f"\nTotal features shape: {features.shape}")
print(f"Total labels shape: {labels.shape}")
print(f"Feature norm (should be ~1.0): {features.norm(dim=1).mean().item():.4f}")
print(f"Unique labels: {torch.unique(labels).tolist()}")


## 9. Test k-NN Evaluation


In [None]:
print("=== Testing k-NN Evaluation ===")

# Create dummy train and test features
train_features = torch.randn(100, backbone.embed_dim)
train_features = nn.functional.normalize(train_features, dim=-1, p=2)
train_labels = torch.randint(0, 10, (100,))

test_features = torch.randn(20, backbone.embed_dim)
test_features = nn.functional.normalize(test_features, dim=-1, p=2)
test_labels = torch.randint(0, 10, (20,))

print(f"Train features: {train_features.shape}")
print(f"Train labels: {train_labels.shape}")
print(f"Test features: {test_features.shape}")
print(f"Test labels: {test_labels.shape}")

# Test k-NN with different k values
k_values = [5, 10, 20]
results = knn_evaluate_multiple_k(
    train_features, train_labels,
    test_features, test_labels,
    k_values=k_values
)

print(f"\nResults: {results}")


## 10. Summary and Checks


In [None]:
print("=== Pipeline Summary ===")
print("\n✓ Data loading works")
print("✓ Model creation works")
print("✓ Loss computation works")
print("✓ Optimizer and scheduler work")
print("✓ Training step works")
print("✓ Feature extraction works")
print("✓ k-NN evaluation works")
print("\nAll components are ready for full training!")
