# 2. Implementing a basic ViT

### About this notebook

This notebook was used in the 50.039 Deep Learning course at the Singapore University of Technology and Design.

**Author:** Matthieu DE MARI (matthieu_demari@sutd.edu.sg)

**Version:** 1.0 (23/01/2025)

**Requirements:**
- Python 3 (tested on v3.13.1)
- Numpy (tested on v2.2.1)
- Sklearn (tested on v1.6.1)
- Torch (tested on v2.7.0+cu124)

In [1]:
import numpy as np
from sklearn.metrics import accuracy_score
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Imports from the previous notebook (patches embeddings and 2D sinusoidal positional encodings)

For simplicity, we will include everything in a single object.

In [3]:
class VisionTransformerProcessor:
    def __init__(self, img_size, patch_size, embed_dim, device):
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size//patch_size)**2
        self.embed_dim = embed_dim
        self.grid_size = (img_size//patch_size, img_size//patch_size)
        self.device = device
        self.linear_proj = nn.Conv2d(in_channels = 3, out_channels = embed_dim, 
                                     kernel_size = patch_size, stride = patch_size).to(device)
        # Generate 2D positional encodings, do it once, instead of every single batch of images
        self.positional_encodings = self.generate_2d_positional_encoding()

    def generate_2d_positional_encoding(self):
        # Create row and column positional encodings
        num_rows, num_cols = self.grid_size
        row_pos = np.arange(num_rows).reshape(-1, 1)
        col_pos = np.arange(num_cols).reshape(-1, 1)
        d = np.arange(self.embed_dim//2).reshape(1, -1)
        # Sinusoidal encodings for rows and columns
        angle_rates = 1/np.power(10000, (2*d)/self.embed_dim)
        row_encoding = np.concatenate([np.sin(row_pos*angle_rates), np.cos(row_pos*angle_rates)], axis = -1)
        col_encoding = np.concatenate([np.sin(col_pos*angle_rates), np.cos(col_pos*angle_rates)], axis = -1)
        # Combine row and column encodings
        row_encoding = np.tile(row_encoding[:, np.newaxis, :], (1, num_cols, 1))
        col_encoding = np.tile(col_encoding[np.newaxis, :, :], (num_rows, 1, 1))
        pos_encoding = row_encoding + col_encoding
        return pos_encoding.reshape(-1, self.embed_dim)

    def process_images(self, images):
        # Process a batch of images to generate patch embeddings and apply positional encodings.
        batch_size, channels, height, width = images.shape
        # Generate patch embeddings
        patches = self.linear_proj(images)  # Shape: (batch_size, embed_dim, grid_h, grid_w)
        patches = patches.flatten(2).transpose(1, 2)  # Shape: (batch_size, num_patches, embed_dim)
        # Add positional encodings
        pos_encodings = torch.tensor(self.positional_encodings, dtype = torch.float32).to(images.device)
        patches_with_pos = patches + pos_encodings.unsqueeze(0)  # Add positional encodings
        return patches_with_pos

Let us show how it works on one batch of images coming from CIFAR-10

In [4]:
# Example Usage with CIFAR-10
# Example image size for CIFAR 10 is 32x32 pixels
img_size = 32
# Each patch is 8x8
patch_size = 8
# Embedding dimension set to 64
embed_dim = 64

# Create an instance of the processor
vi_processor = VisionTransformerProcessor(img_size = img_size, patch_size = patch_size, embed_dim = embed_dim, device = device)

# CIFAR-10 Dataset and DataLoader, uses batch size 16
transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor()])
cifar10_dataset = datasets.CIFAR10(root = "./data", train = True, download = True, transform = transform)
dataloader = DataLoader(cifar10_dataset, batch_size = 16, shuffle = True)

# Process a batch of images
for images, labels in dataloader:
    patch_embeddings = vi_processor.process_images(images.to(device))
    # We expect (batch_size = 16, patches_number = 16, embed_dim = 64)
    print(f"Processed Batch Shape: {patch_embeddings.shape}")
    print(f"Example Patch Embedding for First Image: {patch_embeddings[0, 0, :]}")
    break

Processed Batch Shape: torch.Size([16, 16, 64])
Example Patch Embedding for First Image: tensor([-0.0376, -0.3324,  0.0313,  0.3214, -0.0649, -0.1658, -0.1404,  0.1826,
         0.5735, -0.1980,  0.3530,  0.1537, -0.1557, -0.1961, -0.5471, -0.6755,
        -0.0901,  0.5010, -0.1805, -0.1357,  0.5971,  0.0898, -0.3910, -0.2648,
        -0.0728, -0.0758, -0.1184,  0.0698, -0.1251, -0.9450, -0.0135,  0.5654,
         2.1114,  2.0324,  2.1717,  1.7041,  1.9677,  2.0746,  1.7042,  2.2752,
         2.0074,  2.4681,  1.8811,  2.4472,  1.9353,  2.0653,  1.6227,  1.6465,
         1.8564,  1.8609,  2.8056,  1.9274,  1.8273,  1.9682,  1.8559,  1.9932,
         2.1965,  1.7839,  1.6668,  1.9211,  2.3594,  1.5442,  2.1772,  2.6053],
       device='cuda:0', grad_fn=<SliceBackward0>)


### Implementing a basic ViT

In [5]:
class ViT(nn.Module):
    def __init__(self, embed_dim, num_patches, num_classes, num_heads, num_layers, mlp_dim, dropout = 0.1):
        super(ViT, self).__init__()
        self.num_patches = num_patches
        # Class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Transformer Encoder, using PyTorch prototypes for transformers
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model = embed_dim,
                                                                        nhead = num_heads,
                                                                        dim_feedforward = mlp_dim,
                                                                        dropout = dropout,
                                                                        activation = 'gelu',
                                                                        batch_first = True),
                                             num_layers=num_layers)
        # Classification head
        self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim),
                                      nn.Linear(embed_dim, num_classes))

    def forward(self, x):
        batch_size, num_patches, embed_dim = x.shape
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim = 1)
        x = self.encoder(x)
        cls_output = x[:, 0]
        logits = self.mlp_head(cls_output)
        return logits

### Training loop for ViT

Finally, we will write a simple training loop for our model and the CIFAR-10 dataset

In [6]:
# CIFAR-10 Dataset and DataLoader
transform = transforms.Compose([transforms.Resize((img_size, img_size)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

batch_size = 128
train_dataset = datasets.CIFAR10(root="./data", train = True, download = True, transform = transform)
test_dataset = datasets.CIFAR10(root="./data", train = False, download = True, transform = transform)
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

In [7]:
# Parameters for model
img_size = 32
patch_size = 8
embed_dim = 64
num_patches = (img_size//patch_size)**2
num_classes = 10
num_heads = 2
num_layers = 2
mlp_dim = 128

# Vision Transformer Processor for Patches and Encodings
vi_processor = VisionTransformerProcessor(img_size = img_size, patch_size = patch_size, embed_dim = embed_dim, device = device)

# Vision Transformer Model
model = ViT(embed_dim = embed_dim, num_patches = num_patches, num_classes = num_classes, 
            num_heads = num_heads, num_layers = num_layers, mlp_dim = mlp_dim).to(device)

In [8]:
# Hyperparameters
epochs = 100
lr = 5e-3

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

In [9]:
# Training Loop
def train_model(model, dataloader, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            # Generate patch embeddings with positional encodings
            patch_embeddings = vi_processor.process_images(images)
            # Forward pass
            outputs = model(patch_embeddings)
            loss = criterion(outputs, labels)
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")

In [10]:
# Train the model
# (Takes a while as transformers have lots of trainable parameters!)
train_model(model, train_loader, criterion, optimizer, epochs)

Epoch [1/100], Loss: 2.0111
Epoch [2/100], Loss: 1.8319
Epoch [3/100], Loss: 1.7671
Epoch [4/100], Loss: 1.7204
Epoch [5/100], Loss: 1.6859
Epoch [6/100], Loss: 1.6555
Epoch [7/100], Loss: 1.6378
Epoch [8/100], Loss: 1.6034
Epoch [9/100], Loss: 1.5809
Epoch [10/100], Loss: 1.5567
Epoch [11/100], Loss: 1.5376
Epoch [12/100], Loss: 1.5346
Epoch [13/100], Loss: 1.5146
Epoch [14/100], Loss: 1.5023
Epoch [15/100], Loss: 1.5007
Epoch [16/100], Loss: 1.4896
Epoch [17/100], Loss: 1.4797
Epoch [18/100], Loss: 1.4680
Epoch [19/100], Loss: 1.4597
Epoch [20/100], Loss: 1.4520
Epoch [21/100], Loss: 1.4479
Epoch [22/100], Loss: 1.4394
Epoch [23/100], Loss: 1.4260
Epoch [24/100], Loss: 1.4214
Epoch [25/100], Loss: 1.4110
Epoch [26/100], Loss: 1.4086
Epoch [27/100], Loss: 1.3965
Epoch [28/100], Loss: 1.3892
Epoch [29/100], Loss: 1.3851
Epoch [30/100], Loss: 1.3809
Epoch [31/100], Loss: 1.3760
Epoch [32/100], Loss: 1.3689
Epoch [33/100], Loss: 1.3628
Epoch [34/100], Loss: 1.3596
Epoch [35/100], Loss: 1

In [11]:
# Evaluation
def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            # Generate patch embeddings with positional encodings
            patch_embeddings = vi_processor.process_images(images)
            # Forward pass
            outputs = model(patch_embeddings)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    return accuracy

In [12]:
# Evaluate the model
test_accuracy = evaluate_model(model, test_loader)
print(f"Test Accuracy: {test_accuracy*100:.2f}%")

Test Accuracy: 56.27%
