# Vision Transformer (ViT) from Scratch

Let's dive into one of the most significant contribution in the field of Computer Vision: the Vision Transformer (ViT).

In [1]:
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [2]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'

# The Transformer Encoder

In [3]:
class MultiHeadSelfAttention(nn.Module):
    """
    The Attention Layer.
    """

    def __init__(self, n_embed, n_head, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        assert n_embed % n_head == 0, "n_embed must be divisible by n_head"
        self.n_embed= n_embed
        self.n_head = n_head
        self.d_head = n_embed // n_head
        # query, key, value projections in a single batch
        self.c_attn= nn.Linear(n_embed, 3 * n_embed)
        # output projection
        self.o_proj= nn.Linear(n_embed, n_embed)
        # regularization
        self.dropout= nn.Dropout(p=dropout)


    def forward(self, x):
        B, T, C= x.size() # batch_size, sequence length, embedding dim (d_model)
        assert C == self.n_embed, "Input embedding dimension must match model embedding dimension"
        # 1. calculate query, key, values for all heads
        qkv= self.c_attn(x)
        q, k, v= qkv.split(self.n_embed, dim=2) # q,k,v -> (B, T, C)
        # 2. reshape for Multi-Head Attention
        q= q.view(B, T, self.n_head, self.d_head).transpose(1, 2) # q,k,v view   -> (B, T, nh, dh)
        k= k.view(B, T, self.n_head, self.d_head).transpose(1, 2) # q,k,v transp -> (B, nh, T, dh)
        v= v.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        # 3. Attention - the 'scaled dot product'
        attn= (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_head))
        # normalize Attention scores
        attn= F.softmax(attn, dim=-1)
        attn= self.dropout(attn)
        # 4. compute Attention output
        y= attn @ v # (B, nh, T, dh)
        # 5. concatenate multi-head outputs -- re-assembly all head outputs side by side
        y= y.transpose(1, 2).contiguous().view(B, T, C)
        # 6. output projection
        return self.o_proj(y)


In [4]:
class FeedForward(nn.Module):
    """
    The Feed Forward Network (FFN).
    """

    def __init__(self, n_embed, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.ffn= nn.Sequential(
            nn.Linear(n_embed, d_ff),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(d_ff, n_embed),
        )


    def forward(self, x):
        x= self.ffn(x)

        return x


In [5]:
class EncoderBlock(nn.Module):
    """
    The Ecoder Block (pre-normalization version).
    """

    def __init__(self, n_embed, n_head, d_ff, dropout=0.1) -> None:
        super(EncoderBlock, self).__init__()
        self.ln_1= nn.LayerNorm(n_embed)
        self.attn= MultiHeadSelfAttention(n_embed, n_head, dropout)
        self.dropout1= nn.Dropout(p=dropout)
        self.ln_2= nn.LayerNorm(n_embed)
        self.ffn = FeedForward(n_embed, d_ff, dropout)
        self.dropout2= nn.Dropout(p=dropout)


    def forward(self, x):
        x_norm = self.ln_1(x)
        x= x + self.dropout1(self.attn(x_norm))
        x_norm = self.ln_2(x)
        x= x + self.dropout2(self.ffn(x_norm))

        return x


With the Attention Layer and Feed Forward Network in place, we can assemble a Transformer Encoder. The Transformer Encoder is essentially a stack of N Encoder Blocks.

Remember, Transformers are like Legos — the input dimension is the same as the output dimension, so you can stack as many blocks as you want (or as your memory allows).

In [6]:
class TransformerEncoder(nn.Module):
    """
    The Transformer Encoder is essentially a stack of N Encoder Blocks.
    """

    def __init__(self, n_embed=512, n_layer=6, n_head=8, d_ff=2048, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.transformer= nn.ModuleList([
            EncoderBlock(n_embed, n_head, d_ff, dropout) for _ in range(n_layer)
        ])
        self.ln_final= nn.LayerNorm(n_embed)


    def forward(self, x):
        for layer in self.transformer:
            x= layer(x)

        return self.ln_final(x)


In [7]:
model= TransformerEncoder().to(device)
data= torch.randn(16, 128, 512).to(device)
model.eval()
model(data).shape

torch.Size([16, 128, 512])

# Building the final ViT

We mainly need to add 3 components:

- Converting the image into patches, and then vectors.
- Add positional embedding.
- Add the CLS token.

We need to check that we are correctly splitting the image into a number of patches that is an integer. In other words, we need to check that image_height and image_width are divisible by patch_dimension.

Next step is to convert the patch into embeddings. Remember that here an image has C = 3 dimensions. We need to unfold this dimension, and compress each patch of dimension patch_size x patch_size x c.

Then we need to define the CLS token and the positional embedding. The CLS Token is useful to represent the whole image into a single vector, and the positional embedding is what helps the model to have spatial awareness of the tokens. They are both learned parameters (randomly initialized).

Finally, we just have to define the transformer layer that we have defined before, and add a classification head.

In [8]:
class ViT(nn.Module):
    """
    Initializes a Vision Transformer (ViT) model.
    """

    def __init__(self, image_size, patch_size, channels, num_classes, pool='cls',
                 n_embed=512, n_layer=6, n_head=8, d_ff=2048, dropout=0.1):
        super(ViT, self).__init__()
        image_height, image_width= self.pair(image_size)
        patch_height, patch_width= self.pair(patch_size)
        # ensure that the image dimensions are divisible by the patch sizes
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        # calculate the number of patches and the dimension of each patch
        num_patches= (image_height // patch_height) * (image_width // patch_width)
        patch_dim= channels * patch_height * patch_width
        # ensure the pooling strategy is valid
        assert pool in {'cls', 'mean'}, 'Pool type must be either cls (cls token) or mean (mean pooling).'
        # pooling strategy (CLS token or mean of patches)
        self.pool= pool

        self.patch_embed= nn.Sequential(
            # unfold the image into patches of shape (batch_size, num_patches, patch_dim)
            Rearrange('b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=patch_height, pw=patch_width),
            nn.LayerNorm(patch_dim),       # normalize each patch
            nn.Linear(patch_dim, n_embed), # project patches to embedding dimension
            nn.LayerNorm(n_embed),         # normalize the embedding
        ) # embedding shape (batch_size, num_patches, n_embed)

        # define CLS token and positional embeddings
        self.cls_token= nn.Parameter(torch.randn(1, 1, n_embed)) # learnable class (CLS) token
        self.pos_embed= nn.Parameter(torch.randn(1, num_patches +1, n_embed))
        self.dropout= nn.Dropout(p=dropout)

        # define the transformer encoder
        self.encoder= TransformerEncoder(n_embed, n_layer, n_head, d_ff, dropout)

        # identity layer (no change to the tensor)
        self.latent_space= nn.Identity()
        # classification head
        self.lm_head= nn.Linear(n_embed, num_classes)

        # initialize parameters with Glorot / fan_avg
        for p in self.parameters():
            if p.dim()> 1:
                nn.init.xavier_normal_(p)


    def pair(self, x):
        """
        Utility function: Converts a single value into a tuple of two values.
        If x is already a tuple, it is returned as is.
        """

        return x if isinstance(x, tuple) else (x, x)


    def forward(self, img):
        x= self.patch_embed(img)
        B, P, C= x.size() # (batch_size, num_patches, n_embed)
        # repeat class token (CLS) for each image in the batch
        cls_token= repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
        # concatenate class token (CLS) with patch embeddings
        x= torch.cat((cls_token, x), dim=1)
        # add positional embedding to the input
        x += self.pos_embed[:, :(P + 1)]
        x= self.dropout(x)

        # forward the the transformer encoder
        x= self.encoder(x) # (batch_size, num_patches + 1, n_embed)

        # extract class token and feature map
        cls_token= x[:, 0]
        feature_map= x[:, 1:]
        # apply pooling operation: 'cls' token or mean of patches
        pool_output= cls_token if self.pool == 'cls' else feature_map.mean(dim=1)

        # apply the identity transformation (no change to the tensor)
        pool_output= self.latent_space(pool_output)
        # forward the classifier
        logits= self.lm_head(pool_output)

        # return CLS token, patch embeddings, and classification results
        return cls_token, feature_map, logits


**Forward pass:** We have initialized all the components of our ViT, now we just have to call them in the right order for the forward pass.

- We first convert the input image into patches, and unfold each patch into a vector.

- Then we repeat CLS tokens (along the batch dimension), and we concatenate it on the dimension 1, which is the sequence length. Indeed we learn the parameters for one vector, but it needs to be concatenated to each image, this is why we expand one dimension.

- Then we add the position embedding to each token.

Next we apply the Transformer Encoder. We then mainly use it to build an output containing 3 things:

- The CLS Token (a single vector representation of the image).

- The Feature Map (A vectorized representation of each patch of the image)

- Classification Head Logits (Optional): This is used in the case of classification task. Note that Vision Transformer can be trained with different tasks, but classification is the task that was originally used.

In [9]:
# --- ViT Large hyperparameters config ---
image_size= 224
patch_size= 16
channels=3
num_classes= 1000
pool='cls'
n_embed= 1024
n_layer= 24
n_head= 16
d_ff= 4 * n_embed

model= ViT(image_size, patch_size, channels, num_classes, pool,
           n_embed, n_layer, n_head, d_ff).to(device)

total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {total_params}\n')

model

Number of parameters: 304330216



ViT(
  (patch_embed): Sequential(
    (0): Rearrange('b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=16, pw=16)
    (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=768, out_features=1024, bias=True)
    (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (encoder): TransformerEncoder(
    (transformer): ModuleList(
      (0-23): 24 x EncoderBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadSelfAttention(
          (c_attn): Linear(in_features=1024, out_features=3072, bias=True)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout1): Dropout(p=0.1, inplace=False)
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (ffn): FeedForward(
          (ffn): Sequential(
            (0): Linear(in_features=1024, out_features=40

In [10]:
img= [np.random.randn(3, 224, 224) for _ in range(32)]
img= torch.tensor(np.array(img), dtype=torch.float32).to(device)

model.eval()
cls_token, feature_map, logits= model(img)

print(f'CLS Token Shape: {cls_token.shape}')
print(f'Feature Map Shape: {feature_map.shape}')
print(f'Classification Head Logits Shape: {logits.shape}')

CLS Token Shape: torch.Size([32, 1024])
Feature Map Shape: torch.Size([32, 196, 1024])
Classification Head Logits Shape: torch.Size([32, 1000])


In [11]:
# https://towardsdatascience.com/how-to-train-a-vision-transformer-vit-from-scratch-f26641f26af2

# Training the ViT model from scratch

TODO