## Vision Transformer (ViT)

In this assignment we're going to work with Vision Transformer. We will start to build our own vit model and train it on an image classification task.
The purpose of this homework is for you to get familar with ViT and get prepared for the final project.

In [1]:
import math

import torch

import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

from tqdm import tqdm

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.

For the implementation, feel free to experiment different kinds of setup, as long as you use attention as the main computation unit and the ViT can be train to perform the image classification task present later.
You can read about the ViT implement from other libary: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size//2 # to compensate for pooling
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        self.num_patches = (image_size // self.patch_size) ** 2
        self.patch_dim = in_channels * self.patch_size ** 2
        
        self.emb = nn.Linear(self.patch_dim, self.embed_dim)
        
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=1, stride=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        out = self.proj(x)
        out = out.flatten(2).transpose(1, 2)
        
        return out

## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

In [4]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim

        self.q = nn.Linear(embed_dim, embed_dim, bias = False)
        self.k = nn.Linear(embed_dim, embed_dim, bias = False)
        self.v = nn.Linear(embed_dim, embed_dim, bias = False)

        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        b, n, f = x.size()
        q = self.q(x).view(b, n, self.num_heads, self.embed_dim//self.num_heads).transpose(1,2)
        k = self.k(x).view(b, n, self.num_heads, self.embed_dim//self.num_heads).transpose(1,2)
        v = self.v(x).view(b, n, self.num_heads, self.embed_dim//self.num_heads).transpose(1,2)

        attn = F.softmax(torch.einsum("bhif, bhjf->bhij", q, k)/self.embed_dim**0.5, dim=-1)
        out = torch.einsum("bhij, bhjf->bihf", attn, v)
        out = self.out(out.flatten(2))
        return out

## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
You may also want to use layer normalization or other type of normalization.

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        super(TransformerBlock, self).__init__()
        self.la1 = nn.LayerNorm(embed_dim,eps=1e-12)
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads)
        self.la2 = nn.LayerNorm(embed_dim,eps=1e-12)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        out = self.attention(self.la1(x)) + x
        out = self.mlp(self.la2(out)) + out
        return out

## VisionTransformer:
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        super(VisionTransformer, self).__init__()
        num_patches = (image_size // patch_size)**2

        self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)

        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_emb = nn.Parameter(torch.randn(1, num_patches+1, embed_dim))
        enc_list = [TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)]
        self.enc = nn.Sequential(*enc_list)
        self.fc = nn.Sequential(
            nn.LayerNorm(embed_dim, eps=1e-12),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        out = self.patch_embedding(x)
        out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out],dim=1)
        out = out + self.pos_emb
        out = self.enc(out)
        out = out[:,0]
        out = self.fc(out)
        return out

## Let's train the ViT!

We will train the vit to do the image classification with cifar100. Free free to change the optimizer and or add other tricks to improve the training

In [7]:
image_size = 128
patch_size = 16
in_channels = 3
embed_dim = 192
num_heads = 8
mlp_dim = 1024
num_layers = 8
num_classes = 100
dropout = 0.1

batch_size = 256

In [8]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout)

Using Sharpness-Aware Minimization (SAM) Optimizer for improving generalization (Code: https://github.com/davda54/sam)

In [9]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups


In [10]:
criterion = nn.CrossEntropyLoss()
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9, adaptive=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=50)

In [11]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train,);
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test,);

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
# Uncomment if running with Multiple GPUs
# model = nn.DataParallel(model, device_ids=[0, 1]) 

model.to(device)

num_epochs = 30
best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, data in tqdm(enumerate(trainloader, 0), desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", total=len(trainloader), leave=False):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.first_step(zero_grad=True)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.second_step(zero_grad=True)
        
        running_loss += loss.item() * inputs.size(0)

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_accuracy = 100 * correct / total

    scheduler.step()
    
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_accuracy:.2f}%, Validation Accuracy: {val_acc:.2f}%")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")

                                                                

Epoch 1/30, Loss: 1082.6579, Training Accuracy: 5.42%, Validation Accuracy: 8.43%


                                                                

Epoch 2/30, Loss: 967.8433, Training Accuracy: 10.90%, Validation Accuracy: 13.56%


                                                                

Epoch 3/30, Loss: 875.8873, Training Accuracy: 16.66%, Validation Accuracy: 19.29%


                                                                

Epoch 4/30, Loss: 796.5486, Training Accuracy: 22.56%, Validation Accuracy: 24.72%


                                                                

Epoch 5/30, Loss: 734.8067, Training Accuracy: 27.24%, Validation Accuracy: 30.97%


                                                                

Epoch 6/30, Loss: 689.4392, Training Accuracy: 30.87%, Validation Accuracy: 33.43%


                                                                

Epoch 7/30, Loss: 648.8958, Training Accuracy: 34.22%, Validation Accuracy: 36.99%


                                                                

Epoch 8/30, Loss: 616.4746, Training Accuracy: 36.79%, Validation Accuracy: 38.67%


                                                                

Epoch 9/30, Loss: 586.6314, Training Accuracy: 39.21%, Validation Accuracy: 40.38%


                                                                 

Epoch 10/30, Loss: 560.8058, Training Accuracy: 41.22%, Validation Accuracy: 42.76%


                                                                 

Epoch 11/30, Loss: 538.8508, Training Accuracy: 43.25%, Validation Accuracy: 43.12%


                                                                 

Epoch 12/30, Loss: 514.0950, Training Accuracy: 45.61%, Validation Accuracy: 44.58%


                                                                 

Epoch 13/30, Loss: 491.2745, Training Accuracy: 47.63%, Validation Accuracy: 46.82%


                                                                 

Epoch 14/30, Loss: 472.5086, Training Accuracy: 49.17%, Validation Accuracy: 47.29%


                                                                 

Epoch 15/30, Loss: 451.6610, Training Accuracy: 51.10%, Validation Accuracy: 47.67%


                                                                 

Epoch 16/30, Loss: 425.8101, Training Accuracy: 53.59%, Validation Accuracy: 51.19%


                                                                 

Epoch 17/30, Loss: 404.7301, Training Accuracy: 55.37%, Validation Accuracy: 50.38%


                                                                 

Epoch 18/30, Loss: 386.2508, Training Accuracy: 57.11%, Validation Accuracy: 52.36%


                                                                 

Epoch 19/30, Loss: 362.7247, Training Accuracy: 59.42%, Validation Accuracy: 52.67%


                                                                 

Epoch 20/30, Loss: 344.4425, Training Accuracy: 61.17%, Validation Accuracy: 53.31%


                                                                 

Epoch 21/30, Loss: 323.3215, Training Accuracy: 63.15%, Validation Accuracy: 54.28%


                                                                 

Epoch 22/30, Loss: 302.9631, Training Accuracy: 65.24%, Validation Accuracy: 54.51%


                                                                 

Epoch 23/30, Loss: 281.8075, Training Accuracy: 67.51%, Validation Accuracy: 55.01%


                                                                 

Epoch 24/30, Loss: 261.9907, Training Accuracy: 69.39%, Validation Accuracy: 54.87%


                                                                 

Epoch 25/30, Loss: 245.3479, Training Accuracy: 70.93%, Validation Accuracy: 55.28%


                                                                 

Epoch 26/30, Loss: 223.9455, Training Accuracy: 72.91%, Validation Accuracy: 55.92%


                                                                 

Epoch 27/30, Loss: 202.3032, Training Accuracy: 75.50%, Validation Accuracy: 56.49%


                                                                 

Epoch 28/30, Loss: 182.4770, Training Accuracy: 77.67%, Validation Accuracy: 57.26%


                                                                 

Epoch 29/30, Loss: 164.4295, Training Accuracy: 79.62%, Validation Accuracy: 57.11%


                                                                 

Epoch 30/30, Loss: 144.8760, Training Accuracy: 81.91%, Validation Accuracy: 56.92%


Please submit your best_model.pth with this notebook. And report the best test results you get.

As we can observe, the model is already overfitting. With more epochs, it will only get worse at generalization. Vanilla VITs isn't very competitive on CIFAR-100. We need to pass the images through a good feature extractor to have some on inductive bias on locality. The best result here is close to 57%.