<a href="https://colab.research.google.com/github/fgokmenoglu/PYTORCH/blob/main/Vision_Transformer_Ex1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
import numpy as np

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor

np.random.seed(0)
torch.manual_seed(0)

class MyViT(nn.Module):
  def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
    # Super constructor
    super(MyViT, self).__init__()

    # Input and patches sizes
    self.input_shape = input_shape # w.r.t images --> (N, C, H, W)
    self.n_patches = n_patches # breaking image in n_patches x n_patches
    self.n_heads= n_heads
    assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
    self.hidden_d = hidden_d

    # Step 1: Linear mapper
    self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
    self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

    # Step 2: Classification token
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

    # Step 3: Positional embedding
    # Check inside forward function
    
    # Step 4a: Layer normalization 1
    self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

    # Step 4b: Multi-head Seşf Attention (MSA) and classification token
    self.msa = MyMSA(self.hidden_d, n_heads)

    # Step 5a: Layer normalization 2
    self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

    # Step 5b: Encoder MLP
    self.enc_mlp = nn.Sequential(nn.Linear(self.hidden_d, self.hidden_d), nn.ReLU())

    # Step 6: Classification MLP
    self.mlp = nn.Sequential(nn.Linear(self.hidden_d, out_d), nn.Softmax(dim=-1))

  def forward(self, images):
    # Divide images into patches
    n, c, w, h = images.shape
    patches = images.reshape(n, self.n_patches ** 2, self.input_d)

    # Run linear layer for tokenization
    tokens = self.linear_mapper(patches)

    # Add classification token to the tokens
    tokens = torch.stack([torch.vstack((self.class_token, tokens[ii])) for ii in range(len(tokens))])

    # Add positional embedding
    tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)

    # TRANFORMER ENCODING BEGINS #
    # Run layer normalization, MSA and residual connection
    out = tokens + self.msa(self.ln1(tokens))

    # Run layer normalization, MSA and residual connection
    out = out + self.enc_mlp(self.ln2(out))
    # TRANSFORMER ENCODING ENDS

    # Get the classification
    out = out[:, 0]

    return self.mlp(out)

class MyMSA(nn.Module):
  def __init__(self, d, n_heads=2):
    super(MyMSA, self).__init__()
    self.d = d
    self.n_heads = n_heads

    assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

    d_head = int(d / n_heads)
    self.q_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
    self.k_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
    self.v_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
    self.d_head = d_head
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, sequences):
    # sequence had shape --> (N, seq_length, token_dim)
    # reshaped           --> (N, seq_length, n_heads, token_dim / n_heads)
    # came back to       --> (N, seq_length, item_dim) through concatenation

    result = []

    for sequence in sequences:
      seq_result = []

      for head in range(self.n_heads):
        q_mapping = self.q_mappings[head]
        k_mapping = self.k_mappings[head]
        v_mapping = self.v_mappings[head]

        seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
        q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

        attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
        seq_result.append(attention @ v)
      
      result.append(torch.hstack(seq_result))

    return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

def get_positional_embeddings(sequence_length, d):
  result = torch.ones(sequence_length, d)

  for ii in range(sequence_length):
    
    for jj in range(d):
      result[ii][jj] = np.sin(ii / (10000 ** (jj / d))) if jj % 2 == 0 else np.cos(ii / (10000 ** ((jj - 1) / d)))

  return result

def main():
  # Loading data
  transform = ToTensor()

  train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
  test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

  train_loader = DataLoader(train_set, shuffle=True, batch_size=16)
  test_loader = DataLoader(test_set, shuffle=False, batch_size=16)

  # Defining model and training options
  model = MyViT((1, 28, 28), n_patches=7, hidden_d=8, n_heads=2, out_d=10)
  N_EPOCHS = 5
  LR = 0.01

  # Training loop
  optimizer = Adam(model.parameters(), lr=LR)
  criterion = CrossEntropyLoss()
  
  for epoch in range(N_EPOCHS):
    train_loss = 0.0

    for batch in train_loader:
      x, y = batch
      y_hat = model(x)
      loss = criterion(y_hat, y) / len(x)
      train_loss += loss.item()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    print(f"Epoch {epoch + 1} / {N_EPOCHS} loss: {train_loss:.2f}")

    # Testing loop
  correct, total = 0, 0
  test_loss = 0.0

  for batch in test_loader:
    x, y = batch
    y_hat = model(x)
    loss = criterion(y_hat, y)
    test_loss += loss / len(x)
    correct += torch.sum(torch.argmax(y_hat, dim=1) == y).item()
    total += len(x)

  print(f"Test loss: {test_loss:.2f}")
  print(f"Test accuracy: {correct / total * 100:.2f}%")

if __name__ == '__main__':
    main()


Epoch 1 / 5 loss: 410.23
Epoch 2 / 5 loss: 390.20
Epoch 3 / 5 loss: 384.24
Epoch 4 / 5 loss: 374.24
Epoch 5 / 5 loss: 372.16
Test loss: 61.88
Test accuracy: 87.67%
