#Vision Transformer

![vit_figure.png]( https://raw.githubusercontent.com/google-research/vision_transformer/main/vit_figure.png)

##Imports

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader


import torch.optim as optim
import math

from tqdm import tqdm


##Model

In [4]:
def initialize_weights(x):
  nn.init.xavier_uniform_(x.weight)
  if x.bias is not None:
    nn.init.constant_(x.bias,0)

In [5]:
class Attention(nn.Module):
  def __init__(self,embed_dim,head_dim,dropout_rate):
    super(Attention,self).__init__()

    self.query=nn.Linear(embed_dim,head_dim)
    self.key=nn.Linear(embed_dim,head_dim)
    self.value=nn.Linear(embed_dim,head_dim)
    self.dropout=nn.Dropout(dropout_rate)

    initialize_weights(self.query)
    initialize_weights(self.key)
    initialize_weights(self.value)


  def forward(self,query,key,value,mask=None):
    d_k=query.size(-1)

    q=self.query(query)
    k=self.key(key)
    v=self.value(value)

    scores=q @ k.transpose(1,2) /math.sqrt(d_k)
    if mask is not None:
      # mask = mask.unsqueeze(1)
      scores = scores.masked_fill(mask == 0, float('-inf'))

    weights=F.softmax(scores,dim=-1)
    weights=self.dropout(weights)
    out=weights @ v
    return out

In [6]:
class MultiheadAttention(nn.Module):
  def __init__(self,embed_dim,head_size,dropout_rate):
    super(MultiheadAttention,self).__init__()

    # Validate that embed_dim is divisible by head_size
    assert embed_dim % head_size == 0, "embed_dim must be divisible by head_size"

    self.head_dim=embed_dim//head_size
    self.embed_dim=embed_dim
    self.head_size=head_size

    self.attn_heads=nn.ModuleList([Attention(embed_dim,self.head_dim,dropout_rate) for _ in range(head_size)])
    self.out_layer=nn.Linear(self.head_dim*head_size,embed_dim)
    self.dropout=nn.Dropout(dropout_rate)

    initialize_weights(self.out_layer)


  def forward(self,query,key,value,mask=None):

    out=torch.cat([h(query,key,value,mask) for h in self.attn_heads],dim=-1)
    out=self.dropout(self.out_layer(out))
    return out

In [7]:
class ViTMLP(nn.Module):
  def __init__(self,hidden_dim,filter_size,dropout=0.5):
    super(ViTMLP,self).__init__()
    self.linear1=nn.Linear(hidden_dim,filter_size)
    self.gelu=nn.GELU()
    self.dropout1=nn.Dropout(dropout)
    self.linear2=nn.Linear(filter_size,hidden_dim)
    self.dropout2=nn.Dropout(dropout)

    initialize_weights(self.linear1)
    initialize_weights(self.linear2)

  def forward(self,x):
    x=self.linear1(x)
    x=self.gelu(x)
    x=self.dropout1(x)
    x=self.linear2(x)
    x=self.dropout2(x)
    return x

In [8]:
class PatchEmbedding(nn.Module):
  def __init__(self,img_size=96,patch_size=16,hidden_dim=512):
    super(PatchEmbedding,self).__init__()

    self.num_patches = (img_size // patch_size) ** 2
    self.conv=nn.LazyConv2d(hidden_dim,kernel_size=patch_size,stride=patch_size)

  def forward(self,x):
    return self.conv(x).flatten(2).transpose(1,2)

In [9]:
class VitBlock(nn.Module):
  def __init__(self,hidden_dim,norm_shape,filter_size,num_heads,dropout):
    super(VitBlock,self).__init__()
    self.attn_norm=nn.LayerNorm(norm_shape)
    self.attn=MultiheadAttention(embed_dim=hidden_dim,head_size=num_heads,dropout_rate=dropout)

    self.mlp_norm=nn.LayerNorm(norm_shape)
    self.mlp=ViTMLP(hidden_dim=hidden_dim,filter_size=filter_size,dropout=dropout)


  def forward(self,x,valid_lens=None):
    y=self.attn_norm(x)
    y=self.attn(y,y,y,valid_lens)
    x=x+y

    y=self.mlp_norm(x)
    y=self.mlp(y)
    x=x+y

    return x

In [10]:
x=torch.ones((2,100,24))
encoder_blk=VitBlock(24,24,48,8,0.5)
encoder_blk.eval()
y=encoder_blk(x)

y.shape,x.shape

(torch.Size([2, 100, 24]), torch.Size([2, 100, 24]))

In [11]:
class ViT(nn.Module):
  def __init__(self,img_size,patch_size,hidden_dim,filter_size,num_heads,n_layers,dropout_rate,lr=0.1,num_classes=10):
    super(ViT,self).__init__()
    self.patch_embedding=PatchEmbedding(img_size=img_size,patch_size=patch_size,hidden_dim=hidden_dim)

    self.cls_token=nn.Parameter(torch.zeros(1,1,hidden_dim))
    num_steps=self.patch_embedding.num_patches+1
    self.pos_embedding=nn.Parameter(torch.randn(1,num_steps,hidden_dim))

    self.dropout=nn.Dropout(dropout_rate)
    self.layers=nn.ModuleList([VitBlock(hidden_dim=hidden_dim,norm_shape=hidden_dim,filter_size=filter_size,num_heads=num_heads,dropout=dropout_rate) for _ in range(n_layers)])

    self.out=nn.Sequential(nn.LayerNorm(hidden_dim),nn.Linear(hidden_dim,num_classes))


  def forward(self,x):
    x=self.patch_embedding(x)
    x=torch.cat((self.cls_token.expand(x.shape[0],-1,-1),x),1)
    x=self.dropout(x+self.pos_embedding)

    for layer in self.layers:
      x=layer(x)

    return self.out(x[:,0])

## Dataset

In [12]:
# Data Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to ViT input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


In [13]:
# Load Train & Test Datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

# Data Loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:12<00:00, 13.2MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [14]:
# Model Initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT(img_size=224, patch_size=16, hidden_dim=768, filter_size=2048, num_heads=8, n_layers=6, dropout_rate=0.1, num_classes=10).to(device)


In [15]:
criterion = nn.CrossEntropyLoss()  # Loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)  # Optimizer

In [16]:
def test(model, test_loader, criterion, device):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient computation
        for images, labels in tqdm(test_loader,leave=False):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    avg_loss = running_loss / len(test_loader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

In [17]:
def train(model, train_loader,test_loader, criterion, optimizer, device,num_epochs=10):

    for epoch in range(num_epochs):
      model.train()  # Set model to training mode
      running_loss = 0.0
      correct = 0
      total = 0
      for images, labels in tqdm(train_loader):
          images, labels = images.to(device), labels.to(device)  # Move to GPU/CPU

          optimizer.zero_grad()  # Reset gradients
          outputs = model(images)  # Forward pass
          loss = criterion(outputs, labels)  # Compute loss
          loss.backward()  # Backpropagation
          optimizer.step()  # Update weights

          running_loss += loss.item()
          _, predicted = torch.max(outputs, 1)  # Get predictions
          correct += (predicted == labels).sum().item()
          total += labels.size(0)

      train_avg_loss = running_loss / len(train_loader)
      train_accuracy = 100 * correct / total
      test_loss,test_acc=test(model, test_loader, criterion, device)
      print(f"Training: Loss({train_avg_loss:4f}), Acc({train_accuracy:4f})| Test: Loss({test_loss:4f}), Acc({test_acc:4f})")


In [18]:
train(model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,device=device,optimizer=optimizer)

100%|██████████| 782/782 [03:53<00:00,  3.35it/s]


Training: Loss(1.758894), Acc(35.598000)| Test: Loss(1.492640), Acc(46.160000)


100%|██████████| 782/782 [03:52<00:00,  3.37it/s]


Training: Loss(1.425696), Acc(48.292000)| Test: Loss(1.342738), Acc(51.590000)


100%|██████████| 782/782 [03:51<00:00,  3.38it/s]


Training: Loss(1.264460), Acc(54.356000)| Test: Loss(1.232427), Acc(55.240000)


100%|██████████| 782/782 [03:51<00:00,  3.38it/s]


Training: Loss(1.140659), Acc(59.172000)| Test: Loss(1.147892), Acc(58.460000)


100%|██████████| 782/782 [03:52<00:00,  3.37it/s]


Training: Loss(1.037167), Acc(62.876000)| Test: Loss(1.108667), Acc(61.160000)


100%|██████████| 782/782 [03:52<00:00,  3.37it/s]


Training: Loss(0.958200), Acc(65.630000)| Test: Loss(1.100553), Acc(61.660000)


100%|██████████| 782/782 [03:52<00:00,  3.37it/s]


Training: Loss(0.873590), Acc(68.850000)| Test: Loss(1.065220), Acc(62.610000)


100%|██████████| 782/782 [03:51<00:00,  3.38it/s]


Training: Loss(0.800464), Acc(71.236000)| Test: Loss(1.096791), Acc(62.570000)


100%|██████████| 782/782 [03:51<00:00,  3.37it/s]


Training: Loss(0.722993), Acc(74.176000)| Test: Loss(1.077523), Acc(64.250000)


100%|██████████| 782/782 [03:51<00:00,  3.38it/s]
                                                 

Training: Loss(0.644595), Acc(76.566000)| Test: Loss(1.111263), Acc(64.350000)




In [20]:
def save_model(model, optimizer, epoch, file_path):
    """
    Save the model and optimizer state dictionaries along with the epoch.
    """
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, file_path)
    print(f"Model saved at epoch {epoch} to {file_path}")

def load_model(model, optimizer, file_path, device):
    """
    Load the model and optimizer state dictionaries from a checkpoint.

    Returns:
        epoch (int): The epoch number stored in the checkpoint.
    """
    checkpoint = torch.load(file_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    model.to(device)
    print(f"Model loaded from {file_path}, starting at epoch {epoch}")
    return epoch

In [21]:
save_model(model, optimizer, 10, "/content/drive/MyDrive/dataset/vit_model_v1.pth")

Model saved at epoch 10 to /content/drive/MyDrive/dataset/vit_model_v1.pth


In [None]:
train(model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,device=device,optimizer=optimizer)

 29%|██▉       | 225/782 [01:06<02:48,  3.31it/s]