## Vision Transformer (ViT)

We're going to work with Vision Transformer. We will start to build our own vit model and train it on an image classification task.

In [39]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import numpy as np
import random
from torch.optim.lr_scheduler import _LRScheduler

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.

Reference ViT implementation from other libaries: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

In [41]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
      super(PatchEmbedding, self).__init__()

      assert image_size % patch_size == 0, 'Image size must be divisible by patch size'
      
      self.patch_size = patch_size
      self.embed_dim = embed_dim
      self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
      batch_size, _, height, width = x.shape
      patches = self.proj(x)                                
      patches = patches.permute(0, 2, 3, 1).view(batch_size, -1, self.embed_dim)  # (B, embed_dim, H, W) -> (B, H, W, embed_dim) -> (B, H*W, embed_dim)
      return patches

## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

In [42]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
      super(MultiHeadSelfAttention, self).__init__()
      
      assert embed_dim % num_heads == 0, 'Embedding dimension must be divisible by the number of heads'
        
      self.embed_dim = embed_dim
      self.num_heads = num_heads
      self.head_dim = embed_dim // num_heads
        
      self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
      self.attention_dropout = nn.Dropout(dropout)
      self.proj = nn.Linear(embed_dim, embed_dim)
      self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
      batch_size, seq_length, _ = x.size()
        
      # Compute query, key, and value matrices
      qkv = self.qkv_proj(x)
      qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
      qkv = qkv.permute(0, 2, 1, 3)  # Reorder dimensions: (B, num_heads, seq_length, 3 * head_dim)
      q, k, v = qkv.chunk(3, dim=-1)
      
      # Scaled dot-product attention
      attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
      attn = F.softmax(attn, dim=-1)
      attn = self.attention_dropout(attn)
      output = attn @ v
      
      # Combine heads and apply linear projection
      output = output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.embed_dim)
      output = self.proj(output)
      output = self.proj_drop(output)
      
      return output
        

## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
We may also use layer normalization or other type of normalization.

In [43]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
      super(TransformerBlock, self).__init__()

      self.norm1 = nn.LayerNorm(embed_dim)
      self.attention = MultiHeadSelfAttention(embed_dim, num_heads, dropout)

      self.norm2 = nn.LayerNorm(embed_dim)
      self.mlp = nn.Sequential(
          nn.Linear(embed_dim, mlp_dim),
          nn.GELU(),
          nn.Dropout(dropout),
          nn.Linear(mlp_dim, embed_dim),
          nn.Dropout(dropout),
      )

    def forward(self, x):
      x = x + self.attention(self.norm1(x))
      x = x + self.mlp(self.norm2(x))
      return x

## VisionTransformer
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [44]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
      super(VisionTransformer, self).__init__()

      self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
      self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
      self.pos_embed = nn.Parameter(torch.randn(1, (image_size // patch_size) ** 2 + 1, embed_dim))

      self.transformer_layers = nn.ModuleList([
          TransformerBlock(embed_dim, num_heads, mlp_dim, dropout)
          for _ in range(num_layers)
      ])

      self.norm = nn.LayerNorm(embed_dim)
      self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
      batch_size = x.shape[0]

      # Patch embedding
      x = self.patch_embed(x)

      # Add class token
      cls_token = self.cls_token.repeat(batch_size, 1, 1)
      x = torch.cat([cls_token, x], dim=1)

      # Add positional embeddings
      x = x + self.pos_embed

      # Pass through the transformer layers
      for layer in self.transformer_layers:
          x = layer(x)

      # Class token output
      cls_output = x[:, 0]

      # Apply layer normalization and final linear layer
      logits = self.fc(self.norm(cls_output))

      return logits

## Let's train the ViT!

We will train the vit to do the image classification with cifar100. Free free to change the optimizer and or add other tricks to improve the training

In [45]:
#@title Augmentation
def cutmix_data(x, y):
    if 1.0 > 0:
        lam = np.random.beta(1.0, 1.0)
    else:
        lam = 1

    batch_size = x.size()[0]
    
    index = torch.randperm(batch_size).cuda()

    y_a, y_b = y, y[index]

    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x_sliced = x[index, :, bbx1:bbx2, bby1:bby2]

    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    
    return [bbx1, bby1, bbx2, bby2 ], y_a, y_b, lam, x_sliced

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int64(W * cut_rat)
    cut_h = np.int64(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def mixup_data(x, y):
    lam = np.random.beta(1.0, 1.0)

    batch_size = x.size()[0]
    
    index = torch.randperm(batch_size).cuda()

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

class RandomErasing(object):
    def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
        self.EPSILON = probability
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1
       
    def __call__(self, img):

        if random.uniform(0, 1) > self.EPSILON:
            return img

        for _ in range(100):
            area = img.size()[1] * img.size()[2]
       
            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1/self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))

            if w < img.size()[2] and h < img.size()[1]:
                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)
                if img.size()[0] == 3:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                    img[1, x1:x1+h, y1:y1+w] = self.mean[1]
                    img[2, x1:x1+h, y1:y1+w] = self.mean[2]
                else:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[1]
                return img

        return img

In [46]:
#@title Cosine Scheduler
class CosineAnnealingWarmupRestarts(_LRScheduler):    
    def __init__(self,
                 optimizer : torch.optim.Optimizer,
                 first_cycle_steps : int,
                 cycle_mult : float = 1.,
                 max_lr : float = 0.1,
                 min_lr : float = 0.001,
                 warmup_steps : int = 0,
                 gamma : float = 1.,
                 last_epoch : int = -1
        ):
        assert warmup_steps < first_cycle_steps
        
        self.first_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle_mult = cycle_mult # cycle steps magnification
        self.base_max_lr = max_lr # first max learning rate
        self.max_lr = max_lr # max learning rate in the current cycle
        self.min_lr = min_lr # min learning rate
        self.warmup_steps = warmup_steps # warmup step size
        self.gamma = gamma # decrease rate of max learning rate by cycle
        
        self.cur_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle = 0 # cycle count
        self.step_in_cycle = last_epoch # step size of the current cycle
        
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
        
        # set learning rate min_lr
        self.init_lr()
    
    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)
    
    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch
                
        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

In [47]:
#@title Initialization
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.


    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

In [48]:
#@title Sampler
class RASampler(torch.utils.data.Sampler):
    def __init__(self, dataset_len, batch_size, repetitions=1, len_factor=3.0, shuffle=False, drop_last=False):
        self.dataset_len = dataset_len
        self.batch_size = batch_size
        self.repetitions = repetitions
        self.len_images = int(dataset_len * len_factor)
        self.shuffle = shuffle
        self.drop_last = drop_last

    def shuffler(self):
        if self.shuffle:
            new_perm = lambda: iter(np.random.permutation(self.dataset_len))
        else:
            new_perm = lambda: iter(np.arange(self.dataset_len))
        shuffle = new_perm()
        while True:
            try:
                index = next(shuffle)
            except StopIteration:
                shuffle = new_perm()
                index = next(shuffle)
            for repetition in range(self.repetitions):
                yield index

    def __iter__(self):
        shuffle = iter(self.shuffler())
        seen = 0
        batch = []
        for _ in range(self.len_images):
            index = next(shuffle)
            batch.append(index)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if batch and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return self.len_images // self.batch_size
        else:
            return (self.len_images + self.batch_size - 1) // self.batch_size

In [49]:
# Example usage:
image_size = 32
patch_size = 4
in_channels = 3
embed_dim = 192
num_heads = 12
mlp_dim = 384
num_layers = 8
num_classes = 100
dropout = 0

batch_size = 128

In [50]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)

torch.Size([1, 100])


In [51]:
# Load the CIFAR-100 dataset
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
    RandomErasing(probability=0.25, sh=0.4, r1=0.3, mean=(0.5070, 0.4865, 0.4409))
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(
    trainset, num_workers=2, pin_memory=True,
    batch_sampler=RASampler(len(trainset), batch_size, 1, 3, shuffle=True, drop_last=True))
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [52]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)   
num_epochs = 100
num_steps = int(num_epochs * len(trainloader))
warmup_steps = int(10 * len(trainloader))
scheduler = CosineAnnealingWarmupRestarts(
    optimizer,
    first_cycle_steps = num_steps,
    cycle_mult = 1.,
    max_lr = 0.001,
    min_lr = 1e-6,
    warmup_steps=warmup_steps
)

In [None]:
# Train the model
best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    for i, data in enumerate(trainloader, 0):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        
        # cutmix and cutout
        r = np.random.rand(1)
        if r < 0.5:
            switching_prob = np.random.rand(1)
            # Cutmix
            if switching_prob < 0.5:
                slicing_idx, y_a, y_b, lam, sliced = cutmix_data(images, labels)
                images[:, :, slicing_idx[0]:slicing_idx[2], slicing_idx[1]:slicing_idx[3]] = sliced
                outputs = model(images)
                
                loss =  mixup_criterion(criterion, outputs, y_a, y_b, lam)
            # Mixup
            else:
                images, y_a, y_b, lam = mixup_data(images, labels)
                outputs = model(images)
                
                loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
        else:
            outputs = model(images)
            loss = criterion(outputs, labels) 

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
        scheduler.step()
  

    # Validate the model        
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    print(f"Epoch: {epoch + 1}, Validation Accuracy: {val_acc:.2f}%")

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")