# M2177.003100 Deep Learning Assignment #1<br> Part 1-3. Training Vision Transformers (PyTorch)

Copyright (C) Data Science & AI Laboratory, Seoul National University. This material is for educational uses only. Some contents are based on the material provided by other paper/book authors and may be copyrighted by them. Written by Jaehoon Lee, September 2023

**For understanding of this work, please carefully look at given PDF file.**

Now, you're going to leave behind your implementations and instead migrate to one of popular deep learning frameworks, **PyTorch**. <br>
In this notebook, you will learn to understand and build the basic components of Vision Tranformer(ViT). Then, you will try to classify images in the FashionMNIST datatset and explore the effects of different components of ViTs.
<br>
There are **2 sections**, and in each section, you need to follow the instructions to complete the skeleton codes and explain them.

**Note**: certain details are missing or ambiguous on purpose, in order to test your knowledge on the related materials. However, if you really feel that something essential is missing and cannot proceed to the next step, then contact the teaching staff with clear description of your problem.

### Submitting your work:
<font color=red>**DO NOT clear the final outputs**</font> so that TAs can grade both your code and results. 

### Some helpful tutorials and references for assignment #1-2:
- [1] Pytorch official documentation. [[link]](https://pytorch.org/docs/stable/index.html)
- [2] Stanford CS231n lectures. [[link]](http://cs231n.stanford.edu/)
- [3] Alexey Dosovitskiy et al., "An Image is Worth 16 x 16 Words: Transformers for Image Recognition at Scale", ICLR 2021. [[pdf]](https://arxiv.org/pdf/2010.11929.pdf)

## 1. Building Vision Transformer
Here, you will build the basic components of Vision Transformer(ViT). <br>

![Vision Transformer](imgs/ViT.png)

Using the explanation and code provided as guidance, <br>
Define each component of ViT. <br>


#### ViT architecture:
* ViT model consists with input patch embedding, positional embeddings, transformer encoder, etc.
* Patch embedding
* Positional embeddings
* Transformer encoder with
    * Attention module
    * MLP module

In [None]:
import torch
import torch.nn as nn

##### Patch Embed

**Initialization**: When you create an instance of the PatchEmbedding class, you specify the image_size, patch_size, and in_channels. image_size is the height and width of the input image, patch_size is the size of each patch, and in_channels is the number of input image channels (e.g., 3 for RGB images). 

**Convolutional Projection**: Inside the PatchEmbedding class, a 2D convolutional layer (nn.Conv2d) is used to perform a patch-based projection. This convolutional layer has a kernel size of patch_size, which defines the size of each patch, and a stride of patch_size, which ensures that patches do not overlap. The convolutional layer effectively extracts image patches.

**Reshaping**: After the convolutional projection, the output tensor is reshaped using view. It is transformed from a 4D tensor with dimensions (batch_size, in_channels, H, W) to a 3D tensor with dimensions (batch_size, num_patches, patch_dim). num_patches is the total number of non-overlapping patches in the image, and patch_dim is the number of output channels from the convolutional layer.

In [None]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################


        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
    def forward(self, x):
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################


        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
        return x # output dimension must be: (batch size, number of patches, embed_dim)

##### Attention

**Initialization**
* dim: The input dimension of the sequence. This is the dimensionality of the queries, keys, and values.
* num_heads: The number of attention heads to use. Multi-head attention allows the model to focus on different parts of the input simultaneously.

**Linear Projections (qkv and proj)**: The qkv linear layer takes the input sequence and projects it into three parts: queries (q), keys (k), and values (v). The output of this layer has a shape of (batch_size, sequence_length, 3 * dim).

**Forward Pass (forward method)**: In the forward pass, the input tensor x is processed through the attention mechanism. Here's what happens:<br>
* The linear projection qkv is applied to x, producing a tensor of shape (batch_size, sequence_length, 3 * dim).|
* This tensor is reshaped to have dimensions (batch_size, sequence_length, 3, num_heads, head_dim). The permute operation rearranges the dimensions to (3, batch_size, num_heads, sequence_length, head_dim), making it suitable for multi-head attention.
* The three parts, q, k, and v, are extracted from the reshaped tensor.
* The attention scores are computed by taking the dot product of queries q and keys k. The result is scaled by self.scale.
* The attention scores are passed through a softmax activation along the last dimension (sequence_length), producing attention weights.
* The weighted sum of values v is computed using the attention weights.
* The result is transposed and reshaped to its original shape, and then passed through the proj linear layer.
* The final output is returned.

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        
        return x # output dimension must be: (batch size, number of patches, embed_dim)

##### MLP

The MLP module must consist of three layers:
* fully conncted layer 1
* activation layer
* fully conncted layer 2

In [None]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################

    def forward(self, x):
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
        return x # output dimension must be: (batch size, number of patches, out_features)

##### Transformer Block
The transformer block contains the attention module and MLP module which have residual connections. 
Refer to the following image and build the forward pass.

![Transformer Block](imgs/TransformerBlock.png)

In [None]:
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer)

    def forward(self, x):
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################


        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
        return x


##### Vision Transformer

Using all the components that you built above, **complete** the vision transformer class.

In [None]:
class VisionTransformer(nn.Module):
    """ Vision Transformer """

    def __init__(self, img_size=28, patch_size=4, in_chans=1, num_classes=10, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm, ):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.depth = depth

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ############################################################################## 
        # similarly to cls_token, define a learnable positional embedding that matches the patchified input token size.

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,  norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # Classifier head
        self.head = nn.Linear(
            embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################
        B = x.shape[0]
        
        # Patch Embedding
        x = self.patch_embed(x)
        
        # Concatenate class tokens to patch embedding

        # Add positional embedding to patches

        # Forward through encoder blocks

        # Use class token for classification
        
        # Classifier head

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
        return x

## 2. Training a small ViT model on FashionMNIST dataset.

Define and Train a vision transformer on FashionMNIST dataset. **(You must reach above 85% for full points.)** <br>
Train with at least 5 different hyperparameter settings varying the following ViT hyperparameters. 
Report the setting for the best performance.

#### ViT hyperparameters:
* patch_size
* embed_dim
* depth
* num_heads
* mlp_ratio
* etc.


In [None]:
import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import FashionMNIST

In [None]:
def Train():
    ##############################################################################
    #                           IMPLEMENT YOUR CODE                              #
    ##############################################################################

    patch_size=
    embed_dim=
    depth=
    num_heads= # make sure embed_dim is divisible by num_heads!
    mlp_ratio=
    
    ##############################################################################
    #                              END YOUR CODE                                 #
    ##############################################################################
    
    # Loading data
    transform = ToTensor()

    train_set = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_set = FashionMNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    
    model = VisionTransformer(patch_size=patch_size, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio).to(device)
    model_path = './vit.pth'
    N_EPOCHS = 5
    LR = 0.005

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

    torch.save(model.state_dict(), model_path)
    print('Saved Trained Model.')
    
Train()

### Describe what you did and discovered here
In this cell you should write all the settings tried and performances you obtained. Report what you did and what you discovered from the trials.
You can write in Korean

_Tell us here_