# Fine-Tuning DINOv2 ViT on STL-10 with LoRA Adapters

This notebook demonstrates how to fine-tune a pre-trained DINOv2 Vision Transformer on the STL-10 dataset using Low-Rank Adaptation (LoRA) for parameter-efficient tuning. We implement the DINO self-distillation training (with a momentum teacher and multi-crop augmentations) in a lightweight way that can run on a MacBook Pro M2 (Metal/MPS device) with 16GB RAM. The code is highly optimized and thoroughly commented to illustrate advanced engineering concepts, including custom augmentations, learning rate scheduling, and evaluation metrics like k-NN and few-shot classification.

# Imports

First, let's import necessary libraries and set up the computing device. We'll use PyTorch with MPS support if available (for Mac M1/M2 GPUs), otherwise default to CPU. We also ensure all required libraries (like Torchvision for dataset and Hugging Face for the DINOv2 model) are available.
python

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset, Subset
import torchvision.transforms as transforms
from torchvision.datasets import STL10
from transformers import AutoModel
import torch, random


# Select MPS device if available (for Apple Silicon GPUs), otherwise CPU
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print("Using device:", device)


# Data Preparation and Augmentation

- Dataset: We use the STL-10 dataset. STL-10 provides a small labeled training set (5,000 images, 10 classes) and a larger set of 15,000 unlabeled images. Since we're doing self-supervised training (like DINO), we'll use unlabeled data for training (plus the labeled training images treated as unlabeled to augment data). The model will learn representations without using labels. We will evaluate the learned features on the labeled test set for classification. 

- Image Size: DINOv2 ViT models are pre-trained on 224×224 images. STL-10 images are 96×96, so we will apply random resized crops that can upscale images to 224 for global views. The ViT's positional embeddings will be automatically resized via interpolation (handled by the model). 

- Multi-Crop Augmentation: We implement DINO's augmentation strategy:
	- Global Crops: 2 large crops covering a substantial part of the image (scale range ~0.3 to 1.0 of the image area), each resized to 224×224. These serve as two different "views" of the same image seen by both student and teacher.

	- Local Crops: Several smaller crops (e.g., 6 crops of scale range ~0.05 to 0.3 of area), resized to 96×96. These are only seen by the student (the teacher uses only global crops). The student must predict teacher representations for these partial views, which encourages learning global features.

	- Color Distortions: We apply strong color jitter, random grayscale, Gaussian blur, and Solarization (for one of the global crops) following DINOv2 settings.

	- Normalization: Finally, we convert images to tensors and normalize with ImageNet mean and std.
	
We'll define a custom transformation function that given a PIL image produces multiple crops: two global crops and N local crops. We'll then wrap STL-10 dataset so that each sample yields these crops.

In [None]:
from PIL import Image
import math
import random

# Define normalization (ImageNet mean/std)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# Augmentation parameters
global_crops_scale = (0.32, 1.0)   # scale range for global crops (fraction of image area)
local_crops_scale  = (0.05, 0.32)  # scale range for local crops
global_size = 224  # output size of global crops
local_size  = 96   # output size of local crops
local_crops_number = 6  # number of local crops

# Color jitter and other augmentations settings
jitter_strength = 0.4
color_jitter = transforms.ColorJitter(brightness=jitter_strength, contrast=jitter_strength,
                                      saturation=0.2, hue=0.1)

# Define augmentation pipelines for global and local crops:
# RandomResizedCrop with given scale and flip
global_crop = transforms.RandomResizedCrop(global_size, scale=global_crops_scale,
                                          interpolation=Image.BICUBIC)
local_crop  = transforms.RandomResizedCrop(local_size,  scale=local_crops_scale,
                                          interpolation=Image.BICUBIC)
flip = transforms.RandomHorizontalFlip(p=0.5)

# Define additional transform pipelines for color + blur/solarize
# Global crop 1: heavy color jitter + grayscale + Gaussian blur (always)
global_color1 = transforms.Compose([
    transforms.RandomApply([color_jitter], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=1.0)
])
# Global crop 2: color jitter + grayscale + Gaussian blur (10% chance) + solarize (20% chance)
global_color2 = transforms.Compose([
    transforms.RandomApply([color_jitter], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=0.1),
    transforms.RandomSolarize(threshold=128, p=0.2)
])
# Local crops: color jitter + grayscale + Gaussian blur (50% chance)
local_color = transforms.Compose([
    transforms.RandomApply([color_jitter], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0))], p=0.5)
])

# Compose full pipelines for each crop type
pipeline_global1 = transforms.Compose([global_crop, flip, global_color1, transforms.ToTensor(), normalize])
pipeline_global2 = transforms.Compose([global_crop, flip, global_color2, transforms.ToTensor(), normalize])
pipeline_local   = transforms.Compose([local_crop, flip, local_color, transforms.ToTensor(), normalize])


# Validation set for quick representation probes (k-NN / few-shot) with deterministic “light” augmentations
val_tx = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

def multi_crop_transform(img: Image.Image):
    """Apply multi-crop augmentation to an image.
    Returns:
      global_crops: list of 2 global crop tensors,
      local_crops: list of N local crop tensors.
    """
    # Two global views
    g1 = pipeline_global1(img)
    g2 = pipeline_global2(img)
    # Local views
    locals = [pipeline_local(img) for _ in range(local_crops_number)]
    return g1, g2, locals

torch.manual_seed(0)
random.seed(0)

# Load STL-10 dataset
# We'll use both labeled and unlabeled splits as unlabeled data for training
data_path = "./data"  # or any path for storing data

train_labeled = STL10(root=data_path, split='train', download=True)
train_unlabeled = STL10(root=data_path, split='unlabeled', download=True) # this gives PIL images and labels

ul_idx  = torch.randperm(len(train_unlabeled))[:15_000]  # choose 15 000
ul_15k  = Subset(train_unlabeled, ul_idx)
train_set = ConcatDataset([train_labeled, ul_15k])        # 5 000 + 15 000 = 20 000

val_data = STL10(root=data_path, split="test", download=True, transform=val_tx)

vl_idx = torch.randperm(len(val_data))[:2_000]
val_ds = Subset(val_data, vl_idx)

# We won't use labels for training, so we can ignore train_set.labels

# Wrap dataset to apply our multi-crop transform
class MultiCropSTL10(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
    def __len__(self):
        return len(self.base_dataset)
    def __getitem__(self, idx):
        img, _ = self.base_dataset[idx]  # get PIL image, ignore label
        # Apply multi-crop augmentation
        g1, g2, locals = multi_crop_transform(img)
        return {'global1': g1, 'global2': g2, 'locals': locals}

train_multi_dataset = MultiCropSTL10(train_set)

# DataLoader for training
batch_size = 8  # adjust based on memory (8 is safe for MPS 16GB; you can try higher if memory allows)
train_loader = DataLoader(train_multi_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
print(f"Loaded STL-10 dataset with {len(train_multi_dataset)} images. Batch size: {batch_size}")
print(f"Validation loader ready — {len(val_ds)} images, batch {batch_size}")

### Explanation: 
In the code above, for each image we generate 2 global crops and local_crops_number local crops. The train_loader will yield batches where each batch item is a dictionary containing one image's augmented views:

 - batch['global1'] and batch['global2'] are the two global crops (tensor shape [3, 224, 224] each).
 - batch['locals'] is a list of local crop tensors (each [3, 96, 96]) for that image.
 
During training, we'll need to collate these into proper tensors for vectorized forward passes (we will do that in the training loop manually, since local crops have a different size than global crops).

# Model Setup: DINOv2 Backbone with LoRA and Projection Head

**Backbone Model :** We use a pre-trained DINOv2 Vision Transformer backbone from Hugging Face. Specifically, we'll load facebook/dinov2-small which is a ViT small (embedding dim 384) trained with DINOv2. This backbone outputs a 384-dimensional embedding (the [CLS] token representation) for each image. 

**LoRA Adapters:** To fine-tune efficiently, we will freeze the backbone's original weights and insert LoRA adapters into the model's linear layers:

	- In each transformer block: the Query, Key, and Value projection layers of self-attention, and the two linear layers of the MLP.

	- LoRA adds two low-rank matrices per linear weight (down-projection and up-projection) with a small rank (we'll use rank r=4). Only these LoRA matrices will be trained (plus the new head), keeping the number of trainable parameters small.
	
We implement a custom LinearWithLoRA module that wraps an existing nn.Linear and adds LoRA weights. The original weight is frozen, and the LoRA weights produce a learnable offset:
Weff = Wbase + ΔW,

where $\Delta W = B \times A$ is factorized into a down-projection $A: \text{in_features}\to r$ and up-projection $B: r \to \text{out_features}$. We scale $\Delta W$ by $\alpha/r$ (with $\alpha$ typically set equal to $r$) so that initial LoRA contribution is zero (if we initialize $B$ or $A$ to zero) or very small. 

**Projection Head:** DINO uses a projection head (a small MLP) on top of the backbone's embedding to produce "prototype" vectors for computing the self-distillation loss. We'll implement a 3-layer MLP:

- Input dim = 384 (backbone CLS dim)
- Hidden dim = 2048, with GELU activation
- Bottleneck dim = 256, with GELU
- Output dim = n_prototypes (number of prototypes, we use a smaller number like 1024 for efficiency instead of 65k in the original)		The output of this head will be used to produce a probability distribution via softmax for the DINO loss.

**Teacher and Student:** We maintain two models:
- Student: backbone (with LoRA) + head, trained with gradient descent.
- Teacher: backbone (with LoRA) + head, updated only by exponential moving average (EMA) of the student (no direct gradient). The teacher provides target outputs for the student to match. Initially, teacher weights are cloned from student (so they start identical). As training progresses, teacher = m * teacher + (1-m) * student (for each parameter), with momentum m close to 1 (e.g., 0.996 -> 1.0).

In [None]:
# LoRA configuration
lora_rank = 4
lora_alpha = 4  # scaling, typically equal to rank

class LinearWithLoRA(nn.Module):
    """Wrap an nn.Linear layer with LoRA adapters (low-rank adaptation)."""
    def __init__(self, linear: nn.Linear, r: int, alpha: int):
        super().__init__()
        self.in_features = linear.in_features
        self.out_features = linear.out_features
        self.r = r
        self.alpha = alpha
        # Freeze original weight and bias
        self.weight = linear.weight  # keep reference to original weight
        self.weight.requires_grad_(False)
        self.bias = linear.bias
        if self.bias is not None:
            self.bias.requires_grad_(False)
        # Create LoRA weights
        if r > 0:
            # Down-projection: in_features -> r (no bias), Up-projection: r -> out_features (no bias)
            self.lora_down = nn.Linear(self.in_features, r, bias=False)
            self.lora_up   = nn.Linear(r, self.out_features, bias=False)
            # Initialize LoRA weights: set lora_up to zero so that initial output = 0
            nn.init.zeros_(self.lora_up.weight)
            # You can initialize lora_down with small random values or zeros. Here small random:
            nn.init.normal_(self.lora_down.weight, std=1e-3)
            # Scaling factor
            self.scaling = alpha / r
        else:
            # No LoRA (r=0): define dummy modules for completeness
            self.lora_down = None
            self.lora_up = None
            self.scaling = 1.0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.r > 0:
            # Compute base linear output (using frozen weight and (optional) bias)
            base_out = x.matmul(self.weight.T)
            if self.bias is not None:
                base_out += self.bias
            # Compute LoRA output
            lora_out = self.lora_up(self.lora_down(x)) * self.scaling
            return base_out + lora_out
        else:
            # If r=0, just do a regular linear
            return F.linear(x, self.weight, self.bias)

def apply_lora_to_module(module: nn.Module, r: int, alpha: int):
    """Recursively replace Linear layers in module with LinearWithLoRA."""
    for name, child in list(module.named_children()):
        # Recursively apply to child modules first
        apply_lora_to_module(child, r, alpha)
        # If child itself is linear, replace it
        if isinstance(child, nn.Linear):
            setattr(module, name, LinearWithLoRA(child, r, alpha))

# Load DINOv2 small backbone (no classifier head)
print("Loading DINOv2 backbone...")
backbone = AutoModel.from_pretrained('facebook/dinov2-small')
# The AutoModel returns a base model without any projection head.
# We'll manually extract the CLS embedding from it during forward passes.

# Apply LoRA to backbone
apply_lora_to_module(backbone, r=lora_rank, alpha=lora_alpha)

# Define DINO projection head (3-layer MLP)
class DINOHead(nn.Module):
    def __init__(self, in_dim=384, hidden_dim=2048, bottleneck_dim=256, out_dim=1024):
        super().__init__()
        # Layer 1: in -> hidden
        self.linear1 = nn.Linear(in_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        # Layer 2: hidden -> bottleneck
        self.linear2 = nn.Linear(hidden_dim, bottleneck_dim)
        self.bn2 = nn.BatchNorm1d(bottleneck_dim)
        # Layer 3: bottleneck -> out_dim (prototypes)
        self.linear3 = nn.Linear(bottleneck_dim, out_dim)
        # Initialize weights
        # (We can use default init or something like Kaiming. BatchNorm layers init gamma=1, beta=0 by default.)
        nn.init.trunc_normal_(self.linear1.weight, std=0.02)
        nn.init.trunc_normal_(self.linear2.weight, std=0.02)
        nn.init.trunc_normal_(self.linear3.weight, std=0.02)
        if self.linear1.bias is not None:
            nn.init.zeros_(self.linear1.bias)
        if self.linear2.bias is not None:
            nn.init.zeros_(self.linear2.bias)
        if self.linear3.bias is not None:
            nn.init.zeros_(self.linear3.bias)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        # Apply batch norm and GELU
        x = self.bn1(x)
        x = F.gelu(x, approximate='tanh')
        x = self.linear2(x)
        x = self.bn2(x)
        x = F.gelu(x, approximate='tanh')
        x = self.linear3(x)
        return x

# Create student and teacher networks
student_backbone = backbone
student_head = DINOHead(in_dim=backbone.config.hidden_size if hasattr(backbone, 'config') else 384,
                        hidden_dim=2048, bottleneck_dim=256, out_dim=1024)
# Clone the teacher from student (deep copy to separate weights)
import copy
teacher_backbone = copy.deepcopy(student_backbone)
teacher_head = copy.deepcopy(student_head)

# Move to device
student_backbone.to(device)
student_head.to(device)
teacher_backbone.to(device)
teacher_head.to(device)

# Freeze teacher parameters (no grad)
for p in teacher_backbone.parameters():
    p.requires_grad = False
for p in teacher_head.parameters():
    p.requires_grad = False

# Freeze student backbone base weights (LoRA parts remain trainable, base weight is already requires_grad False from wrapper)
# We already set backbone base Linear weights to requires_grad False in LinearWithLoRA.
# Ensure other non-LoRA parameters of backbone (like LayerNorms, position embeddings) are also frozen:
for name, param in student_backbone.named_parameters():
    # If it's a LoRA parameter, it will still be requires_grad True.
    if not param.requires_grad:
        continue  # already frozen (likely linear base weights)
    # For safety, freeze everything except LoRA:
    # We identify LoRA params by our module class
    if not isinstance(param, nn.Parameter):  # just a sanity check, all should be Parameter
        continue
    # We can also check name: our LoRA layers are named 'lora_down.weight' or 'lora_up.weight' inside LinearWithLoRA.
    if 'lora_down' in name or 'lora_up' in name:
        param.requires_grad = True  # LoRA params trainable
    else:
        param.requires_grad = False  # Freeze others (like LayerNorm gamma/beta, etc.)


Notes:
 - We replaced every linear layer in the ViT with our LinearWithLoRA. The original weights are kept but frozen; new lora_down and lora_up parameters are added and are the only trainable parts of those layers.

 - We also froze other backbone parameters such as layer norm weights and positional embeddings. This is optional — one might fine-tune normalization layers — but to stay parameter-efficient, we freeze everything except LoRA and the DINO head.
 
 - The teacher is a copy of the student model at initialization, so it starts with identical weights (including LoRA which are initially mostly zero). We will not train the teacher by gradient; we'll update it using momentum.

# Verify number of trainable parameters

In [None]:
# Count trainable parameters
total_params = 0
trainable_params = 0
for param in list(student_backbone.parameters()) + list(student_head.parameters()):
    numel = param.numel()
    total_params += numel
    if param.requires_grad:
        trainable_params += numel
print(f"Total parameters (student backbone+head): {total_params:,}")
print(f"Trainable parameters (with LoRA): {trainable_params:,}")

# Count trainable parameters
total_params = 0
trainable_params = 0
for param in list(student_backbone.parameters()) + list(student_head.parameters()):
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()
print(f"Total parameters (student backbone + head): {total_params:,}")
print(f"Trainable parameters (with LoRA): {trainable_params:,}")

# Helper functions

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import numpy as np

def extract_feats():
    """Forward the student through the *labelled* validation set
       and collect its 256-D features and integer labels."""
    student_backbone.eval()
    student_head.eval()
    feats, labs = [], []

    with torch.no_grad():
        for imgs, lbl in val_loader:                 # val_loader returns (B,3,H,W) + (B,)
            imgs = imgs.to(device)
            # CLS token from backbone
            cls = student_backbone(imgs,
                                   output_hidden_states=False).last_hidden_state[:, 0]
            vec = student_head(cls)                  # (B,256)
            feats.append(vec.cpu())
            labs.append(lbl)

    feats = torch.cat(feats).numpy()                 # (N,256)  ← N ≈ len(val_loader)*batch_size
    labs  = torch.cat(labs).numpy()                  # (N,)
    return feats, labs

def knn_acc(feats, labs, k=5):
    """Return top-1 accuracy of a k-NN probe (fit on 500 random points)."""
    idx = np.random.choice(len(feats), 500, replace=False)
    knn = KNeighborsClassifier(k)
    knn.fit(feats[idx], labs[idx])
    return knn.score(feats, labs)   # scalar in [0,1]

def fewshot(feats, labs, shots=20):
    """20-shot linear probe accuracy (logistic regression, max_iter=1000)."""
    classes = np.unique(labs)
    train_idx = []
    for c in classes:
        cls_idx = np.where(labs == c)[0]
        train_idx += list(np.random.choice(cls_idx, shots, replace=False))
    mask = np.ones(len(labs), dtype=bool)
    mask[train_idx] = False

    clf = LogisticRegression(max_iter=1000)
    clf.fit(feats[train_idx], labs[train_idx])
    return clf.score(feats[mask], labs[mask])


def evaluate(epoch):
    feats, labs = extract_feats()                           # (B,256) & (B,)
    knn   = knn_acc(feats, labs)
    few   = fewshot(feats, labs)
    print(f"\n📊  Epoch {epoch+1}  |  k-NN: {knn*100:5.2f}%  ·  20-shot: {few*100:5.2f}%")
    return knn, few

## 🔍 What Happens in the Training Loop — Line by Line (Deep Dive)


The loop starts by putting the student network (student_backbone + student_head) into training mode while forcing the teacher into evaluation mode. This lets the student update its Batch-Norm statistics and apply dropout if any, while the teacher remains a stable, fixed reference that never changes its running means or variances.

During each pass over the train_loader we receive a multi-crop batch. For every original STL-10 image the dataloader has already produced two large “global” crops and several smaller “local” crops. We move those tensors onto the MPS device with non_blocking=True, which overlaps the CPU-to-GPU copy with compute and squeezes a little more throughput out of Apple Silicon.

 - **Teacher forward** – under torch.no_grad() we feed the first global crop (g1) through the frozen teacher backbone, take only the [CLS] token (index 0) and push it through the teacher’s projection head. The result is a 1 024-dimensional “prototype-logits” vector. Because gradients are disabled this step costs almost no extra memory.

 - **Student forward**

	- The second global crop (g2) goes through the student backbone and head in the usual way, yielding another set of logits.

	- All local crops are packed into one large tensor of shape (B × N_local, C, H, W) and processed in a single vectorised call, which is far faster than looping. Their outputs are concatenated with the global student output.

 - **Target duplication** – the teacher produced one vector per image, so for local crops we simply repeat the teacher’s global vector N_local times. Now the student and teacher tensors have matching size and order.

 - **DINO loss** – we compute a KL divergence between the teacher’s sharpened distribution (temperature = 0.04) and the student’s smoothed distribution (temperature = 0.1). The temperature trick prevents the network from converging to degenerate constant vectors and keeps gradients well-scaled.

 - **Optimisation** – only LoRA matrices and the projection-head weights receive gradients, so optimizer.zero_grad(), loss.backward(), and clip_grad_norm_(…, 3.0) touch a tiny subset of parameters (< 2 million). Gradient clipping guards against rare spikes that can produce NaNs on the M2 GPU. A quick optimizer.step() updates those parameters.

 - **EMA teacher update** – immediately after the weight update we nudge every teacher parameter toward its student counterpart:
 θ_teacher ← m·θ_teacher + (1–m)·θ_student, where the momentum m slowly rises from 0.996 to 0.9995 over the course of training. That schedule makes the teacher track the student fairly closely early on (good for fast learning) and then become more stable later (good for convergence).

 - **Book-keeping** – we accumulate running means of the loss and per-batch execution time, and the tqdm progress bar shows them live along with the current learning rate.

At the end of an epoch we advance the learning-rate schedule. During the first five epochs we perform a linear warm-up from 0 to 6 × 10⁻⁴; afterwards a cosine annealing schedule gradually decays the LR toward 1 × 10⁻⁶. This simple schedule is robust and requires no manual tuning.

In [None]:
# Training loop ─ self-distillation with multi-crop and EMA teacher
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

# ── Optimiser: only LoRA + head are trainable 
trainable_params = list(filter(lambda p: p.requires_grad,
                               list(student_backbone.parameters()) +
                               list(student_head.parameters())))
optimizer = AdamW(trainable_params, lr=6e-4, weight_decay=0.04)

# ── Cosine schedule with linear warm-up (5 epochs) 
max_epochs = 20
warmup_epochs = 5
eval_every = 2           # run metrics every 2 epochs
memory_log = []
history = {"loss":[], "knn":[], "few":[]}

scheduler = CosineAnnealingLR(optimizer,
                              T_max=max_epochs - warmup_epochs,
                              eta_min=1e-6)

# Helper to get LR for tqdm bar
def current_lr():
    for pg in optimizer.param_groups:
        return pg["lr"]

# Momentum schedule for EMA teacher
def teacher_momentum(epoch, base_m=0.996, final_m=0.9995):
    """Linearly increase momentum from base to final across training."""
    return final_m - (final_m - base_m) * (1 + math.cos(math.pi * epoch / max_epochs)) / 2

# ── Training function 
def train_one_epoch(epoch):
    student_backbone.train()
    student_head.train()
    teacher_backbone.eval()      # teacher always in eval
    teacher_head.eval()
    total_loss, step = 0.0, 0
    loop = tqdm(train_loader, leave=False, desc=f"Epoch {epoch+1}/{max_epochs}")
    for batch in loop:
        # ----- Unpack multi-crop batch -------------------------------------------------
        g1 = batch['global1'].to(device, non_blocking=True)
        g2 = batch['global2'].to(device, non_blocking=True)
        # locals is list-of-lists: convert to (B*N_locals, C, H, W) tensor
        locals_list = sum(batch['locals'], [])   # flatten
        if locals_list:                          # guard if 0 locals
            locals_imgs = torch.stack(locals_list).to(device, non_blocking=True)

        # ----- Forward: teacher on g1, student on all crops ---------------------------
        with torch.no_grad():
            t_feat = teacher_backbone(g1, output_hidden_states=False).last_hidden_state[:, 0]
            t_out  = teacher_head(t_feat)                     # (B, 1024)

        # Student on global2
        s_feat_g2 = student_backbone(g2, output_hidden_states=False).last_hidden_state[:, 0]
        s_out_g2  = student_head(s_feat_g2)                   # (B, 1024)

        # Student on locals (if any)
        if locals_list:
            n_loc = locals_imgs.size(0)
            s_feat_loc = student_backbone(locals_imgs,
                                          output_hidden_states=False).last_hidden_state[:, 0]
            s_out_loc  = student_head(s_feat_loc)             # (B*N_loc, 1024)
            # Repeat teacher targets for local crops
            t_out_loc  = t_out.repeat_interleave(local_crops_number, dim=0)
            s_out = torch.cat([s_out_g2, s_out_loc], dim=0)
            t_out = torch.cat([t_out,    t_out_loc], dim=0)
        else:
            s_out = s_out_g2

        # ----- DINO loss --------------------------------------------------------------
        temp_student = 0.1
        temp_teacher = 0.04
        loss = -(F.softmax(t_out / temp_teacher, dim=-1).detach() *
                 F.log_softmax(s_out / temp_student, dim=-1)).sum(-1).mean()

        # ----- Optim step -------------------------------------------------------------
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, 3.0)
        optimizer.step()

        # ----- EMA update for teacher -------------------------------------------------
        m = teacher_momentum(epoch + step / len(train_loader))
        with torch.no_grad():
            for sp, tp in zip(student_backbone.parameters(), teacher_backbone.parameters()):
                tp.data = tp.data * m + sp.data * (1.0 - m)
            for sp, tp in zip(student_head.parameters(), teacher_head.parameters()):
                tp.data = tp.data * m + sp.data * (1.0 - m)

        # ----- Book-keeping -----------------------------------------------------------
        total_loss += loss.item()
        step += 1
        loop.set_postfix(loss=f"{total_loss/step:.4f}", lr=current_lr())

        if device.type == "mps":
            mem_mb = torch.mps.current_allocated_memory() / 1e6
        else:
            mem_mb = torch.cuda.memory_allocated() / 1e6 if torch.cuda.is_available() else 0
        memory_log.append(mem_mb)     


    # Scheduler step AFTER each epoch
    if epoch >= warmup_epochs:
        scheduler.step()
    else:
        # linear warm-up
        warm_lr = 6e-4 * (epoch + 1) / warmup_epochs
        for pg in optimizer.param_groups:
            pg["lr"] = warm_lr

    return total_loss / step


for epoch in range(max_epochs):                           # change to max_epochs for full run
    loss = train_one_epoch(epoch)
    history["loss"].append(loss)
    if (epoch % eval_every) == 0 or epoch == max_epochs-1:
        knn, few = evaluate(epoch)
        history["knn"].append(knn); history["few"].append(few)
    print(f"Epoch {epoch+1}: DINO loss {loss:.4f}")
    epoch_mem = np.mean(memory_log)
    print(f" avg MPS-RAM this epoch: {epoch_mem:6.1f} MB")


In [None]:
# Count params
total_params = sum(p.numel() for p in list(student_backbone.parameters()) + list(student_head.parameters()))
train_params = sum(p.numel() for p in list(student_backbone.parameters()) + list(student_head.parameters()) if p.requires_grad)

# Quick bar chart
import matplotlib.pyplot as plt
plt.figure()
plt.bar(["Total", "Trainable"], [total_params/1e6, train_params/1e6])
plt.ylabel("Millions of parameters")
plt.title("Parameter budget: full model vs. LoRA-tuned subset")
plt.show()
