# Setup

In [1]:
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional

from PIL import Image
from torchvision import transforms

import torch
import torch.nn as nn

# Introduction
This notebook is based on the `torchvision` implementation of the Vision Transformer ([Source Code](https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py)) 

We build up the model architecture step by step into a minimal working version. To do that, we start with the smallest building blocks and combine them to more complex blocks.
1. General MLP
2. Transformer MLP Block
3. Encoder Block
4. Encoder
5. Vision Transformer

# General MLP Module

[Source Code](https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py)

<img src="General_MLP_Module.png"  width="50%">

In [2]:
class MLP(nn.Sequential):
    """Multi-layer perceptron (MLP) module."""

    def __init__(
        self,
        in_channels: int,
        hidden_channels: List[int],
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation_layer: Optional[Callable[..., nn.Module]] = nn.ReLU,
        bias: bool = True,
        dropout: float = 0.0,
    ):
        params = {}

        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(nn.Linear(in_dim, hidden_dim, bias=bias))
            if norm_layer is not None:
                layers.append(norm_layer(hidden_dim))
            layers.append(activation_layer(**params))
            layers.append(nn.Dropout(dropout, **params))
            in_dim = hidden_dim  # update input dimension for next layer

        layers.append(nn.Linear(in_dim, hidden_channels[-1], bias=bias))
        layers.append(nn.Dropout(dropout, **params))

        super().__init__(*layers)

# Transformer MLP Block

The Transformer MLP Block is a special case of the General MLP Module. There are always 2 hidden layers with a GELU activation function and special initialization of the weights. Usually `mlp_dim` = 4 $\cdot$ `in_dim`.

Note that **the dimensionality of the tensor is the same before and after each module** of the Transformer Encoder! 

<img src="Transformer_MLP_Block.png"  width="50%">

In [3]:
class MLPBlock(MLP):
    """Transformer MLP block."""
    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

# Encoder Block

<img src="Encoder_Block.png"  width="70%">

In [4]:
class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)

        # Modified to be able to inspect the attention weights
        # x, _ = self.self_attention(x, x, x, need_weights=False)
        x, attn_weights = self.self_attention(x, x, x, need_weights=True)
        self.attention_weights = attn_weights  # Store the attention weights

        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

# Encoder

<img src="Encoder.png"  width="90%">

In [5]:
class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        return self.ln(self.layers(self.dropout(input)))

# Vision Transformer

<img src="Vision_Transformer.png"  width="90%">

In [6]:
class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer

        # Flattening of patches
        self.conv_proj = nn.Conv2d(
            in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        )

        # Number of patches
        seq_length = (image_size // patch_size) ** 2

        # Class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        # self.class_token = nn.Parameter(torch.rand(1, 1, hidden_dim))
        seq_length += 1

        # Creating the encoder
        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        # Projection head
        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        # Init the patchify stem
        fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
        nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
        if self.conv_proj.bias is not None:
            nn.init.zeros_(self.conv_proj.bias)


    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        print("After Projection:", x.shape)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)
        print("After Reshape:", x.shape)
        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)
        print("After Permute:", x.shape)
        return x


    # Added to be able to inspect the MultiheadAttention
    def get_attention_head(self, layer: int, head: int) -> torch.Tensor:
        return self.encoder.layers[layer].attention_weights[:, head]


    def forward(self, x: torch.Tensor):
        print("Input:", x.shape)
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]
        print("Batch Size:", n)

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)
        print("Batch Class Token:", batch_class_token.shape)
        print("Before Encoder", x.shape)

        x = self.encoder(x)
        print("After Encoder:", x.shape)
        
        # Classifier "token" as used by standard language architectures
        x = x[:, 0]
        print("Only Batch Class Token:", x.shape)

        x = self.heads(x)
        print("After MLP Head:", x.shape)
        return x

In [7]:
def _vision_transformer(
    patch_size: int,
    num_layers: int,
    num_heads: int,
    hidden_dim: int,
    mlp_dim: int,
    **kwargs: Any,
) -> VisionTransformer:

    image_size = kwargs.pop("image_size", 224)
    model = VisionTransformer(
        image_size=image_size,
        patch_size=patch_size,
        num_layers=num_layers,
        num_heads=num_heads,
        hidden_dim=hidden_dim,
        mlp_dim=mlp_dim,
        **kwargs,
    )
    return model

# Create Model

Now that we have seen how the Vision Transformer is assembled from the simpler building blocks, we can have a look at the whole architecture, see how the tensors flow through it, and pass inputs to the model.

<img src="Vision_Transformer_Big_Picture.png"  width="90%">

In [9]:
model = _vision_transformer(16, 6, 6, 768, 3072, image_size=512)
# print("Model:", model)
image = Image.open("windmill.png")
normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0), (1))
])
# Normalize the image
image = normalize(image)
# Add a batch dimension of 1
image = image[None]
# print("Input image:", image)
out = model(image)
# print("Output of ViT:", out)

Input: torch.Size([1, 3, 512, 512])
After Projection: torch.Size([1, 768, 32, 32])
After Reshape: torch.Size([1, 768, 1024])
After Permute: torch.Size([1, 1024, 768])
Batch Size: 1
Batch Class Token: torch.Size([1, 1, 768])
Before Encoder torch.Size([1, 1025, 768])
After Encoder: torch.Size([1, 1025, 768])
Only Batch Class Token: torch.Size([1, 768])
After MLP Head: torch.Size([1, 1000])


In [11]:
# Example of using get_attention_head:
torch.set_printoptions(profile="full")

for layer in [0, 1]:
    for head in [0, 1, 2, 3]:
        attention_head = model.get_attention_head(layer, head)
        print(f"Attention head {head} from layer {layer}:")
        print(attention_head)

Attention head 0 from layer 0:
tensor([[0.0013, 0.0010, 0.0011, 0.0010, 0.0009, 0.0011, 0.0010, 0.0010, 0.0010,
         0.0011, 0.0010, 0.0011, 0.0011, 0.0012, 0.0009, 0.0010, 0.0010, 0.0009,
         0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010,
         0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0011, 0.0010,
         0.0009, 0.0009, 0.0009, 0.0010, 0.0011, 0.0010, 0.0010, 0.0010, 0.0009,
         0.0010, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
         0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010,
         0.0010, 0.0011, 0.0010, 0.0011, 0.0010, 0.0010, 0.0011, 0.0010, 0.0009,
         0.0011, 0.0009, 0.0010, 0.0009, 0.0010, 0.0010, 0.0011, 0.0009, 0.0009,
         0.0009, 0.0009, 0.0010, 0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0009,
         0.0010, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0011,
         0.0010, 0.0009, 0.0010, 0.0009, 0.0009, 0.0009, 0.0011, 0.0009, 0.001