<a target="_blank" href="https://colab.research.google.com/github/mlde-ms/vision-transformer-from-scratch.git">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Setup

In [1]:
import math
import matplotlib.pyplot as plt
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional

from PIL import Image
from torchvision import transforms, datasets

import torch
import torch.nn as nn

from torchinfo import summary

import mnist_vit

# VisionTransformer from Scratch

This notebook demonstrates how the popular VisionTransformer (ViT) architecture, presented in the paper TODO, works. It is also implemented in [`torchvision`](https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py), which this notebook uses as inspiration.

First, we explore the modules used in the ViT, successively building larger of its parts into a minimal working version. While doing so, we explore the *code*, *graphics* representing the architecture, understanding the *math* behind it and looking at the transformations the *tensor* flowing through the network undergoes to produce the output. Second, we define a very tiny ViT to solve the toy *example* of [MNIST](https://yann.lecun.com/exdb/mnist/).

## The General MLP Module

[Source Code](https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py)

<img src="TODO">

In [2]:
class MLP(nn.Sequential):
    """Multi-layer perceptron (MLP) module."""

    def __init__(
        self,
        in_channels: int,
        hidden_channels: List[int],
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation_layer: Optional[Callable[..., nn.Module]] = nn.ReLU,
        bias: bool = True,
        dropout: float = 0.0,
    ):
        params = {}

        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(nn.Linear(in_dim, hidden_dim, bias=bias))
            if norm_layer is not None:
                layers.append(norm_layer(hidden_dim))
            layers.append(activation_layer(**params))
            layers.append(nn.Dropout(dropout, **params))
            in_dim = hidden_dim  # update input dimension for next layer

        layers.append(nn.Linear(in_dim, hidden_channels[-1], bias=bias))
        layers.append(nn.Dropout(dropout, **params))

        super().__init__(*layers)

## Transformer MLP Block

The Transformer MLP Block is a special case of the General MLP Module. There are always 2 hidden layers with a GELU activation function and special initialization of the weights. Usually `mlp_dim` = 4 $\cdot$ `in_dim`.

<img src="TODO">

In [3]:
class MLPBlock(MLP):
    """Transformer MLP block."""
    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

## Encoder Block

<img src="TODO">

In [4]:
class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor, return_attention_output=False):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)

        x, _ = self.self_attention(x, x, x, need_weights=False)

        # Added to get attention output
        if return_attention_output: return x

        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

## Encoder

<img src="TODO">

In [5]:
class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        return self.ln(self.layers(self.dropout(input)))

## Vision Transformer

Note the follwing functionality I have added to make to the model easier to probe:

1. Added some `print` statements printing the shape of the tensor at the current stage, which can be turned on and off via the `print_shapes` parameter of the constructor.
2. Added the method `get_attention_head` to be able to inspect the output of a given MultiheadAttention module.
3. Added the `in_channels` parameter to the constructor to be able to specify the number of channels (was hardcoded to 3, but we have only 1 with MNIST).

<img src="TODO">

In [6]:
class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        print_shapes: bool = False,  # Added by me to print the shapes of the tensors
        in_channels: int = 3,        # Added by me to specify the number of input channels (needed, as MNIST has only 1 channel)
    ):
        super().__init__()
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer
        self.print_shapes = print_shapes  # Added by me to print the shapes of the tensors

        # Patchify and flatten input
        self.conv_proj = nn.Conv2d(
            in_channels=in_channels, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        )

        # Number of patches
        seq_length = (image_size // patch_size) ** 2

        # Class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        # self.class_token = nn.Parameter(torch.rand(1, 1, hidden_dim))
        seq_length += 1

        # Creating the encoder
        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        # Projection head
        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        # Init the patchify stem
        fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
        nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
        if self.conv_proj.bias is not None:
            nn.init.zeros_(self.conv_proj.bias)


    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        if (self.print_shapes): print("After Projection:", x.shape)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)
        if (self.print_shapes): print("After Reshape:", x.shape)
        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)
        if (self.print_shapes): print("After Permute:", x.shape)
        return x


    # Added to be able to inspect the output of MultiheadAttention modules
    def get_attention_output(self, x: torch.Tensor, layer: int, head: int) -> torch.Tensor:
        x = self._process_input(x)
        n = x.shape[0]
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)
        for i, encoder_layer in enumerate(self.encoder.layers):
            if i == layer:
                x = encoder_layer(x, return_attention_output=True)
                break
            else:
                x = encoder_layer(x)
        head_dim = self.hidden_dim // self.encoder.layers[layer].num_heads
        x = x.view(n, -1, self.encoder.layers[layer].num_heads, head_dim)
        x = x[:, :, head, :]
        return x


    def forward(self, x: torch.Tensor):
        if (self.print_shapes): print("Input:", x.shape)
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]
        if (self.print_shapes): print("Batch Size:", n)

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)
        if (self.print_shapes): print("Batch Class Token:", batch_class_token.shape)
        if (self.print_shapes): print("Encoder Input:", x.shape)

        x = self.encoder(x)
        if (self.print_shapes): print("Encoder Output:", x.shape)
        
        # Classifier "token" as used by standard language architectures
        x = x[:, 0]
        if (self.print_shapes): print("Only Batch Class Token:", x.shape)

        x = self.heads(x)
        if (self.print_shapes): print("Output:", x.shape)
        return x

# Creating a Vision Transformer

Now that we have seen how the Vision Transformer is assembled bottom-up, we can have a look at the whole architecture, see how the tensors flow through it, and pass inputs to the model.

In [7]:
model = VisionTransformer(512, 16, 6, 6, 768, 3072, num_classes=1000, representation_size=2000, print_shapes=True)

image = Image.open("windmill.png")
normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0), (1))
])
# Normalize the image
image = normalize(image)
# Add a batch dimension of 1
image = image[None]

# Run the image through the model and print the shapes at different stages
print("\033[1mShapes of the tensor at different stages:\033[0m")
out = model(image)

[1mShapes of the tensor at different stages:[0m
Input: torch.Size([1, 3, 512, 512])
After Projection: torch.Size([1, 768, 32, 32])
After Reshape: torch.Size([1, 768, 1024])
After Permute: torch.Size([1, 1024, 768])
Batch Size: 1
Batch Class Token: torch.Size([1, 1, 768])
Encoder Input: torch.Size([1, 1025, 768])
Encoder Output: torch.Size([1, 1025, 768])
Only Batch Class Token: torch.Size([1, 768])
Output: torch.Size([1, 1000])


In [8]:
# Model summary
summary(VisionTransformer(512, 16, 6, 6, 768, 3072, num_classes=1000, representation_size=2000), input_size=(1, 3, 512, 512))

Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [1, 1000]                 768
├─Conv2d: 1-1                                 [1, 768, 32, 32]          590,592
├─Encoder: 1-2                                [1, 1025, 768]            787,200
│    └─Dropout: 2-1                           [1, 1025, 768]            --
│    └─Sequential: 2-2                        [1, 1025, 768]            --
│    │    └─EncoderBlock: 3-1                 [1, 1025, 768]            7,087,872
│    │    └─EncoderBlock: 3-2                 [1, 1025, 768]            7,087,872
│    │    └─EncoderBlock: 3-3                 [1, 1025, 768]            7,087,872
│    │    └─EncoderBlock: 3-4                 [1, 1025, 768]            7,087,872
│    │    └─EncoderBlock: 3-5                 [1, 1025, 768]            7,087,872
│    │    └─EncoderBlock: 3-6                 [1, 1025, 768]            7,087,872
│    └─LayerNorm: 2-3                     

# Applying Our ViT to an Example: MNIST Classification

In [9]:
model = VisionTransformer(28, 2, 2, 2, 64, 256, num_classes=10, representation_size=32, in_channels=1, print_shapes=False)

In [10]:
# Uncomment the following line to also train the model
# mnist_vit.train(model)

# Expects model weights to be stored in `vit_weights.pth`
mnist_vit.evaluate(model)

# Expects model weights to be stored in `vit_weights.pth`
mnist_vit.visualize_attention(model)

1.0%

Using cpu for evaluation.
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw


100.0%
6.0%


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%
100.0%


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Accuracy on the train data: 95.88%
[[ 957    0    3    0    2    2    6    0    5    2]
 [   0 1121    0    3    0    0    2    4    2    1]
 [   6    2  984    5    1    0    0   14    5    3]
 [   2    1   15  966    0   15    0    7   41    2]
 [   1    3    4    2  943    1    5   11    8   21]
 [   1    0    0   19    0  858    5    0   17    3]
 [   7    4    7    0    8    5  938    0    5    2]
 [   0    1   14   11    1    3    0  976    5    8]
 [   3    2    5    3    1    2    2    1  883    5]
 [   3    1    0    1   26    6    0   15    3  962]]


Now you can look at the images generated in the folder `attention` and inspect the attention heads.