In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import csv

## Image Embedding


1.   Patchify (using conv2d)
2.   Prepend Cls token for every image
1.   Add positional encodings






In [None]:
# class ImageEmbedding(nn.Module):
#   def __init__ (self, img_size : int, in_channels : int = 3, embed_dim : int = 768, patch_size : int = 16):
#     super().__init__()

#     num_patches = (img_size // patch_size) ** 2 # assuming sq. images and patches
#     self.patchify = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
#     self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
#     self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

#     nn.init.trunc_normal_(self.cls_token, std=0.02)
#     nn.init.trunc_normal_(self.pos_embedding, std=0.02)

#   def forward(self, x):
#     batch_size = x.size(0)

#     x = self.patchify(x) # (B x embed_dim x H/p x W/p)
#     x = x.flatten(2, 3)  # (B x embed_dim x N), N -> number of patches
#     x = x.transpose(1,2) # (B x N x embed_dim)

#     cls_token = self.cls_token.expand(batch_size, -1, -1)
#     x = torch.cat((cls_token, x), dim=1) # (B x N+1 x embed_dim)

#     x = x + self.pos_embedding
#     return x


class ImageEmbedding(nn.Module):
  def __init__(self, img_size: int, in_channels: int = 3, embed_dim: int = 768, patch_size: int = 16):
    super().__init__()

    stride = 8  # hardcoded stride < patch_size → overlapping patches

    # compute number of patches accounting for overlap
    num_patches_h = (img_size - patch_size) // stride + 1
    num_patches_w = (img_size - patch_size) // stride + 1
    num_patches = num_patches_h * num_patches_w

    # overlapping patch extraction
    self.patchify = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride)
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

    nn.init.trunc_normal_(self.cls_token, std=0.02)
    nn.init.trunc_normal_(self.pos_embedding, std=0.02)

  def forward(self, x):
    B = x.size(0)

    x = self.patchify(x)              # (B, embed_dim, H', W')
    x = x.flatten(2).transpose(1, 2)  # (B, N, embed_dim)

    cls_token = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_token, x), dim=1)  # (B, N+1, embed_dim)

    x = x + self.pos_embedding
    return x

### Testing

In [3]:
ie = ImageEmbedding(224)

x = torch.rand(16, 3, 224, 224)
out = ie(x)

print(out.shape)

torch.Size([16, 197, 768])


## Self-Attention

In [4]:
class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim : int = 768, num_heads : int = 12):
    super().__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads

    self.query_linear = nn.Linear(embed_dim, embed_dim, bias=False) # W_q
    self.key_linear = nn.Linear(embed_dim, embed_dim, bias=False)   # W_k
    self.value_linear = nn.Linear(embed_dim, embed_dim, bias=False) # W_v

    self.output_linear = nn.Linear(embed_dim, embed_dim)

  def split_heads(self, x):
    batch_size, seq_len, _ = x.size()
    x = x.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
    x = x.permute(0, 2, 1, 3)
    return x

  def compute_attention(self, query, key, value):
    scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
    attention_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attention_weights, value)

  def combine_heads(self, x):
    batch_size, _, seq_len, _ = x.size()
    x = x.permute(0, 2, 1, 3).contiguous()
    return x.view(batch_size, -1, self.embed_dim)

  def forward(self, x):
    query = self.split_heads(self.query_linear(x))
    key = self.split_heads(self.key_linear(x))
    value = self.split_heads(self.value_linear(x))

    attention_weights = self.compute_attention(query, key, value)
    output = self.combine_heads(attention_weights)
    return self.output_linear(output)

### Testing

In [5]:
mha = MultiHeadAttention(embed_dim=256, num_heads=8)

x = torch.randn(2, 4, 256)
output = mha(x)

print("Input : ", x.shape)
print("Output : ", output.shape)


Input :  torch.Size([2, 4, 256])
Output :  torch.Size([2, 4, 256])


## Encoder Block

In [6]:
class EncoderBlock(nn.Module):
  def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.2):
    super().__init__()
    self.multihead_attention = MultiHeadAttention(embed_dim, num_heads)
    self.norm1 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_dim),
        nn.GELU(),
        nn.Linear(mlp_dim, embed_dim)
    )
    self.norm2 = nn.LayerNorm(embed_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    x = x + self.dropout(self.multihead_attention(x))
    x = x + self.dropout(self.mlp(self.norm1(x)))
    return x

## ViT Model

In [7]:
class ViT(nn.Module):
  def __init__(self, img_size, patch_size, in_channels, num_classes, embed_dim, num_heads, mlp_dim, depth=6, dropout=0.2):
    super().__init__()
    self.image_embeddings = ImageEmbedding(img_size, in_channels, embed_dim, patch_size)
    self.encoder = nn.ModuleList([
        EncoderBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(depth)
    ])
    self.norm = nn.LayerNorm(embed_dim)
    self.head = nn.Linear(embed_dim, num_classes)

  def forward(self, x):
    x = self.image_embeddings(x)
    for block in self.encoder:
      x = block(x)
    x = self.norm(x)
    cls_token = x[:, 0]
    return self.head(cls_token)

In [8]:
x = torch.randn(2, 3, 32, 32)  # batch of 2, 32x32 images
model = ViT(img_size=32, patch_size=8, in_channels=3, num_classes=10, embed_dim=64, depth=4, num_heads=4, mlp_dim=128)

logits = model(x)
print(logits.shape)  # (2, 10)

torch.Size([2, 10])


## Data Loading

In [9]:
transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

100%|██████████| 170M/170M [00:13<00:00, 12.3MB/s]


## Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViT(
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=10,
    embed_dim=128,
    num_heads=4,
    mlp_dim=256,
    depth=4,
    dropout=0.5
).to(device)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)

csv_file = "training_log.csv"
with open(csv_file, mode="w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_loss", "train_acc", "test_acc"])

In [None]:
num_epochs = 200
best_test_acc = 0.0
model_save_path = "best_vit_cifar10.pth"

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct_train += (predicted == labels).sum().item()
        total_train += labels.size(0)

    train_loss = running_loss / len(trainloader)
    train_accuracy = 100 * correct_train / total_train

    model.eval()
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_accuracy = 100 * correct_test / total_test

    with open(csv_file, mode="a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([epoch+1, train_loss, train_accuracy, test_accuracy])

    if test_accuracy > best_test_acc:
        best_test_acc = test_accuracy
        torch.save(model.state_dict(), model_save_path)

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Loss: {train_loss:.4f} "
          f"Train Acc: {train_accuracy:.2f}% "
          f"Test Acc: {test_accuracy:.2f}% (Best: {best_test_acc:.2f}%)")

Epoch [1/10] Loss: 1.8418 Train Acc: 31.64% Test Acc: 39.53% (Best: 39.53%)
