# Milestone M2 — DINO ViT-S/16 + head CIFAR-100

Goal (from `docs/project_tasks.md`):
- Build a model: **DINO backbone** + **linear head** for 100 classes
- Support **freeze policies**:
  - `head_only` (backbone frozen)
  - `finetune_all` (everything trainable)
  - `last_blocks_only` (optional: only last N blocks trainable)
- Provide helpers: `get_trainable_params(model)` and `count_params(model)`

Note: keep this notebook **code-only** here (don’t run if imports are broken in your environment).

In [2]:
!git clone https://github.com/emanueleR3/AML-Project-2.git

Cloning into 'AML-Project-2'...
remote: Enumerating objects: 51, done.[K
remote: Counting objects: 100% (51/51), done.[K
remote: Compressing objects: 100% (37/37), done.[K
remote: Total 51 (delta 23), reused 41 (delta 13), pack-reused 0 (from 0)[K
Receiving objects: 100% (51/51), 252.32 KiB | 7.88 MiB/s, done.
Resolving deltas: 100% (23/23), done.


In [1]:
%ls
%cd AML-Project-2

[0m[01;34mAML-Project-2[0m/  [01;34mdata[0m/  [01;34msample_data[0m/
/content/AML-Project-2


In [2]:
import torch

from src.utils import get_device
from src.model import build_model, count_params


## 1) Build model from config

This matches the deliverable: `build_model(config)` returns a ready-to-train model.

In [3]:
device = get_device()

config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'dropout': 0.1,
    # freeze_policy: 'head_only' | 'finetune_all' | 'last_blocks_only'
    'freeze_policy': 'head_only',
    # used only for 'last_blocks_only'
    'last_n_blocks': 2,
    'device': device,
}

model = build_model(config)
model.to(device)


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


DINOClassifier(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm): L

## 2) Count params (logging helper)

Milestone M2 requests `count_params(model)` for logging.

In [4]:
total = count_params(model, trainable_only=False)
trainable = count_params(model, trainable_only=True)
print('Total params:', total)
print('Trainable params:', trainable)


Total params: 21704164
Trainable params: 38500


## 3) Stop condition checks (forward + single backward)

M2 stop condition:
- Forward on dummy batch returns logits shape `[B, 100]`
- A single training step (loss + backward) does not error

In [5]:
# Dummy forward pass
B = 4
x = torch.randn(B, 3, 224, 224, device=device)
logits = model(x)
print('logits shape:', tuple(logits.shape))  


logits shape: (4, 100)


## 4) Freeze policy quick examples

Switching `freeze_policy` changes which backbone params are trainable.

In [6]:
# Head-only (backbone frozen)
m_head_only = build_model({**config, 'freeze_policy': 'head_only'})
print('head_only trainable:', count_params(m_head_only, trainable_only=True))

# Full fine-tuning
m_all = build_model({**config, 'freeze_policy': 'finetune_all'})
print('finetune_all trainable:', count_params(m_all, trainable_only=True))

# Last blocks only (optional)
m_last = build_model({**config, 'freeze_policy': 'last_blocks_only', 'last_n_blocks': 2})
print('last_blocks_only trainable:', count_params(m_last, trainable_only=True))


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


head_only trainable: 38500


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


finetune_all trainable: 21704164


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


last_blocks_only trainable: 3588196


## 5) Training Loop on CIFAR-100

For a **frozen DINO backbone + linear head**, convergence is fast:
- **10 epochs** is typically enough (linear probe setting)
- Using **AdamW** with cosine LR schedule
- Saves best checkpoint based on validation accuracy

In [7]:
import os
from tqdm import tqdm
from src.data import load_cifar100, create_dataloader
from src.utils import save_checkpoint, AverageMeter, accuracy, ensure_dir

# Hyperparameters
EPOCHS = 10
BATCH_SIZE = 64
LR = 1e-3
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 4
DATA_DIR = './data'
CHECKPOINT_DIR = './outputs/checkpoints'

ensure_dir(CHECKPOINT_DIR)

In [8]:
# Load CIFAR-100 with DINO transforms (224x224)
train_dataset, test_dataset = load_cifar100(data_dir=DATA_DIR, image_size=224, download=False)

# Create dataloaders
train_loader = create_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = create_dataloader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print(f'Train samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Batches per epoch: {len(train_loader)}')

TypeError: load_cifar100() got an unexpected keyword argument 'download'

In [21]:
# Re-build model fresh for training
model = build_model(config)
model.to(device)

# Optimizer + Scheduler
optimizer = torch.optim.AdamW(model.get_trainable_params(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
criterion = torch.nn.CrossEntropyLoss()

print(f'Optimizer: AdamW (lr={LR}, wd={WEIGHT_DECAY})')
print(f'Scheduler: CosineAnnealingLR (T_max={EPOCHS})')

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


Optimizer: AdamW (lr=0.001, wd=0.0001)
Scheduler: CosineAnnealingLR (T_max=10)


In [23]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    
    pbar = tqdm(loader, desc='Train', leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        acc1 = accuracy(outputs, labels, topk=(1,))[0]
        loss_meter.update(loss.item(), images.size(0))
        acc_meter.update(acc1.item(), images.size(0))
        pbar.set_postfix(loss=f'{loss_meter.avg:.4f}', acc=f'{acc_meter.avg:.2f}%')
    
    return loss_meter.avg, acc_meter.avg


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    
    for images, labels in tqdm(loader, desc='Eval', leave=False):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        acc1 = accuracy(outputs, labels, topk=(1,))[0]
        loss_meter.update(loss.item(), images.size(0))
        acc_meter.update(acc1.item(), images.size(0))
    
    return loss_meter.avg, acc_meter.avg

In [24]:
# Training loop
best_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}

print(f'Starting training for {EPOCHS} epochs...\n')

for epoch in range(1, EPOCHS + 1):
    print(f'Epoch {epoch}/{EPOCHS}')
    
    # Train
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    
    # Evaluate
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    # Step scheduler
    scheduler.step()
    
    # Log
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    print(f'  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
    print(f'  Test  Loss: {test_loss:.4f} | Test  Acc: {test_acc:.2f}%')
    
    # Save best checkpoint
    is_best = test_acc > best_acc
    if is_best:
        best_acc = test_acc
        save_checkpoint({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_acc': best_acc,
            'config': config,
        }, filepath=os.path.join(CHECKPOINT_DIR, 'dino_cifar100_best.pt'))
        print(f'  ✓ New best! Saved checkpoint (acc={best_acc:.2f}%)')
    
    print()

print(f'Training complete! Best Test Acc: {best_acc:.2f}%')

Starting training for 10 epochs...

Epoch 1/10


                                                                               

KeyboardInterrupt: 

In [None]:
# Save final checkpoint (last epoch)
save_checkpoint({
    'epoch': EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'best_acc': best_acc,
    'history': history,
    'config': config,
}, filepath=os.path.join(CHECKPOINT_DIR, 'dino_cifar100_last.pt'))

print(f'Final checkpoint saved to {CHECKPOINT_DIR}/dino_cifar100_last.pt')
print(f'Best checkpoint saved to {CHECKPOINT_DIR}/dino_cifar100_best.pt')