# Tutorial 6.3: Self-distillation with no labels (DINO)

Author: [Erik Syniawa](mailto:erik.syniawa@informatik.tu-chemnitz.de)

The self-supervised approach Self-**di**stillation with **no** labels (DINO) has been introduced by Caron et al., 2021 [[1](#caron2021self)].

The Vision Transformer (ViT; [[3](#dosovit)]) demonstrated competitive performance with CNNs, but did not show dramatic benefits (see our notebook).
A key factor of the success of Transformers in NLP comes from the use of self-supervised training. See:

- [BERT (Devlin et al., 2019)](https://aclanthology.org/N19-1423/?utm_campaign=The%20Batch&utm_source=hs_email&utm_medium=email): 
    - **Learning**: Learns by predicting masked words using context from both directions
 
    - **Prediction**: Predicts missing words in text, enabling understanding of language relationships
        
- [GPT (Radford et al., 2019)](https://storage.prod.researchhub.com/uploads/papers/2020/06/01/language-models.pdf):
    - **Learning**: Learns by predicting the next word given all previous words in a sequence

    - **Prediction**: Predicts and generates coherent text continuations by anticipating what comes next 

But our ViT is trained in a supervised manner with a lot of labeled data. So would self-supervised training help training ViTs?

<div align="center">
    <img src="figures/example.gif" width="700"/>
    <p><i>Figure 1: Self-attention extracted from DINO (animation from [2])</i></p>
</div>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import torchvision.transforms.v2 as v2
from torchvision.io import read_image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import sys
import copy
import time
import math
from tqdm.notebook import tqdm
from PIL import Image, ImageFilter, ImageOps

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'PyTorch version: {torch.__version__} running on {device}')

# Add utils path
notebook_dir = os.getcwd()
root_path = os.path.abspath(os.path.join(notebook_dir, ".."))
if root_path not in sys.path:
    sys.path.append(root_path)
    print(f"Added {root_path} to sys.path")

from Utils.dataloaders import prepare_UTKFace_age_task
from Utils.little_helpers import timer, set_seed, get_parameters
from Utils.functions import train_model, evaluate_model, test_model
from Utils.optimizers import LARS

set_seed(42)

## 1. Self-Distillation Architecture

DINO employs a novel approach to self-supervised learning, building on the concept of knowledge distillation but without any labels. The key insight is to use a teacher-student architecture similar to [[4](#grill2020bootstrap)] where both networks share the same architecture but have different sets of parameters (see Tutorial 8.2 (BYOL)).

### 1.1 Teacher-Student Framework

<div align="center">
    <img src="figures/dino.gif" width="700"/>
    <p><i>Figure 2: DINO main architecture (animation from [2])</i></p>
</div>

The DINO framework consists of:

1. **Student Network**: Updated through direct gradient backpropagation
2. **Teacher Network**: Updated through exponential moving average (EMA) of the student's weights
3. **Different Views**: Multiple augmented views of the same image are created, with:
   - Global views (covering most of the image)
   - Local views (smaller crops of the image)

The teacher-student framework in DINO addresses a fundamental challenge in self-supervised learning: **how to avoid collapse** (all images mapping to the same representation).
The asymmetric architecture helps with that: Student sees all views (global + local), teacher only sees global views. That:

- forces student to predict global context from local patches
- creates a natural pretext task: "local-to-global correspondence"

Updating the teacher network - the **Momentum Teacher** - where teacher parameters are EMA of student parameters. This leads to:

- **Stability**: Teacher provides consistent, slowly-changing targets
- **Quality**: Teacher represents a "better" version of the student (ensemble effect)
- **Preventing Collapse**: Student can't immediately change teacher's outputs

#### Mathematical Intuition Behind EMA Update

The exponential moving average update mechanism forms the mathematical foundation of DINO's stability. The teacher parameters are updated according to the following equation:

$$\theta_t \leftarrow \lambda \theta_t + (1 - \lambda) \theta_s$$

where $\theta_t$ represents the teacher parameters, $\theta_s$ represents the student parameters, and $\lambda$ is the momentum coefficient.

**Mathematical Properties and Implications:** The momentum coefficient $\lambda \approx 0.996$ ensures that teacher parameters change slowly, providing stable targets for the student network. This high momentum value means that each update incorporates only 0.4% of the current student parameters, creating a smoothed version of the student's learning trajectory.

**Ensemble Effect Through Temporal Averaging:** The teacher network represents accumulated knowledge over many training steps. Mathematically, the teacher parameters at step $t$ can be expressed as a weighted sum of all previous student parameters:

$$\theta_t^{teacher} = (1-\lambda) \sum_{i=0}^{t} \lambda^i \theta_{t-i}^{student}$$

This formulation demonstrates that the teacher maintains a memory of the student's evolution, with exponentially decaying weights for older parameters.

**Quality Improvement Mechanism:** In practice, the EMA teacher often outperforms the student throughout training due to the variance reduction effect of parameter averaging. The mathematical intuition follows from the bias-variance decomposition: while individual student parameters may have high variance due to stochastic gradient updates, the EMA teacher reduces this variance while maintaining low bias.

The center vector follows a similar EMA update rule with momentum parameter $m$:

$$c \leftarrow m \cdot c + (1-m) \cdot \frac{1}{B} \sum_{i=1}^{B} g_{\theta_t}(x_i)$$

where $B$ represents the batch size. This centering operation ensures that the teacher outputs maintain zero mean across the feature dimensions, preventing dimensional collapse while preserving the relative relationships between features.


```python
# Pseudo-code for DINO training process
# gs, gt: student and teacher networks
# C: center (dimensionality K)
# tps, tpt: student and teacher temperatures
# l, m: network and center momentum rates

gt.params = gs.params  # Initialize teacher with student weights

for x in loader:  # Load a minibatch with n samples
    x1, x2 = augment(x), augment(x)  # Create random views
    local_views = [augment(x) for _ in range(local_crops_number)]
    
    s1, s2 = gs(x1), gs(x2)  # Student output for global views
    s_local = [gs(v) for v in local_views]  # Student output for local views
    
    t1, t2 = gt(x1), gt(x2)  # Teacher output for global views (only)
    
    # Calculate loss: cross-entropy between student and teacher
    loss = H(t1, s2)/2 + H(t2, s1)/2
    for s_l in s_local:
        loss += (H(t1, s_l) + H(t2, s_l)) / (2 * len(local_views))
    
    loss.backward()  # Back-propagate
    
    # Update student with SGD
    update(gs)
    
    # Update teacher with momentum
    gt.params = l * gt.params + (1-l) * gs.params
    
    # Update center with momentum
    C = m * C + (1-m) * cat([t1, t2]).mean(dim=0)
```


### 1.2 Loss Function

DINO uses a cross-entropy loss between the student and teacher outputs:

```python
def H(t, s):
    t = t.detach()  # Stop gradient on teacher
    s = softmax(s / tps, dim=1)  # Student output with temperature
    t = softmax((t - C) / tpt, dim=1)  # Teacher output with centering and sharpening
    return - (t * log(s)).sum(dim=1).mean()
```

The teacher output is processed with:
1. **Centering**: Subtract a moving average of the teacher output
2. **Sharpening**: Apply a lower temperature to make the distribution more peaked

These two simple techniques effectively prevent the model from collapsing to trivial solutions.

In [None]:
class DINOLoss(nn.Module):
    def __init__(self, 
                 out_dim: int, 
                 ncrops: int, 
                 warmup_teacher_temp: float, 
                 teacher_temp: float,
                 warmup_teacher_temp_epochs: float, 
                 nepochs: int, 
                 student_temp: float = 0.1,
                 center_momentum: float = 0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))

        # Teacher temperature schedule
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # Teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1

        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        batch_center = batch_center / len(teacher_output)
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)


The complete DINO loss function incorporates multiple views of the same image. Given a set of views $\mathcal{V} = \{x_1^g, x_2^g, x_1^l, ..., x_N^l\}$ where $g$ denotes global views and $l$ denotes local views, the loss is computed as:

$$\mathcal{L}_{DINO} = \sum_{x \in \{x_1^g, x_2^g\}} \sum_{\substack{x' \in \mathcal{V} \\ x' \neq x}} H(P_t(x), P_s(x'))$$ 

This asymmetric formulation ensures that:

1. **Global-to-Global**: Student global views learn from teacher global views
2. **Local-to-Global**: Student local views learn from teacher global views
3. **Teacher Consistency**: Teacher only processes global views, providing stable targets

The mathematical intuition is that local views must predict the global context, forcing the model to understand part-to-whole relationships without supervision.


## 2. Image Augmentations

### 2.1 Multi-Crop Strategy

The multi-crop strategy creates two types of views:
- **Global Views**: Higher resolution crops (224×224) covering >50% of the image
- **Local Views**: Lower resolution crops (96×96) covering <50% of the image

Let's examine key augmentations used in DINO (mostly they are the same as in BYOL [[4](#grill2020bootstrap)]):

### 2.2 Specialized Augmentations

Next to simple augmentations like crop, flip, or gray scaling, there are more complex augmentations like the following ones:

#### 2.2.1 Gaussian Blur

In [None]:
class GaussianBlur:
    def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        if isinstance(img, torch.Tensor):
            img = v2.functional.to_pil_image(img)

        if torch.rand(1).item() <= self.prob:
            radius = torch.empty(1).uniform_(self.radius_min, self.radius_max).item()
            img = img.filter(ImageFilter.GaussianBlur(radius=radius))
        return img


#### 2.2.2 Solarization

In [None]:
class Solarization:
    def __init__(self, p=0.2):
        self.prob = p

    def __call__(self, img):
        if isinstance(img, torch.Tensor):
            img = v2.functional.to_pil_image(img)

        if torch.rand(1).item() <= self.prob:
            img = ImageOps.solarize(img)
        return img


### 2.3 Data Augmentation Pipeline

Now we can stich the different augmentations together.

In [None]:

class DataAugmentationDINO:
    def __init__(self, global_crops_scale, local_crops_scale, local_crops_number, input_size=128):
        self.input_size = input_size
        self.local_crops_number = local_crops_number

        color_jitter = v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)

        # Global crop 1 (always blur)
        self.global_transform1 = v2.Compose([
            v2.RandomResizedCrop(self.input_size, scale=global_crops_scale),
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomApply([color_jitter], p=0.8),
            v2.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

        # Global crop 2 (sometimes blur + solarization)
        self.global_transform2 = v2.Compose([
            v2.RandomResizedCrop(self.input_size, scale=global_crops_scale),
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomApply([color_jitter], p=0.8),
            v2.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

        # Local crops
        self.local_transform = v2.Compose([
            v2.RandomResizedCrop(self.input_size, scale=local_crops_scale),
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomApply([color_jitter], p=0.8),
            v2.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.5),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

    def __call__(self, img):
        crops = []
        crops.append(self.global_transform1(img))
        crops.append(self.global_transform2(img))

        for _ in range(self.local_crops_number):
            crops.append(self.local_transform(img))

        return crops

### 2.4 Applying the Augmentation Pipeline

Let's have a look on the Augmentation Pipeline for an example image.

First see, how the original image looks like.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import requests

# Load a sample image 
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# Display the original image
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')
plt.show()

Now, have a look on the different augmentations.

In [None]:
# First resize the image to maintain consistent size
image_resized = v2.Resize(224)(image)

# Create a figure with subplots to show all augmentations together
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle("Individual DINO Augmentations", fontsize=16)

# Original Image
axes[0, 0].imshow(image_resized)
axes[0, 0].set_title("Original")

# Random Resized Crop (Global)
crop_transform = v2.RandomResizedCrop(224, scale=(0.4, 1.0))
crop_img = crop_transform(image_resized)
axes[0, 1].imshow(crop_img)
axes[0, 1].set_title("Random Resized Crop (Global)")

# Color Jitter
color_jitter = v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
jitter_img = color_jitter(image_resized)
axes[0, 2].imshow(jitter_img)
axes[0, 2].set_title("Color Jitter")

# Grayscale
grayscale = v2.RandomGrayscale(p=1.0)  # p=1.0 to force grayscale
gray_img = grayscale(image_resized)
axes[1, 0].imshow(gray_img)
axes[1, 0].set_title("Grayscale")

# Gaussian Blur
blur = GaussianBlur(p=1.0, radius_min=1.0, radius_max=1.0)  # p=1.0 to force blur
blur_img = blur(image_resized)
axes[1, 1].imshow(blur_img)
axes[1, 1].set_title("Gaussian Blur")

# Solarization
solarize = Solarization(p=1.0)  # p=1.0 to force solarization
solar_img = solarize(image_resized)
axes[1, 2].imshow(solar_img)
axes[1, 2].set_title("Solarization")

# Remove axis ticks
for ax in axes.flatten():
    ax.axis('off')

plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()

# Also demonstrate a local crop comparison
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(image_resized)
plt.title("Original (Resized)")
plt.axis('off')

plt.subplot(1, 2, 2)
local_crop_transform = v2.RandomResizedCrop(96, scale=(0.05, 0.4))
local_crop = local_crop_transform(image)
plt.imshow(local_crop)
plt.title("Random Resized Crop (Local View)")
plt.axis('off')

plt.tight_layout()
plt.show()

Finally, the full pipeline produces such images.

In [None]:
# Apply DINO augmentations with explicit parameters
transform = DataAugmentationDINO(
    global_crops_scale=(0.4, 1.0),
    local_crops_scale=(0.05, 0.4),
    local_crops_number=8,
    input_size=224  # Set input_size to match TinyImageNet dimensions
)

# Generate different crops of the image
crops = transform(image)

# Prepare a resized version of the original for consistent display
original_resized = v2.Resize(224)(image)
original_tensor = v2.ToImage()(original_resized)
original_tensor = v2.ToDtype(torch.float32, scale=True)(original_tensor)

# Display the first few augmented views
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
fig.suptitle("DINO Multi-Crop Augmentations", fontsize=16)

# First row: Original + 2 global views + 2 local views
axes[0, 0].imshow(original_tensor.permute(1, 2, 0).numpy())
axes[0, 0].set_title("Original (Resized)")

axes[0, 1].imshow(crops[0].permute(1, 2, 0).numpy())
axes[0, 1].set_title("Global View 1")

axes[0, 2].imshow(crops[1].permute(1, 2, 0).numpy())
axes[0, 2].set_title("Global View 2")

axes[0, 3].imshow(crops[2].permute(1, 2, 0).numpy())
axes[0, 3].set_title("Local View 1")

axes[0, 4].imshow(crops[3].permute(1, 2, 0).numpy())
axes[0, 4].set_title("Local View 2")

# Second row: 5 more local views
for i in range(5):
    if i+4 < len(crops):
        axes[1, i].imshow(crops[i+4].permute(1, 2, 0).numpy())
        axes[1, i].set_title(f"Local View {i+3}")

# Remove axis ticks
for ax in axes.flatten():
    ax.axis('off')

plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()



## 3. Backbone & DINO Head Architecture

DINO supports both traditional convolutional architectures (ResNet) and Vision Transformers (ViT). First we will implement a ViT via [`timm`](https://github.com/huggingface/pytorch-image-models) and in a later exercise you will implement a ResNet.

### 3.1 Vision Transformer (ViT)

The Vision Transformer divides an image into fixed-size patches, linearly embeds them, adds position embeddings, and processes the sequence with transformer blocks:

In [None]:
# create a ViT from timm
import timm
def create_vit_model(model_name='vit_small_patch16_224', img_size=128, num_classes=0):
    """Create ViT model from timm"""
    model = timm.create_model(
        model_name, 
        pretrained=False,
        img_size=img_size,
        num_classes=num_classes
    )
    return model




### 3.2 DINO Head

The DINO Head is a projection head that maps the backbone features to the space where the self-distillation is performed:



In [None]:
class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=False):
        super().__init__()
        hidden_dim = in_dim
        
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim) if use_bn else nn.Identity(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim) if use_bn else nn.Identity(),
        )

        self.last_layer = nn.Linear(hidden_dim, out_dim, bias=False)
        self.norm_last_layer = norm_last_layer

    def forward(self, x):
        x = self.mlp(x)
        
        if self.norm_last_layer:
            w = F.normalize(self.last_layer.weight, dim=1, p=2)
            x = F.linear(x, w)
        else:
            x = self.last_layer(x)
            
        return x



### 3.3 MultiCropWrapper

A special wrapper class helps handling multiple crops of different resolutions:


In [None]:

class MultiCropWrapper(nn.Module):
    def __init__(self, backbone, head):
        super().__init__()
        backbone.head = nn.Identity()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in x]),
            return_counts=True,
        )[1], 0)

        start_idx, output = 0, torch.empty(0).to(x[0].device)
        for end_idx in idx_crops:
            _out = self.backbone(torch.cat(x[start_idx: end_idx]))
            
            if isinstance(_out, tuple):
                _out = _out[0]
            
            output = torch.cat((output, _out))
            start_idx = end_idx

        return self.head(output)


## 4. Evaluation Methods

DINO produces rich visual representations that can be evaluated in various ways:

### 4.1 k-NN Classification

One of the simplest evaluation methods is k-nearest neighbors classification:

In [None]:

def evaluate_knn(model, data_loader, k=20):
    """
    Evaluate the model using k-NN classification.
    """
    model.eval()
    features = []
    targets = []
    
    # Extract features
    with torch.no_grad():
        for imgs, labels in data_loader:
            imgs = imgs.to(device)
            feats = model(imgs)
            
            # Normalize features
            feats = F.normalize(feats, dim=1, p=2)
            
            features.append(feats.cpu())
            targets.append(labels)
    
    # Concatenate all features and targets
    features = torch.cat(features, dim=0)
    targets = torch.cat(targets, dim=0)
    
    # Compute similarities
    similarity = torch.mm(features, features.t())
    
    # Get top-k neighbors
    _, indices = similarity.topk(k + 1, dim=1, largest=True)
    indices = indices[:, 1:]  # Exclude self
    
    # Get labels of neighbors
    neighbor_labels = torch.gather(targets.unsqueeze(1).expand(-1, k), 1, indices)
    
    # Predict by majority voting
    predictions = torch.mode(neighbor_labels, dim=1)[0]
    
    # Compute accuracy
    correct = (predictions == targets).sum().item()
    accuracy = 100 * correct / len(targets)
    
    return accuracy


### 4.2 Linear Probing

Another common evaluation method is training a linear classifier on top of frozen features:


In [None]:

def linear_eval(model, train_loader, val_loader, epochs=100):
    """
    Evaluate model by training a linear classifier on frozen features.
    """
    # Freeze backbone
    model.requires_grad_(False)
    
    # Create linear classifier
    classifier = nn.Linear(model.embed_dim, 200).to(device)  # 200 classes in TinyImageNet
    
    # Optimizer
    optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = nn.CrossEntropyLoss()
    
    best_acc = 0
    
    for epoch in range(epochs):
        # Training
        classifier.train()
        for images, targets in train_loader:
            images, targets = images.to(device), targets.to(device)
            
            # Extract features
            with torch.no_grad():
                features = model(images)
            
            # Forward pass through classifier
            outputs = classifier(features)
            loss = criterion(outputs, targets)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        scheduler.step()
        
        # Validation
        classifier.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, targets in val_loader:
                images, targets = images.to(device), targets.to(device)
                
                # Extract features
                features = model(images)
                
                # Forward pass through classifier
                outputs = classifier(features)
                _, predicted = outputs.max(1)
                
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        accuracy = 100 * correct / total
        
        if accuracy > best_acc:
            best_acc = accuracy
    
    return best_acc

## 5. References

1. Caron, M., Touvron, H., Misra, I., Jégou, H., Mairal, J., Bojanowski, P., & Joulin, A. (2021). Emerging properties in self-supervised vision transformers. *International Conference on Computer Vision (ICCV)*.  <a id="caron2021self"></a>
2. <https://github.com/facebookresearch/dino>   <a id="caron2021github"></a>
3. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. *International Conference on Learning Representations (ICLR)*. <a id="dosovit"></a>
4. Grill, J. B., Strub, F., Altché, F., Tallec, C., Richemond, P. H., Buchatskaya, E., ... & Valko, M. (2020). Bootstrap your own latent: A new approach to self-supervised learning. *Advances in Neural Information Processing Systems (NeurIPS)*. <a id="grill2020bootstrap"></a>

## DINO in action

In [None]:
# Configuration
config = {
    'model_name': 'vit_small_patch16_224',
    'img_size': 128,
    'batch_size': 32,
    'epochs': 100,
    'lr': 0.0005,
    'weight_decay': 0.04,
    'momentum_teacher': 0.996,
    'out_dim': 2048,
    'local_crops_number': 8,
    'warmup_teacher_temp': 0.04,
    'teacher_temp': 0.07,
    'warmup_teacher_temp_epochs': 30,
    'student_temp': 0.1,
    'center_momentum': 0.9,
}


### Training Process

The training process follows these steps:

1. Initialize student and teacher networks with the same architecture
2. Setup loss function, optimizer, and learning rate scheduler
3. For each epoch:
   - Generate multiple views of each image
   - Compute student and teacher outputs
   - Calculate DINO loss
   - Update student via backpropagation
   - Update teacher with EMA
   - Update center with EMA


First, prepare the dataset:

In [None]:
def make_annotation_file(path_data, split):
    list_data = []
    
    for idx, file in enumerate(os.listdir(path_data)):
        try:
            _, name, age, gender = file.split('_')
            gender = gender.split('.')[0]
            list_data.append((path_data+'/'+file, name, age, gender))
        except:
            continue

    n_samples = len(list_data)
    idx = np.linspace(0, n_samples-1, n_samples, dtype='int32')
    np.random.shuffle(idx)    
    list_data = np.asarray(list_data)[idx]

    n_train = int(n_samples * split[0])
    n_test = int(n_samples * split[1])

    train_data = list_data[:n_train]
    test_data = list_data[n_train:n_train+n_test]
    
    csv_train = pd.DataFrame(train_data, columns=['Images','Name','Age','Gender'])
    csv_test = pd.DataFrame(test_data, columns=['Images','Name','Age','Gender'])

    csv_train.to_csv('annot_train.csv', sep=',', index=False)
    csv_test.to_csv('annot_test.csv', sep=',', index=False)

    return csv_train, csv_test

# Create dataset annotations
path_data = '../Dataset/AgeDB/'
split = [0.8, 0.2]

labels_train, labels_test = make_annotation_file(path_data, split)


*Remember: If you're working on a Windows system use 0 workers to avoid multiprocessing issues (Windows uses "spawn" for creating processes instead of "fork", which can cause issues with multiprocessing). Thus, `num_workers = 0`.*

In [None]:
class DataSetAugmentDINO(Dataset):
    def __init__(self, phase, annotations_file, local_crops_number=8):
        self.phase = phase
        self.img_labels = pd.read_csv(annotations_file)
        self.local_crops_number = local_crops_number
        
        if phase == 'train':
            self.transform = DataAugmentationDINO(
                global_crops_scale=(0.4, 1.0),
                local_crops_scale=(0.05, 0.4),
                local_crops_number=local_crops_number,
                input_size=128
            )
        else:
            self.transform = v2.Compose([
                v2.Resize(128),
                v2.CenterCrop(128),
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ])

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = self.img_labels.iloc[idx, 0]
        label = self.img_labels.iloc[idx, 3]
        
        try:
            # Load as PIL Image directly
            from PIL import Image
            x = Image.open(img_path).convert('RGB')
            
            if self.phase == 'train':
                crops = self.transform(x)
                return crops, label
            else:
                x = self.transform(x)
                return x, label
        
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            if self.phase == 'train':
                dummy_crops = [torch.zeros(3, 128, 128) for _ in range(2 + self.local_crops_number)]
                return dummy_crops, 'unknown'
            else:
                return torch.zeros(3, 128, 128), 'unknown'

# Create datasets
batch_size = 32
local_crops_number = 8

train_set = DataSetAugmentDINO('train', 'annot_train.csv', local_crops_number)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)

valid_set = DataSetAugmentDINO('valid', 'annot_test.csv', local_crops_number)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=4)

Second, configure our network:

In [None]:
# Create model
print(f"Creating model: {config['model_name']}")
student = create_vit_model(config['model_name'], config['img_size'], num_classes=0)
embed_dim = student.embed_dim

# Create DINO head
head = DINOHead(
    in_dim=embed_dim,
    out_dim=config['out_dim'],
    use_bn=False,
    norm_last_layer=True,
)

# Wrap with MultiCropWrapper
student = MultiCropWrapper(student, head)
student = student.to(device)

# Create teacher (EMA copy)
teacher = copy.deepcopy(student)
for param in teacher.parameters():
    param.requires_grad = False

print(f'Trainable Parameters: {get_parameters(student):.3f}M')

Finally, define loss function, optimizer, and learning rate scheduler:

In [None]:
dino_loss = DINOLoss(
    out_dim=config['out_dim'],
    ncrops=2 + config['local_crops_number'],
    warmup_teacher_temp=config['warmup_teacher_temp'],
    teacher_temp=config['teacher_temp'],
    warmup_teacher_temp_epochs=config['warmup_teacher_temp_epochs'],
    nepochs=config['epochs'],
    student_temp=config['student_temp'],
    center_momentum=config['center_momentum'],
).to(device)

# Create optimizer
optimizer = torch.optim.AdamW(
    student.parameters(),
    lr=config['lr'] * config['batch_size'] / 256,
    weight_decay=config['weight_decay'],
)
# Learning rate scheduler
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0):
    warmup_iters = warmup_epochs * niter_per_ep
    total_iters = epochs * niter_per_ep
    
    schedule = []
    if warmup_iters > 0:
        warmup_schedule = np.linspace(0, base_value, warmup_iters)
        schedule.extend(warmup_schedule)
    
    iters = np.arange(total_iters - len(schedule))
    cosine_schedule = final_value + 0.5 * (base_value - final_value) * \
                      (1 + np.cos(np.pi * iters / len(iters)))
    
    schedule.extend(cosine_schedule)
    return schedule

niter_per_ep = len(train_loader)
lr_schedule = cosine_scheduler(
    config['lr'] * config['batch_size'] / 256,
    0,
    config['epochs'], 
    niter_per_ep,
    warmup_epochs=10
)

momentum_schedule = cosine_scheduler(config['momentum_teacher'], 1, config['epochs'], niter_per_ep)

wd_schedule = cosine_scheduler(
    config['weight_decay'],
    0.4,  # weight_decay_end
    config['epochs'], 
    len(train_loader)
)

Putting everything together and start the training.

In [None]:
print(f"Starting DINO training for {config['epochs']} epochs...")

train_losses = []
lr_values = []

for epoch in range(config['epochs']):
    student.train()
    epoch_loss = 0
    num_batches = 0
    
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["epochs"]}')
    
    for i, (images, _) in enumerate(train_pbar):
        # Update learning rate
        it = len(train_loader) * epoch + i
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_schedule[it]
            if i == 0:
                param_group["weight_decay"] = wd_schedule[it]
        
        # Move images to device
        images = [im.to(device, non_blocking=True) for im in images]

        # Forward passes
        teacher_output = teacher(images[:2])  # Only global views
        student_output = student(images)     # All views
        
        loss = dino_loss(student_output, teacher_output, epoch)

        if not math.isfinite(loss.item()):
            print(f"Loss is {loss.item()}, stopping training")
            break

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        nn.utils.clip_grad_norm_(student.parameters(), 3.0)
        optimizer.step()

        # EMA update for teacher
        with torch.no_grad():
            m = momentum_schedule[it]
            for param_q, param_k in zip(student.parameters(), teacher.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

        epoch_loss += loss.item()
        num_batches += 1
        
        train_pbar.set_postfix({
            'loss': loss.item(),
            'lr': optimizer.param_groups[0]["lr"]
        })

    avg_loss = epoch_loss / num_batches
    train_losses.append(avg_loss)
    lr_values.append(optimizer.param_groups[0]["lr"])
    
    print(f"Epoch {epoch+1}/{config['epochs']}: Loss = {avg_loss:.4f}, LR = {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save checkpoint every 20 epochs
    if (epoch + 1) % 20 == 0 or epoch == config['epochs'] - 1:
        torch.save({
            'student': student.state_dict(),
            'teacher': teacher.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'config': config,
            'loss': avg_loss
        }, f'dino_checkpoint_{epoch+1:04d}.pth')

print("DINO pretraining completed!")

Plot some evaluation.

In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(lr_values)
plt.title('Learning Rate Schedule')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.grid(True)

plt.tight_layout()
plt.show()

### Fine-Tuning to predict gender

As in the previous notebooks, we want to fine-tune the model to predict the two gender classes.

First, create a class for the fine-tuning model.

In [None]:
class FineTuneModel(nn.Module):
    def __init__(self, base_model, num_classes, freeze_base=False):
        super().__init__()
        self.base_model = base_model
        
        if freeze_base:
            self.base_model.requires_grad_(False)
        else:
            self.base_model.requires_grad_(True)

        embed_dim = self.base_model.backbone.embed_dim
        
        self.new_head = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = x.to(device)
        features = self.base_model.backbone(x)
        if isinstance(features, tuple):
            features = features[0]
        
        out = self.new_head(features)
        return out

# Load pretrained DINO model
checkpoint_path = 'dino_checkpoint_0100.pth'
if os.path.exists(checkpoint_path):
    print("Loading pretrained DINO model...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Recreate model with same config
    config_loaded = checkpoint['config']
    student_loaded = create_vit_model(config_loaded['model_name'], config_loaded['img_size'])
    embed_dim = student_loaded.embed_dim
    
    head_loaded = DINOHead(embed_dim, config_loaded['out_dim'], use_bn=False, norm_last_layer=True)
    dino_model = MultiCropWrapper(student_loaded, head_loaded)
    dino_model.load_state_dict(checkpoint['student'])
    dino_model = dino_model.to(device)
    
    # Create fine-tuning model
    gender_model = FineTuneModel(dino_model, num_classes=2, freeze_base=False)
    gender_model = gender_model.to(device)
    
    print(f'Fine-tuning model parameters: {get_parameters(gender_model):.3f}M')
else:
    print(f"Checkpoint {checkpoint_path} not found. Please complete DINO training first.")

Define Optimizer and Loss function.

In [None]:
import torch.optim as optim

num_epochs=5
init_lr= 1e-4
optimizer = optim.AdamW(gender_model.parameters(), lr=init_lr, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()


Start fine-tuning.

In [None]:
for epoch in range(num_epochs):
    gender_model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} [Train]')

    for x_i, labels in train_pbar:
        images = x_i[0]
        if images.size()[0] < batch_size:
            ##ignore the last batch, if it did not fit
            break
        # Move images to device
        images = images.to(device).float()

        y = np.zeros(len(labels))
        y[np.asarray(labels)=='m'] = 0
        y[np.asarray(labels)=='f'] = 1
        y = torch.tensor(y)
        y = y.to(torch.int64)
        
        optimizer.zero_grad()

        # Forward pass
        out = gender_model(images)
        out = out.to(device)
        y = y.to(device)
        
        loss = criterion(out, y)

        loss.backward()
        optimizer.step()
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(out, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

        # Update progress bar
        train_pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})

    # Calculate epoch statistics
    epoch_train_loss: float = running_loss / len(train_loader.dataset)
    epoch_train_acc: float = 100 * correct / total

Evaluate the gender prediction.

In [None]:
from sklearn.metrics import classification_report

gender_model.eval()
gender_model.to(device)

predictions = []
label_list = []

with torch.no_grad():
    eval_pbar = tqdm(valid_loader, desc='Evaluation')

    for x_i,labels in eval_pbar:
        if x_i.size()[0] < batch_size:
            ##ignore the last batch, if it did not fit 
            break
        x_i = x_i.to(device)
        out = gender_model(x_i)
        if isinstance(out, tuple):
            out = out[0]
        _, predicted = torch.max(out, 1)
        predictions.append(predicted.cpu())
        label_list.append(labels)

predictions = np.asarray(predictions).flatten()
label_list = np.asarray(label_list).flatten()
labels = np.zeros(len(label_list))
labels[label_list=='f'] = 1



print(classification_report(labels, predictions))

## Exercises

### 1. Plot the attention maps on a random AgeDB image (see Tutorial 4.1 (ViT))

In [None]:
### your code here



### 2. Change the encoder network to ResNet50

In [None]:
# your code here

### 3. Use t-SNE to see how gender is represented in the latent representation from the online network.

In [None]:
# your code here

### 4. Change the augmentation pipeline

In [None]:
# your code here