# Instruction


*   Feel free to modify the skeleton code as needed. It is provided solely as a framework for your solution.
*   You must submit this notebook in both .html and .ipynb formats.
*   Ensure to change your runtime to GPU for faster training.




#Objectives

In this assignment, you will use vision transformer for image classification on MNIST dataset:

Most of the code for this homework has been provided to you in this notebook and you should be able to train the model as it is to see how it works.

You need to do these tasks:


1.   **Patch Extraction**: Run *img_to_patch* on a batch of images from train_dataloader with both *flatten_channels=True* and *flatten_channels=False* and *patch_size=4*, then discuss about shape of the output: What does each dimension show?
2.   **Training**: Train the model for 10 epochs and show learning curves.
3.   **Visualization of cosine similarity of positional embeddings**: you need to extract positional embeddings from the model before and after training, calculate the cosine similarity between all pairs of positional embeddings, and visualize the results. You can use *torch.nn.functional.cosine_similarity* for this task.

4.	**Visualization of attention maps**: your task is to modify the TransformerBlock class/ and ViT model such that it returns the attention weights from the self_attention layer during the forward pass. You will then visualize the attention maps for the last layer which are averaged over the heads. (You need to train the model again after making these changes.)




#Importing Required Libraries

Let's start by importing the necessary libraries.

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.nn.functional as F
from torchvision.utils import make_grid

import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt


# For repeatable experiments
random_seed = 32
torch.manual_seed(random_seed)

# Loading the Dataset and creating dataloaders

### Download Dataset

In this assignment, you'll be working on MNIST Dataset which contains over 60,000 images of handwritten digits.
The first step is to download the dataset:



In [3]:
# Downloading training data for MNIST dataset
train_dataset = MNIST(root="dataset/",
                      train=True,
                      download=True,
                      transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,), (0.3081,))
                      ]))

# Downloading testing data for MNIST dataset
test_dataset = MNIST(root="dataset/",
                     train=False,
                     download=True,
                     transform=transforms.Compose([
                     transforms.ToTensor(),
                     transforms.Normalize((0.1307,), (0.3081,))
                     ]))

### Create DataLoader

Now, let's define our dataloader to pass samples in **minibatches**.

In [None]:
train_batch_size = 64
test_batch_size=128

# Create data loaders to iterate over data
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

print(f"Training data size: {len(train_loader.dataset)}")
print(f"Test data size: {len(test_loader.dataset)}")



Let's check size of each batch in our train_loader:

In [None]:
# Displaying the shape of a single batch from the train loader
train_examples = enumerate(train_loader)
batch_idx, (train_images, train_labels) = next(train_examples)

print(f"Shape of images in train loader [B, C, H, W]: {train_images.shape}")
print(f"Shape of labels in train loader [B]: {train_labels.shape}")

### Visualizing Sample Images

Let's visualize some of sample images from the train_dataset:

In [None]:
fig = plt.figure()
for i in range(6):
    plt.subplot(2, 3, i + 1)
    plt.tight_layout()
    plt.imshow(train_images[i][0], cmap='gray', interpolation='none')
    plt.title("Label: {}".format(train_labels[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()


# Vision Transformer




### Patch Extraction

Patch extraction is the first step in Vision Transformer (ViT) architecture, it divides the input image into smaller, non-overlapping patches that can be processed independently.
Here, we defined a function that divides images into patches:  

In [7]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

Run this function with **flatten_channels=True** and **flatten_channels=False** on the same batch from the previous cells, set **batch_size=4** and discuss the shape of the output: what each dimension corresponds to?

In [None]:
'''

# Your code here

'''

Notice how setting flatten_channels affects the shape of the output.

Now, let's visualize the patches to see how they look. In this case, we will use the output of the **img_to_patch** function with **flatten_channels=False**.
We will use **torchvision.utils.make_grid** to visualize the patches as a grid. The **plot_patches** function allows you to control the display by setting **seq=False** to view the patches in their original image shape or **seq=True** to display the image as a sequence of patches.


In [9]:
def plot_patches(patches, seq=False):
  fig, ax = plt.subplots(1, 1, figsize=(14, 3))

  # set nrow: Number of images displayed in each row of the grid
  if seq == False:
    nrow=int(train_images.shape[2] / patch_size)
  else:
    nrow =int((train_images.shape[2] / patch_size)**2)

  img_grid = make_grid(patches, nrow=nrow , normalize=True, pad_value=0.9)
  img_grid = img_grid.permute(1, 2, 0)  # Convert from CxHxW to HxWxC for imshow

  ax.imshow(img_grid, cmap='gray')
  ax.axis('off')

  plt.show()
  plt.close()

You need to plot the extracted patches for the first image in the batch, with **batch_size=4**:

In [None]:
'''

# Your code here

'''

By looking at each individual patch, you can see that recognizing the digits is much harder compared to seeing the whole image. Yet, this is the input we give to the Transformer for digit classification. The model must figure out how to piece these patches together on its own to accurately identify the digits.

### Tranformer Block


In Vision Transformers, after the image is divided into patches during patch extraction, each patch is flattened and converted into a token through embedding. These token embeddings are then processed by the transformer block, which is key to learning the relationships between different regions of the image.

The Transformer block consists of two main components:


*   **Multi-head self-attention**: which allows the model to focus on different patches simultaneously and,
*   **Feedforward neural network (MLP)**: for further transformation.

Layer normalization is applied before each mechanism to ensure stability and enhance learning efficiency.


In [8]:
# Transformer Block with Self-Attention and MLP
class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim=128, num_heads=4):
        # Initialize the parent nn.Module
        super(TransformerBlock, self).__init__()

        # First Layer Normalization
        self.layer_norm1 = nn.LayerNorm(hidden_dim)

        # Multi-Head Self-Attention
        self.self_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
                                                    batch_first=True, dropout=0.1)

        # Second Layer Normalization
        self.layer_norm2 = nn.LayerNorm(hidden_dim)

        # Feed-forward network (MLP) with hidden layer and activation
        self.feed_forward = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 4, hidden_dim),
            )



    def forward(self, x):

        # First Layer Normalization
        norm_x = self.layer_norm1(x)

        #self-attention with residual connection
        attn_output= self.self_attention(norm_x, norm_x, norm_x)[0]
        x = x + attn_output

        # Second Layer Normalization
        norm_x=self.layer_norm2(x)

        #feed-forward network with residual connection
        feed_forward_output = self.feed_forward(norm_x)
        x = x + feed_forward_output

        return x

### ViT Model

The vision transformer architecture for an image classification task follows these steps:


*   **Patch Extraction**: Split the image into fixed-size patches and flatten them.
*   **Patch Embedding**: Apply linear embeddings to the flattened patches (project them to the hidden size dimension).
*   **Add positional embeddings**: Add positional embeddings to patch embeddings.
*   **Concatenate output token**: Append the output token to the sequence of patch embeddings.
*   **Transformer block processing**: Pass the sequence of embeddings through a transformer encoder.
*   **Classification**: Use the output token embedding for the final classification.



In [9]:
# Vision Transformer (ViT) with Transformer Blocks
class ViT(nn.Module):
    def __init__(self, img_size, in_channels, num_classes, patch_size, hidden_size, num_layers, num_heads=8):
        super(ViT, self).__init__()

        # Store patch dimension
        self.patch_size = patch_size

        # Linear layer to embed image patches into the hidden dimension
        self.patch_embed = nn.Linear(in_channels * patch_size * patch_size, hidden_size)

        # Stack of Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads) for _ in range(num_layers)
        ])

        # Linear layer to output class predictions
        self.classifier = nn.Linear(hidden_size, num_classes)

        # Learnable output token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))

        # Positional embedding to maintain spatial information
        num_patches = (img_size // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, hidden_size) * 0.001)

    def forward(self, x):
        batch_size = x.size(0)

        # Convert image to patches and flatten
        patch_seq = img_to_patch(x, self.patch_size, flatten_channels=True)

        # Embed patches into the hidden dimension
        patch_embeddings = self.patch_embed(patch_seq)

        # Add positional embeddings to the patch embeddings
        patch_embeddings += self.pos_embed

        # Concatenate class token to the patch embeddings
        cls_token = self.cls_token.expand(batch_size, 1, -1)
        embeddings = torch.cat((cls_token, patch_embeddings), dim=1)

        # Pass the embeddings through transformer blocks
        for block in self.transformer_blocks:
            embeddings = block(embeddings)

        # Classification based on the class token output
        return self.classifier(embeddings[:, 0])


# Training

### Device Configuration for Training

Ensure that you have set your device to GPU:

In [None]:
# Get device for training.
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Let's check the model's structure:




In [None]:
patch_size=4
num_classes = 10

# Let's define our model:
model = ViT(img_size=train_images.shape[2],
            in_channels=train_images.shape[1],
            num_classes = num_classes,
            patch_size=patch_size,
            hidden_size=128,
            num_layers=8,
            num_heads=8).to(device)

# View the model's architecture
print(model)

Before we start training, let's check the shape of model's output when fed a single batch:

In [None]:
train_examples = enumerate(train_loader)
batch_idx, (train_images, train_labels) = next(train_examples)

# Pass image through network
out= model(train_images.to(device))
# Check input and output's shapes
print(f"Shape of images in train loader [B, C, H, W]: {train_images.shape}")
print(f"Shape of images in train loader [B, num_classes]: {out.shape}")

### Train Function

Here, we define our training function that is used to train the model.



In [24]:
# Let's define our training function
def train(dataloader, model, loss_fn, optimizer,epoch,train_losses,train_acc):

    model.train()
    training_loss=0
    correct = 0
    for Image, Label in tqdm(dataloader):



        Image=Image.to(device)
        Label=Label.to(device)

        pred=model(Image)

        loss=loss_fn(pred,Label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        training_loss += loss.item()
        correct += (pred.argmax(1) == Label).type(torch.float).sum().item()


    # Average loss and accuracy for the epoch
    training_loss /= len(dataloader)
    train_accuracy= (100*correct/len(dataloader.dataset))
    train_losses.append(training_loss)
    train_acc.append(train_accuracy)

    print(f"Epoch {epoch+1} \n Train Accuracy: {train_accuracy:.1f}%, Train loss: {training_loss:.6f}")



### Test Function

The test function evaluates the model's predictive performance using the **test_dataloader**.




In [25]:
# Now, let's define Our test function
def test(dataloader, model, loss_fn,epoch,val_losses,val_acc):

    num_batches = len(dataloader)
    model.eval()
    val_loss, correct = 0, 0
    for Image, Label in tqdm(dataloader):


       Image=Image.to(device)
       Label=Label.to(device)

       pred=model(Image)
       loss=loss_fn(pred,Label)

       pred_labels = pred.argmax(dim=1)

       val_loss += loss.item()
       correct += (pred.argmax(1) == Label).type(torch.float).sum().item()

    # Average loss and accuracy for the epoch
    val_loss /= len(dataloader)
    val_accuracy= (100*correct/len(dataloader.dataset))
    val_losses.append(val_loss)
    val_acc.append(val_accuracy)

    print(f"Epoch {epoch+1} \n Test Accuracy: {val_accuracy:>0.1f}%, Test loss: {val_loss:>7f}")



### Training Loop

Now, we need to define our loss function and optimizer and start training.


In [None]:
# Define our learning rate, loss function and optimizer
epochs = 10
learning_rate = 0.001
patch_size = 4
n_classes = 10

model = ViT(img_size=train_images.shape[2],
            in_channels=train_images.shape[1],
            num_classes = num_classes,
            patch_size=patch_size,
            hidden_size=128,
            num_layers=8,
            num_heads=8).to(device)

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                    T_max=epochs,
                                                    eta_min=0)
train_losses, train_acc = [], []
val_losses, val_acc = [], []

# Let's start training:
for epoch in range(epochs):
    train(train_loader, model, loss_fn, optimizer,epoch,train_losses,train_acc)
    test(test_loader, model, loss_fn,epoch, val_losses,val_acc)
    lr_scheduler.step()
print("Finished!")


#Inference and plotting the learning curves

Let's see how model performs on the first 10 samples in the test dataloader:

In [None]:
# Save our model parameters
if not os.path.exists('model'):
    os.makedirs('model')


torch.save(model.state_dict(), "model/VIT_Model.pth")
print("Saved PyTorch Model State to model/VIT_Model.pth")

# Load the saved model parameters into a new instance of the model
model = ViT(img_size=train_images.shape[2],
            in_channels=train_images.shape[1],
            num_classes = num_classes,
            patch_size=patch_size,
            hidden_size=128,
            num_layers=8,
            num_heads=8).to(device)
model.load_state_dict(torch.load("model/VIT_Model.pth"))

model.eval()
for i in range(10):
    x, y = test_dataset[i][0], test_dataset[i][1]
    x = x.to(device)

    pred = model(x.unsqueeze(0))

    # for predicting 10 classes
    predicted, actual = pred[0].argmax(0).item(), y
    print(f'Predicted: "{predicted}", Actual: "{actual}"')


Let's plot the learning curves:

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Plot the learning curve for loss
ax1.plot(range(1, epochs + 1), train_losses, label='Training Loss')
ax1.plot(range(1, epochs + 1), val_losses, label='Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Loss Curve')
ax1.legend()

# Plot the learning curve for accuracy
ax2.plot(range(1, epochs + 1), train_acc, label='Training Accuracy')
ax2.plot(range(1, epochs + 1), val_acc, label='Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Accuracy Curve')
ax2.legend()

# Adjust layout
plt.tight_layout()
plt.show()

# Positional Embeddings

Positional embeddings in ViT are typically initialized randomly, with no clear pattern or relationship at first. As the model trains, these embeddings evolve to encode meaningful positional information.

In this part of the assignment, you'll visualize how this evolves by examining the cosine similarity between positional embeddings before and after training. Here's what you need to do:



*   Extract positional embeddings from the ViT model before and after training.
*   Compute the cosine similarity between each pair of embeddings to observe how similar each patch's position is to others.
*   Reshape and visualize these similarities in a grid, creating a similarity heatmap for each patch.
*   Use **torch.nn.functional.cosine_similarity** for the calculations.







In [None]:
'''

# Your code here

'''

# Attention Map

Attention in Transformers allows the model to focus on different parts of the input, assigning weights that reflect the importance of each patch relative to others. In Vision Transformers (ViTs), attention helps capture spatial relationships by highlighting key areas of an image that are crucial for prediction. As the model processes deeper layers, these attention maps become more refined, showing which patches the model relies on to make decisions. Visualizing these attention maps provides insight into how the model interprets and focuses on different regions of the image during inference.

In this part of the assignment, you'll then visualize how attention evolves across patches. For this task, you need to:

*    Modify the TransformerBlock and ViT model to return attention weights during the forward pass. (Adjust the train and test functions accordingly.)
*    Average attentions across head
*    Exclude the class token from the attention map for visualization.
*   Use **torch.nn.functional.interpolate** to resize the attention map to match the image size for better interpretability.
*   Use matplotlib or any other library of your choice to visualize the attention maps.
*   Visualize attention map for test_dataset[50] and test_dataset[160]


In [None]:
'''

# Your code here

'''

# Acknowledgment

This Assignment is based on:

*   A Course provided by [IBM](https://cognitiveclass.ai/courses/vision-transformers-for-image-classification-hands-on)
