#Importing Libraries

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.utils.data as dataloader

# Variabels

In [2]:
batch_size = 64
img_size = 28
patch_size = 7
num_channels = 1
num_patches = (img_size // patch_size) ** 2
attention_heads = 4
embed_dim = 16
transformer_blocks = 4
mlp_nodes = 64
num_classes = 10
learning_rate = 0.001
epochs = 10

# Image Transform

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor()])

# Loading MNIST dataset

In [4]:
train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
val_dataset = torchvision.datasets.MNIST(root='./data', train=False,
                                        download=True, transform=transform)

100%|██████████| 9.91M/9.91M [00:00<00:00, 20.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 495kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.66MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.0MB/s]


# Create Train and Validation batches

In [5]:
train_data = dataloader.DataLoader(train_dataset, shuffle = True, batch_size =batch_size)
val_data = dataloader.DataLoader(val_dataset, shuffle = True, batch_size = batch_size)

# Class for PatchEmbedding

In [6]:
class PatchEmbedding(nn.Module):
  def __init__(self):
    super().__init__()
    self.patch_embed = nn.Conv2d(num_channels, embed_dim, kernel_size= patch_size, stride= patch_size)

  def forward(self, x):
    x = self.patch_embed(x)
    x = x.flatten(2)
    x = x.transpose(1,2)
    return x

## Checking dimensions

In [7]:
images, labels = next(iter(train_data))
print("Dimension of input train batch:", images.shape)
patch_embed = nn.Conv2d(num_channels, 20, kernel_size= patch_size, stride= patch_size)
embedded_image = patch_embed(images)
print("Dimension of input train batch after conv2d:", embedded_image.shape)
flattened_image = embedded_image.flatten(2)
print("Dimension of input train batch after conv2d and flattening:", flattened_image.shape)
transposed_flattened_image = flattened_image.transpose(1,2)
print("Dimension of input train batch after conv2d and flattening and transposed:", transposed_flattened_image.shape)

Dimension of input train batch: torch.Size([64, 1, 28, 28])
Dimension of input train batch after conv2d: torch.Size([64, 20, 4, 4])
Dimension of input train batch after conv2d and flattening: torch.Size([64, 20, 16])
Dimension of input train batch after conv2d and flattening and transposed: torch.Size([64, 16, 20])


# Class for Transformer Encoder

In [17]:
class TransformerEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer_norm1 = nn.LayerNorm(embed_dim)
    self.multi_head_attention = nn.MultiheadAttention(embed_dim, attention_heads, batch_first = True)
    self.layer_norm2 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, mlp_nodes),
        nn.GELU(),
        nn.Linear(mlp_nodes, embed_dim)
    )

  def forward(self, x):
    residual1 = x
    x = self.layer_norm1(x)
    x = self.multi_head_attention(x, x, x)[0]
    x = x + residual1
    residual2 = x
    x = self.layer_norm2(x)
    x = self.mlp(x)
    x = x + residual2
    return x

# Class for MLP Head

In [18]:
class MLP_Head(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer_norm1 = nn.LayerNorm(embed_dim)
    self.mlp_head = nn.Sequential(
        #nn.Linear(embed_dim),
        nn.Linear(embed_dim, num_classes)
    )
  def forward(self, x):
    #x = x[:,0]
    x = self.layer_norm1(x)
    x = self.mlp_head(x)
    return x

# Class VisionTransformer

In [19]:
class VisionTransformer(nn.Module):
  def __init__(self):
    super().__init__()
    self.patch_embedding = PatchEmbedding()
    self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
    self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
    self.transformer_blocks = nn.Sequential(*[TransformerEncoder() for _ in range(transformer_blocks)])
    self.mlp_head = MLP_Head()
  def forward(self, x):
    x = self.patch_embedding(x)
    B = x.shape[0]
    cls_tokens = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, x), 1)
    x = x + self.position_embedding
    x = self.transformer_blocks(x)
    x = x[:,0]
    x = self.mlp_head(x)
    return x

# Selecting device

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Defining Model

In [21]:
model = VisionTransformer().to(device)

# Optimizer

In [22]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

# Loss Function

In [23]:
criterion = nn.CrossEntropyLoss()

# Training

In [24]:
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"\nEpoch {epoch+1}")

    for batch_idx, (images, labels) in enumerate(train_data):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()
        preds = outputs.argmax(dim=1)

        correct = (preds == labels).sum().item()
        accuracy = 100.0 * correct / labels.size(0)

        correct_epoch += correct
        total_epoch += labels.size(0)

        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx+1:3d}: Loss = {loss.item():.4f}, Accuracy = {accuracy:.2f}%")

    epoch_acc = 100.0 * correct_epoch / total_epoch
    print(f"==> Epoch {epoch+1} Summary: Total Loss = {total_loss:.4f}, Accuracy = {epoch_acc:.2f}%")


Epoch 1
  Batch   1: Loss = 2.4827, Accuracy = 6.25%
  Batch 101: Loss = 1.6192, Accuracy = 50.00%
  Batch 201: Loss = 0.7386, Accuracy = 82.81%
  Batch 301: Loss = 0.5873, Accuracy = 81.25%
  Batch 401: Loss = 0.3585, Accuracy = 90.62%
  Batch 501: Loss = 0.4965, Accuracy = 85.94%
  Batch 601: Loss = 0.1839, Accuracy = 95.31%
  Batch 701: Loss = 0.1165, Accuracy = 96.88%
  Batch 801: Loss = 0.4038, Accuracy = 87.50%
  Batch 901: Loss = 0.2681, Accuracy = 93.75%
==> Epoch 1 Summary: Total Loss = 581.2294, Accuracy = 81.23%

Epoch 2
  Batch   1: Loss = 0.2691, Accuracy = 92.19%
  Batch 101: Loss = 0.2304, Accuracy = 96.88%
  Batch 201: Loss = 0.0623, Accuracy = 98.44%
  Batch 301: Loss = 0.2131, Accuracy = 96.88%
  Batch 401: Loss = 0.2257, Accuracy = 92.19%
  Batch 501: Loss = 0.1221, Accuracy = 95.31%
  Batch 601: Loss = 0.3511, Accuracy = 90.62%
  Batch 701: Loss = 0.1896, Accuracy = 92.19%
  Batch 801: Loss = 0.1742, Accuracy = 95.31%
  Batch 901: Loss = 0.0663, Accuracy = 96.88%
=

# Validation Accuracy

In [25]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_data:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_acc = 100.0 * correct / total
print(f"\n==> Val Accuracy: {test_acc:.2f}%")



==> Val Accuracy: 97.17%
