In [76]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random
import math

In [77]:


transform = transforms.Compose([
    transforms.Resize((224, 224)),           
    transforms.ToTensor(),                   
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],        
        std=[0.2023, 0.1994, 0.2010]          
    )
])


In [78]:
cifar10 = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

In [79]:
len(cifar10)

50000

In [80]:
from torch.utils.data import random_split
train_size = int(0.8 * len(cifar10))
val_size = len(cifar10) - train_size

In [81]:
cifar10train, cifar10val = random_split(cifar10, [train_size, val_size])

In [82]:
cifar10test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

In [83]:
cifar10test[0][0].shape

torch.Size([3, 224, 224])

In [84]:
train_loader = DataLoader(cifar10train, batch_size=128, shuffle = True)
val_loader = DataLoader(cifar10val, batch_size=128, shuffle=False)
test_loader = DataLoader(cifar10test, batch_size = 128, shuffle=False)

In [85]:
#global blobs and the model parameters from table-1 in the paper
d_model = 768
d_ff = 3072
n_heads = 12

height = 224
width = 224
channels = 3

patch_size = 16
N = len(cifar10train)

In [86]:
x = torch.randn(height, width, channels)

In [87]:
#patch embeddings
class PatchEmbedding(nn.Module):
    def __init__(self, channels=3, img_size = 224, patch_size=16, d_model = 768):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size)**2
        
        self.proj = nn.Conv2d(channels, d_model, kernel_size = patch_size, stride = patch_size)
        
        
        
    def forward(self, x):
        x = self.proj(x)
        x= x.flatten(2)
        x = x.transpose(1,2)
        return x
        

In [88]:
class InputEmbed(nn.Module):
    def __init__(self, num_patches, emb_dim):
        super().__init__()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim))
        self.dropout = nn.Dropout(0.1)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B, N, D = x.shape  

        cls_tokens = self.cls_token.expand(B, -1, -1)  
        x = torch.cat((cls_tokens, x), dim=1)          

        x = x + self.pos_embed[:, :x.size(1), :]       
        return self.dropout(x)


In [89]:
class ScaledDotProdAttn(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, q, k, v, mask):
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2,-1))/math.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
            
        weights = torch.softmax(scores, dim =-1)
        outputs = torch.matmul(weights, v)
        
        return outputs, weights

In [90]:
class MultiHeadAttn(nn.Module):
    def __init__(self, d_model, num_heads, mask):
        
        super().__init__()
        
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads

        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads
        
        self.w_q = nn.Linear(d_model, self.d_k * num_heads)
        self.w_k = nn.Linear(d_model, self.d_k * num_heads)
        self.w_v = nn.Linear(d_model, self.d_v*num_heads)
        
        self.attn = ScaledDotProdAttn()
        
        self.w_o = nn.Linear(self.d_v * num_heads, d_model)
        
    def forward(self, q, k, v, mask):
        
        batch_size = q.size(0)
        Q = self.w_q(q)
        K = self.w_k(k)
        V = self.w_v(v)
        
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        V = V.view(batch_size, -1, self.num_heads, self.d_v).transpose(1,2)
        
        outputs, weights = self.attn(Q, K, V, mask)
        
        outputs = outputs.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)
        outputs = self.w_o(outputs)
        
        return outputs, weights

In [91]:
class MLP(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.dropout(x)
        x = F.gelu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        
        return x
        
    

In [92]:
class EncoderLayer(nn.Module):
    def __init__ (self, d_model, d_ff, num_heads):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        
        self.LayerNorm1 = nn.LayerNorm(d_model)
        self.MHA = MultiHeadAttn(d_model, num_heads, mask=None)
        
        self.LayerNorm2 = nn.LayerNorm(d_model)
        self.MLP = MLP(d_model, d_ff)
        
    def forward(self, x):
        
        norm1 = self.LayerNorm1(x)
        attn_out, _ = self.MHA(norm1, norm1, norm1, mask=None)
        x = x + attn_out
        
        norm2 = self.LayerNorm2(x)
        MLP_out = self.MLP(norm2)
        x = x + MLP_out
        
        return x
        
        

In [93]:
class Encoder(nn.Module):
    
    def __init__(self, d_model, d_ff, n_heads, num_layers):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_heads = n_heads
        self.num_layers = num_layers
        
        self.layers = nn.ModuleList(
            EncoderLayer(d_model, d_ff, n_heads) for _ in range (num_layers)
        )
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            
        return x
        
        
    

In [94]:
encoder_out = Encoder(d_model, d_ff, n_heads, num_layers=12)

In [95]:
class Transformer(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.patch_embeds = PatchEmbedding(channels=3, img_size = 224, patch_size=16, d_model = 768)
        self.input_embeds = InputEmbed(num_patches = 196, emb_dim = d_model)
        self.encoder_out = Encoder(d_model, d_ff, n_heads, num_layers=12)
        
        self.head = nn.Linear(d_model, 10) #10 classes in cifar-10
        
    def forward(self, x):
        x = self.patch_embeds(x)
        x = self.input_embeds(x)
        x = self.encoder_out(x)
        cls_token = x[:, 0]
        x = self.head(cls_token)
        
        return x

In [96]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [97]:
device

device(type='cpu')

In [98]:
model = Transformer().to(device)

In [99]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

In [102]:

torch.backends.cudnn.benchmark = True  

In [103]:
# Define number of epochs
num_epochs = 30

# Training loop
best_val_acc = 0
patience = 5
patience_counter = 0

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Add gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    train_loss = total_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    # Validation
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
    
    val_loss = val_loss / len(val_loader)
    val_acc = 100. * val_correct / val_total
    
    # Update learning rate
    scheduler.step()
    
    # Print epoch results
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
        }, 'best_model.pth')
        print(f"Saved new best model with validation accuracy: {val_acc:.2f}%")
        patience_counter = 0
    else:
        patience_counter += 1
        
    if patience_counter >= patience:
        print("Early stopping triggered")
        break