# Tutorial 4.1: Vision Transformer (ViT)

Author: [Erik Syniawa](mailto:erik.syniawa@informatik.tu-chemnitz.de)

Published by Dosovitskiy et al. (2020) in [[1](#6-references)].

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os, sys
from typing import Optional, Tuple, List

notebook_dir = os.getcwd()
root_path = os.path.abspath(os.path.join(notebook_dir, ".."))
if root_path not in sys.path:
    sys.path.append(root_path)
    print(f"Added {root_path} to sys.path")
    
from Utils.little_helpers import timer, set_seed

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'PyTorch version: {torch.__version__} running on {device}')


## Architecture

Vision Transformers (ViT) represent a significant shift in computer vision approaches by applying the Transformer architecture - originally designed for NLP - directly to image classification tasks.

<div align="center">
    <img src="figures/vit_animation.gif" width="700"/>
    <p><i>Figure 1: ViT architecture. Source: [2]</i></p>
</div>

The general architecture of ViT consists of the following components:

### 1. Patch Embedding

Since Transformers have a quadratic computational complexity due to self-attention, it is important to reduce the input dimensionality via a patch embedding.
- The input image is divided into fixed-size patches (e.g., 16x16 pixels)
- Each patch is flattened into a vector
- A linear or a convolutional projection maps each flattened patch to a __learnable__ lower-dimensional embedding vector
- This creates a sequence of patch embeddings that serve as the input to the Transformer


Let's have a look how this looks on an example image:

In [None]:
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO

# Load and preprocess image
image_url = "https://images.unsplash.com/photo-1501854140801-50d01698950b"
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))

img_size = 224
patch_size = 16

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
])

# Apply transformations and convert to tensor
x = transform(img).unsqueeze(0)  # [1, C, H, W]

# Extract patches directly using unfold
B, C, H, W = x.shape
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
patches = patches.contiguous().view(B, C, -1, patch_size, patch_size)
patches = patches.permute(0, 2, 1, 3, 4)  # [B, num_patches, C, patch_size, patch_size]

# Reshape for visualization
patches_for_grid = patches.squeeze(0)  # [num_patches, C, patch_size, patch_size]

# Create grid visualization
grid = make_grid(patches_for_grid, nrow=img_size//patch_size, padding=1)
grid_np = grid.permute(1, 2, 0).numpy()

# Create linear layout (all patches in one row)
num_display_patches = 64
linear_grid = make_grid(patches_for_grid, nrow=num_display_patches, padding=1)
linear_grid_np = linear_grid.permute(1, 2, 0).numpy()

# Display original, grid, and linear layout
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 10))

ax1.imshow(x.squeeze(0).permute(1, 2, 0).numpy())
ax1.set_title(f'Original Image ({img_size}×{img_size})')
ax1.axis('off')

ax2.imshow(grid_np)
ax2.set_title(f'Patches in Grid ({patch_size}×{patch_size})')
ax2.axis('off')

ax3.imshow(linear_grid_np)
ax3.set_title(f'Patches in Linear Sequence: {num_display_patches} in one row')
ax3.axis('off')

plt.tight_layout()
plt.show()

Now we understand how the image is divided into patches. The next step is to convert these patches into a lower-dimensional embedding space via `torch.nn.Module`. This will be done using a convolutional layer with kernel size and stride equal to the patch size. 

In [None]:
class PatchEmbed(nn.Module):
    """ 
    Image to Patch Embedding
    """
    def __init__(self, 
                 img_size: int = 224, # 224x224
                 patch_size: int = 16,  # 16x16 
                 in_chans: int = 3,  # color channels 
                 embed_dim: int = 768):  # embedding dimension
        super(PatchEmbed, self).__init__()
        
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.num_patches_w = self.patch_shape[0]
        self.num_patches_h = self.patch_shape[1]
        
        # Initialize weights with truncated normal from the ViT paper
        fan_in = in_chans * patch_size[0] * patch_size[1]
        nn.init.trunc_normal_(self.proj.weight, std=np.sqrt(1 / fan_in))
        if self.proj.bias is not None:
            nn.init.zeros_(self.proj.bias)

    def _get_h_w(self):
        return self.num_patches_h, self.num_patches_w

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: [B, C, H, W]
        B, C, H, W = x.shape
        assert H == W, f"Input image must be square, got {H}x{W}"
        
        # Apply patch embedding conv
        # (B, C, H, W) -> (B, embed_dim, H//patch_size, W//patch_size)
        x = self.proj(x)
        
        # Flatten and transpose: (B, embed_dim, H', W') -> (B, num_patches, embed_dim)
        x = x.flatten(2).transpose(1, 2)
        
        return x
    

Implementation from [[3](#6-references)].

> What would be the computational costs of calculating the self-attention of a 224x224 image with and without patches? Assuming the image is RGB, the embedding dimension is 768 and the patch size is 16x16. 


> What could be the advantage of using a convolutional layer (like in the code) instead of a linear projection of each patch?

#### 1.1 Patch Embedding Filters after training

Figure 2 shows the top 28 principal components of the learned patch embedding filters from ViT-L/32. These visualizations reveal:

- These filters act as basis functions for representing the fine structure within each patch
- They capture different frequency patterns and edge orientations
- Some filters detect color variations (RGB channels), while others focus on texture and structural elements
- These learned embedding filters transform raw pixel data into meaningful feature representations
- The diversity of these filters suggests that the model learns to extract various visual features at the patch level before self-attention processing

<div align="center">
    <img src="figures/rgb_filter_embeddings.PNG" width="500"/>
    <p><i>Figure 2: Filters of the patch embedding of RGB values of ViT-L/32 model. Source: [1]</i></p>
</div>



### 2. Positional Embedding

Since Transformers don't inherently understand spatial relationships, positional information must be added. In contrast to Transformers in NLP, the initial ViT from [[1](#6-references)] uses learnable positional embeddings that are initialized randomly.

- 1D positional embeddings are added to the patch embeddings (Patch Embedding + Positional Embedding)
- These embeddings encode the 2D position of patches within the original image
- The model learns to represent spatial relationships between patches during training

In [None]:
class PositionalEmbedding(nn.Module):
    """
    Learnable positional embeddings for the transformer
    """
    def __init__(self, 
                 num_patches: int, 
                 embed_dim: int, 
                 dropout: float = 0.1):
        super(PositionalEmbedding, self).__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        # Initialize positional embeddings
        self._init_weights()
        
    def _init_weights(self):
        # Standard initialization strategy from the ViT paper
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
            
    def forward(self, x):
        # x shape: [B, num_patches, embed_dim]
        batch_size = x.shape[0]
        
        # Add positional embeddings
        x = x + self.pos_embed
        
        return self.dropout(x)

Implementation from [[3](#6-references)]. 

We will implement the positional embedding directly in the ViT class for simplicity. A sidenote on positional embeddings: There are implementations like the DeiT (Data-efficient image Transformers) or some variants of the SwinTransformer that use fixed sinusoidal positional embeddings instead of learnable ones. But this is not the focus of this notebook as we will focus on [[1](#6-references)].

#### 2.1 Positional Encoding after training

The heatmap of Figure 3 visualizes how similar the learned positional embeddings are to each other after training:

- Each tile shows the cosine similarity between one patch positional embedding and all others
- You can clearly see that patches close to each other (diagonally, horizontally, or vertically) have more similar embeddings
- The grid-like structure that emerges shows that the model has learned 2D spatial relationships even though it uses 1D positional embeddings
- This explains why the authors found that using more complex 2D positional embeddings didn't improve performance - the model already learned to represent 2D structure

<div align="center">
    <img src="figures/position_embeddings.PNG" width="500"/>
    <p><i>Figure 3: Positional encoding after training. Source: [1]</i></p>
</div>



### 3. Class Token [CLS] ([Devlin et al., 2019](https://aclanthology.org/N19-1423/?utm_campaign=The%20Batch&utm_source=hs_email&utm_medium=email))

- Following BERT's approach, a special learnable embedding called the "cls token" is prepended to the sequence
- This token aggregates information from all patches through self-attention
- The final state of this token at the output of the Transformer serves as the image representation
- A classification head (MLP) is attached to this token for image classification (see section [5.4](#54-final-architecture))


### 4. Layer Normalization ([Ba et al., 2016](https://arxiv.org/abs/1607.06450))

Unlike in ResNets where BatchNorm is used, the ViT uses Layer Normalization. 
Contrary to BatchNorm, which normalizes across the batch dimension, LayerNorm normalizes across the feature dimension for each individual sample.

For an input tensor $X \in \mathbb{R}^{B \times d}$ where $B$ is the batch size and $d$ is the feature dimension, LayerNorm operates as follows:

1. Compute the mean and variance **across the feature dimension** for each sample in the batch:
   - $\mu_i = \frac{1}{d} \sum_{j=1}^{d} x_{i,j}$
   - $\sigma_i^2 = \frac{1}{d} \sum_{j=1}^{d} (x_{i,j} - \mu_i)^2$

2. Normalize the features:
   - $\hat{x}_{i,j} = \frac{x_{i,j} - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}}$

3. Apply learnable scale ($\gamma$) and shift ($\beta$) parameters:
   - $y_{i,j} = \gamma \cdot \hat{x}_{i,j} + \beta$

Where $\epsilon$ is a small constant added for numerical stability. Let's see how this looks in code:

```python
class LayerNorm(nn.Module):
    """
    Layer Normalization module
    """
    def __init__(self, 
                 embed_dim: int, 
                 eps: float = 1e-6):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: [B, ..., d]
        # Calculate mean and variance along the last dimension (features)
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)
        
        # Normalize
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        
        # Scale and shift
        y = self.gamma * x_norm + self.beta
        
        return y
```

We don't need to implement this in the ViT class, as PyTorch already has a built-in LayerNorm module, which we will use.

#### 4.1 Why LayerNorm works better for Transformers

As demonstrated in ([Shen et al., 2020](https://proceedings.mlr.press/v119/shen20e.html)), BatchNorm performs poorly in Transformer architectures for several reasons:

1. **High Variance in Batch Statistics**: NLP data exhibits much higher variance in batch statistics compared to CV data, which causes instability when using BatchNorm.

2. **Variable Sequence Lengths**: NLP tasks often have variable-length inputs, making BatchNorm statistics inconsistent across batches.

3. **Token Independence**: LayerNorm treats each token independently, which aligns better with the self-attention mechanism's design.

4. **Gradient Flow**: LayerNorm helps recenter and rescale backward gradients, which is particularly important for deep Transformer networks.

In ViT, LayerNorm is applied twice in each Transformer block:
- Before the Multi-Head Self-Attention (MSA) module
- Before the MLP block

This "pre-norm" configuration helps stabilize training by ensuring that the input to each sub-layer has a consistent distribution, which is crucial for the convergence of deep Transformer networks.

<div align="center">
    <img src="figures/batchvslayer.png" width="500"/>
    <p><i>Figure 4: Difference between BatchNorm and LayerNorm. The colored sections show the entries on which the statistics are calculated on. Source: Shen et al., 2020</i></p>
</div>

#### 4.2 Implementation Considerations

When implementing LayerNorm in Vision Transformers, several considerations are important:

1. **Epsilon Value**: A small constant ($\epsilon = 10^{-6}$ is common) is added to the variance for numerical stability.

2. **Parameter Initialization**: The scale parameter $\gamma$ is typically initialized to 1, and the shift parameter $\beta$ to 0.

3. **Pre-norm vs Post-norm**: In the original Transformer ([Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)), normalization was applied *after* each sub-layer (post-norm). In ViT and many modern Transformer variants, normalization is applied *before* each sub-layer (pre-norm), which improves training stability.

4. **Dimensionality**: LayerNorm is applied to the entire embedding dimension in ViT, unlike in some other applications where it might be applied to only a subset of features.

#### 4.3 "Recent" Advances

- **RMSNorm** ([Zhang & Sennrich, 2019](https://arxiv.org/abs/1910.07467)): Simplifies LayerNorm by removing the mean-centering step and only normalizing by the root mean square.

- **PowerNorm** ([Shen et al., 2020](https://proceedings.mlr.press/v119/shen20e.html)): Addresses the issues of BatchNorm in Transformers by relaxing zero-mean normalization and using running statistics for the quadratic mean.

- **T-Fixup** ([Huang et al., 2020](https://proceedings.mlr.press/v119/huang20f)): An initialization technique that enables training deep Transformers without any normalization layers (and the need for warm up phases).



### 5. Transformer (Encoder-only)

The standard Transformer encoder consists of alternating Multi-Head Self-Attention (MSA) and MLP blocks. Each block is preceded by LayerNorm and uses residual and skip connections like in ResNet.


<div align="center">
    <img src="figures/transformer_block.PNG" width="200"/>
    <p><i>Figure 5: Structure of a Transformer block. Source: [1]</i></p>
</div>

#### 5.1 Multi-Head Self-Attention (MSA)

- Self-attention allows each patch to attend to all other patches, enabling global information integration
- We will use the MSA from the PyTorch package, so the code below is just for reference

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        # Each head gets a fraction of the total embedding dimension
        head_dim = dim // self.num_heads
        self.scale = head_dim ** -0.5

        # Single linear projection for queries, keys, and values (more efficient)
        # Projects from dim → dim*3 (to create queries, keys, and values at once)
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        # Dropout applied to attention weights for regularization
        self.attn_drop = nn.Dropout(attention_dropout)
        # Final projection to mix information from different heads
        self.proj = nn.Linear(dim, dim)
        # Dropout applied to the output projection
        self.proj_drop = nn.Dropout(projection_dropout)

    def forward(self, x):
        # Extract dimensions from input tensor
        B, N, C = x.shape  # Batch size, sequence length, embedding dimension
        
        # Step 1: Project input to queries, keys, and values all at once
        # - Apply linear projection to get [B, N, 3*C] tensor
        # - Reshape to [B, N, 3, num_heads, head_dim]
        # - Permute to [3, B, num_heads, N, head_dim] to separate q, k, v
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # Split into separate q, k, v tensors, each with shape [B, num_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Step 2: Compute attention scores (dot product of queries and keys)
        # - Transpose keys to align dimensions for matrix multiplication
        # - Scale to prevent extremely small gradients in softmax
        attn = (q @ k.transpose(-2, -1)) * self.scale  # Shape: [B, num_heads, N, N]
        
        # Step 3: Apply softmax to get attention probabilities
        # Each row sums to 1, representing probability distribution over the sequence
        attn = attn.softmax(dim=-1)
        
        # Step 4: Apply dropout for regularization
        attn = self.attn_drop(attn)
        
        # Step 5: Apply attention weights to values
        # - Matrix multiply attention weights with values
        # - Transpose and reshape back to original dimensions
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        
        # Step 6: Final projection and dropout
        # This mixes information from different heads
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x

The key feature is that the attention computation is split across multiple "heads". Each head can focus on different parts of the sequence simultaneously.
This allows the model to capture various types of relationships in parallel (see section [5.6](#56-preview)):
- Some heads might focus on local patterns
- Others might capture long-range dependencies
- Others might specialize in specific linguistic or semantic features

The $num_{heads}$ parameter determines how many parallel attention mechanisms to use. The $head_{dim} = dim / num_{heads}$ determines the dimension of each attention head.
At the end, outputs from all heads are concatenated and projected back to the original dimension.

> **Note**: As mentioned above, self-attention has quadratic computational complexity with respect to sequence length. For a $224\times 224$ image with $16\times 16$ patches, we have 196 tokens, resulting in $\sim$ 38K attention operations per layer per head. This becomes a significant bottleneck as image resolution or the number of patches increases.
Several approaches can optimize this computation:

- **Flash attention**: [Flash Attention](https://arxiv.org/abs/2205.14135) and [Flash Attention-2](https://arxiv.org/abs/2307.08691) significantly speed up attention computation through IO-aware implementation, tiling, and recomputation strategies. 

- **Linear attention**: Linear complexity approximations of attention have been successfully applied to ViTs, including Performer's [FAVOR+](https://arxiv.org/abs/2009.14794) and [Nyströmformer](https://arxiv.org/abs/2102.03902) (used in ViTGAN).

- **Hierarchical designs**: Models like [Swin Transformer](https://arxiv.org/abs/2103.14030) and [PVT](https://arxiv.org/abs/2102.12122) apply hierarchical attention with progressive spatial reduction, combining the benefits of CNNs multiscale feature maps with transformer architecture.

- **Sparse attention**: Approaches like [Longformer](https://arxiv.org/abs/2004.05150) and [Reformer](https://arxiv.org/abs/2001.04451) use sparse attention patterns to reduce complexity, allowing for longer sequences without quadratic scaling. Applications in computer vision include [Swin Transformer](https://arxiv.org/abs/2103.14030) and [Linformer](https://arxiv.org/abs/2006.04768).

- **Memory-efficient implementations**: Libraries like [xformers](https://github.com/facebookresearch/xformers) provide memory-efficient Multi-Head Attention implementations for PyTorch that reduce both memory footprint and computational requirements. 

> When implementing your own ViT and struggle with limited computational resources, please consider these optimizations to improve efficiency without sacrificing performance. 

#### 5.2 MLP

The MLP contains two layers with a GeLU non-linearity.


In [None]:
class MLP(nn.Module):
    """
    MLP block as used in Vision Transformer with GeLU activation
    """
    def __init__(self, 
                 in_features: int, 
                 hidden_features: int, 
                 out_features: int,
                 dropout: float = 0.0,
                 activation: nn.Module = nn.GELU(),
                 weight_init: bool = True):
        super(MLP, self).__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            activation,
            nn.Dropout(dropout),
            nn.Linear(hidden_features, out_features),
            nn.Dropout(dropout)
        )
        
        if weight_init:
            self._init_weights()
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply MLP block
        x = self.mlp(x)
        return x
    
    def _init_weights(self):
        # Initialize weights using Xavier uniform distribution (from torchvision vit.py)
        for layer in self.mlp:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.normal_(layer.bias, std=1e-6)
        

#### 5.3 Encoder Block

The encoder block is a combination of MSA and MLP with layer normalization (see Figure 5).


In [None]:
class TransformerEncoderBlock(nn.Module):
    """
    Transformer encoder block consisting of Multi-Head Self-Attention followed by MLP
    Using pre-norm architecture from TorchVision
    """
    def __init__(self, 
                 embed_dim: int, 
                 num_heads: int, 
                 mlp_ratio: float = 4.0, 
                 dropout: float = 0.0, 
                 attention_dropout: float = 0.0):
        super(TransformerEncoderBlock, self).__init__()
        
        # Layer Normalization before self-attention (pre-norm architecture)
        self.norm1 = nn.LayerNorm(embed_dim, eps=1e-6)
        
        # Multi-Head Self-Attention using PyTorch's built-in implementation
        self.attn = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=attention_dropout,
            batch_first=True
        )
        
        # Dropout after attention
        self.dropout = nn.Dropout(dropout)
        
        # Layer Normalization before MLP (pre-norm architecture)
        self.norm2 = nn.LayerNorm(embed_dim, eps=1e-6)
        
        # MLP block
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = MLP(
            in_features=embed_dim,
            hidden_features=mlp_hidden_dim,
            out_features=embed_dim,
            dropout=dropout
        )
                
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply Layer Norm -> Self-Attention -> Residual
        attn_output, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False)
        x = x + self.dropout(attn_output)
        
        # Apply Layer Norm -> MLP -> Residual
        x = x + self.mlp(self.norm2(x))  # Note: MLP already includes dropout so we don't need to apply it again
        
        return x

> **Note**: PyTorch's `nn.LayerNorm` initializes its parameters as follows:
- The scale parameter weight ($\gamma$) is initialized to 1.0
- The shift parameter bias ($\beta$) is initialized to 0.0
- So we just have to initialize $\epsilon$ to 1e-6

#### 5.4 Final Architecture

Now we have to put it all together:

In [None]:
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) model
    """
    def __init__(self, 
                 img_size: int = 224, 
                 patch_size: int = 16, 
                 in_chans: int = 3, 
                 num_classes: int = 1000,
                 embed_dim: int = 768, 
                 depth: int = 12, 
                 num_heads: int = 12, 
                 mlp_ratio: float = 4.0,  # hidden layer size = embed_dim * mlp_ratio
                 dropout: float = 0.0,
                 attention_dropout: float = 0.0,):
        super(VisionTransformer, self).__init__()
        
        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size, 
            patch_size=patch_size, 
            in_chans=in_chans, 
            embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches
        
        # Class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.normal_(self.class_token, std=0.02)
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        nn.init.normal_(self.pos_embed, std=0.02)
        
        # Dropout
        self.pos_dropout = nn.Dropout(dropout)
        
        # Transformer Encoder Blocks
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(
                embed_dim=embed_dim, 
                num_heads=num_heads, 
                mlp_ratio=mlp_ratio, 
                dropout=dropout,
                attention_dropout=attention_dropout
            )
            for _ in range(depth)
        ])
        
        # Layer Norm before classification head
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        
        # Classification head if valid number of classes is provided
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        # Get batch size
        B = x.shape[0]
        
        # Convert image to patch embeddings
        x = self.patch_embed(x)
        
        # Prepend class token
        cls_token = self.class_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        # Add positional embeddings
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Apply final normalization
        x = self.norm(x)
        
        # Return class token representation
        return x[:, 0]
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Get features from the class token
        x = self.forward_features(x)
        
        # Final classification
        x = self.head(x)
        
        return x
    
    def save_model(self, path: str):
        folder_name, _ = os.path.split(path)
        if folder_name:
            os.makedirs(folder_name, exist_ok=True)
        torch.save(self.state_dict(), path)
        
    def load_model(self, path: str):
        self.load_state_dict(torch.load(path))

#### 5.5 Overview 

There are different variants of the ViT regarding their depth and other hyperparameters. The most common you can find below.


| Model     | Patch Size | Embedding Dim | Heads | Blocks | MLP Size | Parameters |
|-----------|------------|---------------|-------|--------|----------|------------|
| ViT-S/8   | 8×8        | 384           | 6     | 12     | 1536     | 21M        |
| ViT-S/16  | 16×16      | 384           | 6     | 12     | 1536     | 21M        |
| ViT-S/32  | 32×32      | 384           | 6     | 12     | 1536     | 21M        |
| ViT-B/8   | 8×8        | 768           | 12    | 12     | 3072     | 85M        |
| ViT-B/16  | 16×16      | 768           | 12    | 12     | 3072     | 85M        |
| ViT-B/32  | 32×32      | 768           | 12    | 12     | 3072     | 85M        |
| ViT-L/8   | 8×8        | 1024          | 16    | 24     | 4096     | 307M       |
| ViT-L/16  | 16×16      | 1024          | 16    | 24     | 4096     | 307M       |
| ViT-L/32  | 32×32      | 1024          | 16    | 24     | 4096     | 307M       |
| ViT-H/8   | 8×8        | 1280          | 16    | 32     | 5120     | 632M       |
| ViT-H/16  | 16×16      | 1280          | 16    | 32     | 5120     | 632M       |
| ViT-H/32  | 32×32      | 1280          | 16    | 32     | 5120     | 632M       |

#### 5.6 Preview

Understanding how information flows through Vision Transformers compared to CNNs is crucial for interpreting their behavior. The concept of "attention distance" in ViTs is analogous to receptive field size in CNNs but shows fundamentally different patterns of information integration.

<div align="center">
    <img src="figures/attention_metrics.PNG" width="500"/>
    <p><i>Figure 6: Attention distance of ViT-L/32 model. Source: [1]</i></p>
</div>

Unlike CNNs, where receptive fields grow gradually through the network hierarchy, attention heads in ViTs display a unique pattern of information processing (as shown in Figure 6):

- **Parallel local-global processing**: Even in early layers, ViT has some attention heads with very large attention distances (processing global context) while others focus locally - enabling simultaneous processing of both broader context and fine details.
- **Heterogeneous attention specialization**: Different attention heads within the same layer specialize in different spatial relationships, with some focusing on adjacent patches and others connecting distant regions of the image.
- **Rapid transition to global processing**: By the middle layers, most attention heads operate globally, unlike CNNs which require many more layers to achieve the same receptive field size.

This multi-scale parallel processing appears to be a key advantage of the Transformer architecture, allowing it to build rich representations that combine local features with global context in ways fundamentally different from the strictly hierarchical processing in CNNs. For a detailed overview, check out ["Do Vision Transformers See Like Convolutional Neural Networks?"](https://proceedings.neurips.cc/paper_files/paper/2021/hash/652cf38361a209088302ba2b8b7f51e0-Abstract.html).

#### 5.7 Summary

> **Are Vision Transformers inherently better than CNNs?**

When we examine the original ViT [[1](#6-references)], we find that the answer isn't straightforward. If ViT is initialized randomly (without pre-training), it requires an enormous dataset to achieve better performance than CNNs. As shown in Figure 7, only when trained on the JFT dataset (which contains over 100M images and is not publicly available), ViT does slightly outperform CNNs. For "smaller" datasets, CNNs still maintain an advantage.

<div align="center">
    <img src="figures/comparison_CNNs.png" width="1000"/>
    <p><i>Figure 7: Scaling of performance between ResNets (=BiT) and ViTs with dataset size. Source: [1]</i></p>
</div>

> **But why is the ViT so widely used nowadays?**

There are several key factors that have driven ViT adoption despite the data-hungry nature of the architecture:

1. **Pre-trained foundation models**: Large tech companies like Google and Meta have released pre-trained ViT models (trained on massive datasets like JFT-300M or ImageNet-21k). These models demonstrate exceptional transfer learning capabilities - they can be fine-tuned on smaller datasets with significantly better performance than training from scratch. Research shows ViTs can outperform CNNs in transfer learning scenarios, particularly as model size increases ([Steiner et al., 2021](https://arxiv.org/abs/2106.10270)).


2. **Architectural improvements**: Many variants of the original ViT have been developed to improve efficiency and performance on smaller datasets:
   - **DeiT** (Data-efficient image Transformers): Uses knowledge distillation to train more efficiently
   - **Swin Transformer**: Introduces hierarchical structure and local attention to improve efficiency (see [section 5.1](#51-multi-head-self-attention-msa))
   - **CCT** (Compact Convolutional Transformer): Combines convolutional layers with Transformers for better inductive bias. See our Tutorial 4.2 for detailed information!
   - Check out [Han et al. (2022)](https://ieeexplore.ieee.org/abstract/document/9716741) for a comprehensive survey on Vision Transformers

3. **Interpretability**: ViTs provide a more transparent view of their decision-making process through attention maps, allowing to better understand which parts of an image influence classification decisions. In [[1](#6-references)] the authors determine the attention maps of the ViT-L/32 model via Attention Rollouts [[4](#6-references)]. 

<div align="center">
    <img src="figures/attention_map.PNG" width="250"/>
    <p><i>Figure 8: Attention maps of ViT-L/32 model. Source: [1]</i></p>
</div>

Implementing attention rollouts is straightforward if you have access to the attention weights of the model. So this will be a nice exercise for you to implement it yourself. 

> **But I want to learn from scratch!**

If you are interested in learning how to train a ViT from scratch, check out ["How to train your ViT?"](https://arxiv.org/abs/2106.10270) for appropriate data augmentation and other tricks. For example you can also use DropPath like in our ResNet Tutorial (2.4) as a regularization technique.

## 6. References

[1] [Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. *International Conference on Learning Representations (ICLR)*.](https://arxiv.org/abs/2010.11929)

[2] [https://www.v7labs.com/blog/vision-transformer-guide/](https://www.v7labs.com/blog/vision-transformer-guide/)

[3] [https://medium.com/@sanjithkumar986/vision-transformers-1-19e4c052aab9](https://medium.com/@sanjithkumar986/vision-transformers-1-19e4c052aab9)

[4] [Abnar, S., & Zuidema, W. (2020). Quantifying attention flow in transformers. arXiv preprint arXiv:2005.00928.](https://arxiv.org/abs/2005.00928)

## Training and evaluation of ViT on the Imagenette

So let's get to action and train the ViT on a dataset. You can adjust hyperparameters like the learning rate, batch size, and number of epochs in the code below. 

In [None]:
import torchvision.transforms.v2 as v2
from Utils.dataloaders import prepare_imagenette

# define hyperparameters
batch_size = 256
patch_size = 32
num_workers = 4

transform_augm = transforms.Compose([
    v2.ToImage(),
    # Core transformations
    v2.RandomResizedCrop(size=224, scale=(0.75, 1.0), ratio=(0.9, 1.05)),
    v2.RandomHorizontalFlip(p=0.5),  # People can face either direction
    v2.RandomRotation(degrees=(-10, 10)),  # Small rotations
    
    # Lighting and appearance variations
    v2.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1, hue=0.05),
    v2.RandomAutocontrast(p=0.2),
    
    # Occasional realistic variations - with proper probability handling
    v2.RandomApply([v2.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.3),
    v2.RandomAdjustSharpness(sharpness_factor=1.5, p=0.3),
    v2.RandomPerspective(distortion_scale=0.15, p=0.3),
    v2.RandomErasing(p=0.1, scale=(0.02, 0.08), ratio=(0.3, 3.3)),
    
    # Normalization
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_norm = transforms.Compose(
[    v2.ToImage(),
     v2.ToDtype(torch.float32, scale=True),
     v2.Resize(size=(224,224)),
     v2.Normalize(mean = [0.485, 0.456, 0.406], std=[0.229,0.224,0.225]) , 
])

# Load the Imagenette dataset
train_loader, val_loader, classes = prepare_imagenette(train_compose=transform_augm, 
                                                       test_compose=transform_norm, 
                                                       save_path='../Dataset',
                                                       batch_size=batch_size, 
                                                       num_workers=num_workers)


In [None]:
dataiter = iter(train_loader)
images, labels = next(dataiter)

num_classes = len(classes)
img_size = images.shape[3]
in_chans = images.shape[1]

print(f"Number of classes: {num_classes}")
print(f"Image size: {img_size}")
print(f"Number of channels: {in_chans}")

In [None]:
# init ViT model
vit_model = VisionTransformer(
    img_size=img_size,
    patch_size=patch_size,
    in_chans=in_chans,
    num_classes=num_classes,
    embed_dim=384,  # ViT-B/16 uses standard 768
    depth=12,  # ViT-B/16 uses standard 12
    num_heads=6,  # ViT-B/16 uses standard 12
    mlp_ratio=4.0,  # ViT-B/16 uses standard 4.0
    dropout=0.1,
    attention_dropout=0.1,
)

# model summary
print(vit_model)

# print number of trainable parameters
from Utils.little_helpers import get_parameters
print(f"Number of trainable parameters: {get_parameters(vit_model):.3f}M")

Like in [[1](#6-references)] we will use a linear warm up and cosine annealing decay schedule. The learning rate is increased linearly from `min_lr` to the initial learning rate over the amount of `warmup_epochs`. After that, it is decreased using cosine annealing. See notebook '04_1_WarmUpSchedular.ipynb' for a detailed explanation of the learning rate scheduler.

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, ChainedScheduler

warmup_epochs = 4  # 5-10% of training epochs
num_epochs = 40 

warmup_lr_init = 1e-6
base_lr = 1e-4
min_lr = 1e-7

optimizer = optim.AdamW(vit_model.parameters(), lr=base_lr, weight_decay=0.05)  # test different values for weight decay (0.05-0.1) and learning rate (1e-3 - 1e-4)

warmup_scheduler = LinearLR(
    optimizer, 
    start_factor=warmup_lr_init/base_lr,
    end_factor=1.0, 
    total_iters=warmup_epochs
)

cosine_scheduler = CosineAnnealingLR(
    optimizer, 
    T_max=num_epochs, 
    eta_min=min_lr
)

# combine schedulers with ChainedScheduler
scheduler = ChainedScheduler(
    schedulers=(warmup_scheduler, cosine_scheduler),
    optimizer=optimizer,
)


In [None]:
from Utils.functions import train_model

# save results + model checkpoints into ...
results_folder = f'vit_results/'

# train model
with timer("Training process"):
    history = train_model(model=vit_model,
                          train_loader=train_loader,
                          val_loader=val_loader,
                          criterion=nn.CrossEntropyLoss(label_smoothing=0.0),  # label smoothing to reduce overfitting. Proper values are 0.01 to 0.2
                          optimizer=optimizer,
                          scheduler=scheduler,
                          patience=5,  # number of epochs to wait before early stopping 
                          monitor='val_loss',  # metric to monitor for early stopping. If this metric doesn't improve for the last `patience` epochs, early stopping is triggered
                          device=device,
                          num_epochs=num_epochs,
                          checkpoint_path=results_folder,  # path to save the "best" model with best value of `monitor`
                          )

In [None]:
from Utils.plotting import visualize_training_results

# save history
np.save(os.path.join(results_folder, 'history.npy'), history, allow_pickle=True)

visualize_training_results(train_losses=history['train_loss'],
                           train_accs=history['train_acc'],
                           test_losses=history['val_loss'],
                           test_accs=history['val_acc'],
                           output_dir=None)


In [None]:
from Utils.functions import test_model
from Utils.plotting import visualize_test_results

# evaluate model on validation set
with timer("Evaluating process"):
    aggregate_df, per_image_df, overall_accuracy = test_model(model=vit_model,
                                                              test_loader=val_loader,
                                                              device=device,
                                                              class_names=classes,
                                                              print_per_class_summary=False)


# save dataframes as parquet (requires pyarrow and fastparquet)
try:
    aggregate_df.to_parquet(os.path.join(results_folder, 'aggregate_df.parquet'))
    per_image_df.to_parquet(os.path.join(results_folder, 'per_image_df.parquet'))
except ImportError:
    aggregate_df.to_pickle(os.path.join(results_folder, 'aggregate_df.pkl'))
    per_image_df.to_pickle(os.path.join(results_folder, 'per_image_df.pkl'))
    
visualize_test_results(aggregate_df=aggregate_df,
                       per_image_df=per_image_df,
                       overall_accuracy=overall_accuracy,
                       max_classes_display=10,
                       output_dir=None)


In [None]:
# show patch embeddings
from Utils.plotting import visualize_patch_embeddings

visualize_patch_embeddings(model=vit_model,
                           num_components=28,
                           output_dir=None)

In [None]:
# show position embeddings
from Utils.plotting import visualize_position_embeddings

visualize_position_embeddings(model=vit_model, output_dir=None)


As we can see from our results, the ViT performs worse than ResNet and even simple CNNs when trained from scratch on our dataset. When we examine the patch and positional embeddings, we can see clear differences between those shown in [[1](#6-references)] and our implementation. The author of this notebook wants to make it clear that this is not due to a flawed implementation but simply due to the fact that our dataset wasn't big and diverse enough to learn meaningful representations.

This observation aligns with [[1](#6-references)]: Vision Transformers lack the inductive biases that are inherently built into CNNs. Convolutional architectures have translation equivariance and locality baked into their design, which helps them learn efficiently from limited data. In contrast, our ViT must learn these spatial relationships entirely from the data itself.

Without these built-in structural priors, Transformers require substantially more training examples to achieve comparable performance. This is why the original ViT paper demonstrated that these models only outperform CNNs when pre-trained on very large datasets like JFT-300M. On smaller datasets, the lack of inductive bias becomes a disadvantage, resulting in poorer generalization. 

This trade-off represents a fundamental principle in deep learning: models with fewer inductive biases can potentially learn more flexible representations, but only when provided with sufficient data to compensate for the absence of built-in priors. Pre-trained ViT models leverage this by learning general visual representations from massive datasets before being fine-tuned on smaller, task-specific datasets.

## 1. Exercise: Visualize the patch and positional embeddings of a pre-trained ViT

In this exercise, you will visualize the patch and positional embeddings of a pre-trained Vision Transformer (ViT) model from `timm`.

In [None]:
# Remove all references to the model and optimizer
del vit_model, optimizer
if torch.device == 'cuda':
    # Clear CUDA cache
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary())

# Garbage collect to free memory
import gc
gc.collect()

In [None]:
import timm

# Load a pretrained ViT model
# You can change this to any other ViT variant. Have a look at https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py or
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/tiny_vit.py
model_name = 'vit_base_patch32_224.orig_in21k'  

# initialize the model
pretrained_model = timm.create_model(model_name, pretrained=True)

# TODO: Implement training and eval pipelines. Tip: Use the `train_model` and `test_model` functions from the Utils folder, as well as the already initialized `train_loader` and `val_loader` from previous sections.

In [None]:
from Utils.plotting import visualize_patch_embeddings, visualize_position_embeddings

# TODO: Visualize the patch and position embeddings with the `visualize_patch_embeddings` and `visualize_position_embeddings` functions from the Utils folder.

> Note: These functions are using predefined attributes to access the patch and positional embeddings (like `model.patch_embed` and `model.pos_embed`). These attributes might be different for the model you are using (like in `torchvision`). So check the model architecture and adapt the code accordingly if you are using another library than `timm`.

> After the first run, try increasing the weight decay in the optimizer, initialize the pre-trained model and train again. Look at the positional embeddings. What do you observe?

> Imagenette is a subset of ImageNet on which your pre-trained ViT model was probably trained on. Most of the pre-trained model achieve an accuracy of around 85% on ImageNet. But why do you start with a lower accuracy than that? What could be the reasons for that?

## 2. Exercise: Plot the attention maps of the pre-trained ViT for a random Imagenette Picture

In this exercise, you will implement the attention rollout technique to visualize how a pre-trained Vision Transformer (ViT) attends to different parts of an image. This technique, introduced by [[4](#6-references)], helps us understanding how information propagates through the self-attention layers of a Transformer. 

You will:

- Use the pre-trained ViT model from `timm` from Exercise 1
- Recompute the attention from all layers (use the `attention_hook` function). Note that you can also use the `AttentionExtractor` class from [`timm`](https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/attention_extract.py).
- Implement the attention rollout algorithm **yourself** (see [[4](#6-references)] for details)
- Visualize attention maps on random Imagenette images
- Compare raw attention to attention rollout

### Function to extract attention weights

In [None]:
def get_attention(model: nn.Module, 
                  img: torch.Tensor,) -> List[torch.Tensor]:
    
    attention_maps = []
    def attention_hook(module: nn.Module, 
                       input: torch.Tensor, 
                       output: torch.Tensor) -> torch.Tensor:
        # This is executed during the forward pass through an attention layer
        # We need to add this inside an attention module to capture when attn is computed
        
        # Get the input to the attention module
        x = input[0]
        
        # Get qkv projections
        qkv = module.qkv(x)
        B, N, C = x.shape
        
        # Reshape qkv to separate heads
        qkv = qkv.reshape(B, N, 3, module.num_heads, C // module.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # Unbind along first dimension
        
        # Calculate attention weights
        attn = (q @ k.transpose(-2, -1)) * module.scale
        attn = attn.softmax(dim=-1)
        
        # Store the attention weights
        attention_maps.append(attn.detach())
        
        # Let the original forward pass continue
        return output
    
    # Register hooks for all attention modules
    hooks = []
    for block in model.blocks:
        hooks.append(block.attn.register_forward_hook(attention_hook))
    
    # Forward pass through the model
    with torch.no_grad():
        model(img)
    
    # Remove hooks after forward pass
    for hook in hooks:
        hook.remove()
    
    # Post-process attention maps
    # TODO: This is your task now!
    # 1. Average or Max over the attention heads (have look how the approaches differ)
    # 2. Add residual connection (0.5 * attention + 0.5 * identity)
    # 3. Normalize the attention maps
    
    processed_maps = []
    for attn in attention_maps:
        pass
    return processed_maps

In [None]:
def attention_rollout(attention_maps: List[torch.Tensor],) -> torch.Tensor:
    # TODO: Recursively compute rollout through all layers
    # A_roll(l) = A(l) * A_roll(l-1)
    # normalize the final attention map
    rollout = None
    return rollout

In [None]:
def visualize_attention(img_tensor: torch.Tensor, 
                        attention_map: List[torch.Tensor], 
                        save_path: Optional[str] = None):
    
    # TODO: Implement the visualization of attention maps
    # Remember to resize the attention map to the original image size. Use a bilinear interpolation or something similar
    # Also remember to denormalize the image for visualization
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

In [None]:
# TODO: Get a random image from the dataset
# For example, you can use the first image in the batch

# Extract attention matrices
pretrained_model.to("cpu") # Move model to CPU to make data processing easier

# TODO: Implement post-processing of attention matrices
attention_matrices = get_attention(pretrained_model, img)

# TODO: Implement attention rollout
rollout = attention_rollout(attention_matrices)

# Visualize raw attention vs attention rollout
# First, visualize raw attention from the last layer
raw_attention = attention_matrices[-1]
visualize_attention(img, raw_attention, save_path=None)

# Then, visualize attention rollout
visualize_attention(img, rollout, save_path=None)


Your main task is to complete the `get_attention`, `attention_rollout` and `visualize_attention` functions. 

For a ViT with L layers, attention rollout is computed as:
```
A_rollout(0) = A(0)  # First layer's attention
A_rollout(i) = A(i) @ A_rollout(i-1)  # For i=1 to L-1
```

Where @ represents matrix multiplication. The idea is to track how attention propagates from the input tokens through all layers of the Transformer.

In [[4](#6-references)] the authors also propose attention flow, which is a different way to visualize attention. If you are interested in this, you can implement it as well. The main difference is that attention flow treats the attention graph as a flow network with capacities, using maximum flow algorithms from graph theory instead of matrix multiplication. While attention rollout multiplies attention weights along paths (treating them as proportion factors), attention flow finds the maximum possible flow from any layer to input tokens by treating attention weights as capacity constraints. This results in more distributed attention patterns that better correlate with ground truth importance scores, though at higher computational cost ($O(d^2 \cdot n^4)$ vs $O(d \cdot n^2)$ for rollout).