In [None]:
import torch
import torchvision
import torch
import torchvision
import torch.utils.data as dataloader
from torch.nn import attention
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms


In [None]:
transformation_opertaion = transforms.Compose([transforms.ToTensor()])

In [None]:
train_dataset = torchvision.datasets.MNIST(root = './data', train = True, download = True, transform = transformation_opertaion)
val_dataset = torchvision.datasets.MNIST(root = './data', train = False, download = True, transform = transformation_opertaion)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 572kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.53MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.34MB/s]


In [None]:
train_loader = dataloader.DataLoader(train_dataset, batch_size = 64, shuffle = True)
val_loader = dataloader.DataLoader(val_dataset, batch_size = 64, shuffle = False)

In [None]:
num_classes = 10
batch_size = 64
num_channels = 1
patch_size = 7
img_size = 28
num_patches = (img_size // patch_size) ** 2
embedding_dim = 64
attention_heads = 4
transformer_blocks = 4
mlp_hidden_nodes = 128
learning_rate = 0.001
epochs = 10

In [None]:
# Patch Embedding

class PatchEmbedding(nn.Module):
  def __init__(self, num_channels, patch_size, embedding_dim):
    super().__init__()
    self.patch_embed = nn.Conv2d(num_channels, embedding_dim, kernel_size = patch_size, stride = patch_size)

  def forward(self, x):
    x = self.patch_embed(x)
    x = x.flatten(2).transpose(1, 2)
    return x

In [None]:
 # Transformer Encoder
class TransformerEncoder(nn.Module):
  def __init__(self, embedding_dim, attention_heads, mlp_hidden_nodes):
    super(TransformerEncoder, self).__init__()
    self.layer_norm1 = nn.LayerNorm(embedding_dim)
    self.layer_norm2 = nn.LayerNorm(embedding_dim)
    self.multihead_attention = nn.MultiheadAttention(embedding_dim, attention_heads,batch_first = True)
    self.mlp = nn.Sequential(
        nn.Linear(embedding_dim, mlp_hidden_nodes),
        nn.GELU(),
        nn.Linear(mlp_hidden_nodes, embedding_dim)
    )


  def forward(self, x):
    residual = x
    x = self.layer_norm1(x)
    x = self.multihead_attention(x, x, x)[0]
    x = residual + x
    residual = x
    x = self.layer_norm2(x)
    x = self.mlp(x)
    x  = residual + x
    return x

In [None]:
# MLP Head
class MLPHead(nn.Module):
  def __init__(self, embedding_dim, num_classes):
    super(MLPHead, self).__init__()
    self.layer_norm1 = nn.LayerNorm(embedding_dim)
    self.mlp_head = nn.Linear(embedding_dim, num_classes)

  def forward(self, x):
    x = self.layer_norm1(x)
    # The classification token is already selected in VisionTransformer forward method.
    # Removing the following line as it's redundant and causes incorrect shape.
    # x = x[:, 0]
    x = self.mlp_head(x)
    return x

In [None]:
# VisionTrnasformers

class VisionTransformer(nn.Module):
  def __init__(self):
    super(VisionTransformer, self).__init__()
    self.patch_embedding = PatchEmbedding(num_channels, patch_size, embedding_dim)
    self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
    self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim))
    self.transformer_blocks = nn.Sequential(*[TransformerEncoder(embedding_dim, attention_heads, mlp_hidden_nodes) for _ in range(transformer_blocks)])
    self.mlp_head = MLPHead(embedding_dim, num_classes)

  def forward(self, x):
    B = x.size(0)
    x = self.patch_embedding(x)
    cls_token = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_token, x), dim = 1)
    x = x + self.position_embedding
    x = self.transformer_blocks(x)
    x = x[:, 0]
    x = self.mlp_head(x)

    return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
# training

for epoch in range(epochs):
  model.train()
  train_loss = 0.0
  train_correct = 0
  print(f"Epoch {epoch + 1}/{epochs}")
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    # Fix for accuracy calculation
    train_correct += (outputs.argmax(1) == labels).sum().item()

  if epoch % 1 == 0:
    print(f"Epoch %{epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss / len(train_loader)}")
    print(f"Train Accuracy: {train_correct / len(train_loader)}")

Epoch 1/10
Epoch %1/10
Train Loss: 0.0751417811934167
Train Accuracy: 62.424307036247335
Epoch 2/10
Epoch %2/10
Train Loss: 0.06157587955260018
Train Accuracy: 62.68017057569296
Epoch 3/10
Epoch %3/10
Train Loss: 0.05446083324195356
Train Accuracy: 62.850746268656714
Epoch 4/10
Epoch %4/10
Train Loss: 0.04575147776259308
Train Accuracy: 63.00319829424307
Epoch 5/10
Epoch %5/10
Train Loss: 0.043470071275472276
Train Accuracy: 63.09808102345416
Epoch 6/10
Epoch %6/10
Train Loss: 0.0397842470206507
Train Accuracy: 63.137526652452024
Epoch 7/10
Epoch %7/10
Train Loss: 0.03604383460061847
Train Accuracy: 63.16098081023454
Epoch 8/10
Epoch %8/10
Train Loss: 0.030236760621951488
Train Accuracy: 63.34434968017057
Epoch 9/10
Epoch %9/10
Train Loss: 0.03164076991826058
Train Accuracy: 63.31556503198294
Epoch 10/10
Epoch %10/10
Train Loss: 0.026201935721931656
Train Accuracy: 63.4136460554371


In [None]:
for epoch in range(epochs):
  model.train()
  train_loss = 0.0
  train_correct = 0
  print(f"Epoch {epoch + 1}/{epochs}")
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    # Fix for accuracy calculation
    train_correct += (outputs.argmax(1) == labels).sum().item()

  # Validation loop
  model.eval()
  val_loss = 0.0
  val_correct = 0
  with torch.no_grad():
    for images, labels in val_loader:
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      val_correct += (outputs.argmax(1) == labels).sum().item()

  # Print training and validation results
  if epoch % 1 == 0:
    print(f"Train Loss: {train_loss / len(train_loader):.4f}")
    print(f"Train Accuracy: {train_correct / len(train_loader.dataset):.4f}")
    print(f"Validation Loss: {val_loss / len(val_loader):.4f}")
    print(f"Validation Accuracy: {val_correct / len(val_loader.dataset):.4f}")

  model.train() # Set model back to training mode

Epoch 1/10
Train Loss: 0.0183
Train Accuracy: 0.9939
Validation Loss: 0.0926
Validation Accuracy: 0.9773
Epoch 2/10
Train Loss: 0.0193
Train Accuracy: 0.9933
Validation Loss: 0.0688
Validation Accuracy: 0.9810
Epoch 3/10
Train Loss: 0.0162
Train Accuracy: 0.9946
Validation Loss: 0.0657
Validation Accuracy: 0.9828
Epoch 4/10
Train Loss: 0.0189
Train Accuracy: 0.9938
Validation Loss: 0.0676
Validation Accuracy: 0.9813
Epoch 5/10
Train Loss: 0.0139
Train Accuracy: 0.9952
Validation Loss: 0.0749
Validation Accuracy: 0.9815
Epoch 6/10
Train Loss: 0.0160
Train Accuracy: 0.9945
Validation Loss: 0.0741
Validation Accuracy: 0.9803
Epoch 7/10
Train Loss: 0.0174
Train Accuracy: 0.9940
Validation Loss: 0.0797
Validation Accuracy: 0.9784
Epoch 8/10
Train Loss: 0.0127
Train Accuracy: 0.9957
Validation Loss: 0.0644
Validation Accuracy: 0.9833
Epoch 9/10
Train Loss: 0.0138
Train Accuracy: 0.9953
Validation Loss: 0.0828
Validation Accuracy: 0.9791
Epoch 10/10
Train Loss: 0.0121
Train Accuracy: 0.9959
V