# Milestone M2 — DINO ViT-S/16 + head CIFAR-100

Goal (from `docs/project_tasks.md`):
- Build a model: **DINO backbone** + **linear head** for 100 classes
- Support **freeze policies**:
  - `head_only` (backbone frozen)
  - `finetune_all` (everything trainable)
  - `last_blocks_only` (optional: only last N blocks trainable)
- Provide helpers: `get_trainable_params(model)` and `count_params(model)`

Note: keep this notebook **code-only** here (don’t run if imports are broken in your environment).

In [None]:
import sys

# If running from notebooks/, add project root to PYTHONPATH
sys.path.append('AML-Project-2')

import torch

from src.utils import get_device
from src.model import build_model, count_params


## 1) Build model from config

This matches the deliverable: `build_model(config)` returns a ready-to-train model.

In [None]:
device = get_device()

config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'dropout': 0.1,
    # freeze_policy: 'head_only' | 'finetune_all' | 'last_blocks_only'
    'freeze_policy': 'head_only',
    # used only for 'last_blocks_only'
    'last_n_blocks': 2,
    'device': device,
}

model = build_model(config)
model.to(device)


## 2) Count params (logging helper)

Milestone M2 requests `count_params(model)` for logging.

In [None]:
total = count_params(model, trainable_only=False)
trainable = count_params(model, trainable_only=True)
print('Total params:', total)
print('Trainable params:', trainable)


## 3) Stop condition checks (forward + single backward)

M2 stop condition:
- Forward on dummy batch returns logits shape `[B, 100]`
- A single training step (loss + backward) does not error

Below is the code to run once your environment imports work.

In [None]:
# Dummy forward pass
B = 4
x = torch.randn(B, 3, 224, 224, device=device)
logits = model(x)
print('logits shape:', tuple(logits.shape))  # expected: (B, 100)

# Single training step (example)
# y = torch.randint(low=0, high=100, size=(B,), device=device)
# loss_fn = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.get_trainable_params(), lr=0.01, momentum=0.9)
# loss = loss_fn(logits, y)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# print('loss:', float(loss))


## 4) Freeze policy quick examples

Switching `freeze_policy` changes which backbone params are trainable.

In [None]:
# Head-only (backbone frozen)
m_head_only = build_model({**config, 'freeze_policy': 'head_only'})
print('head_only trainable:', count_params(m_head_only, trainable_only=True))

# Full fine-tuning
m_all = build_model({**config, 'freeze_policy': 'finetune_all'})
print('finetune_all trainable:', count_params(m_all, trainable_only=True))

# Last blocks only (optional)
m_last = build_model({**config, 'freeze_policy': 'last_blocks_only', 'last_n_blocks': 2})
print('last_blocks_only trainable:', count_params(m_last, trainable_only=True))
