<a href="https://colab.research.google.com/github/bhagatpandey369/vision-transformer-from-scratch/blob/main/Vision_Transformer_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
image = Image.open('alihassan.png').resize((224,224))
x = np.array(image)
P = 16
C = 3
patch = x.reshape(x.shape[0]//P, P, x.shape[1]//P, P, C).swapaxes(1, 2).reshape(-1, P, P, C)
x_p = np.reshape(patch, (-1, P*P*C))
N = x_p.shape[0]
N

In [None]:
D = 768
B = 1
x_p = torch.Tensor(x_p)
x_p = x_p[None, ...]
E = nn.Parameter(torch.randn(1, P*P*C, D))


In [None]:
x_p.shape, E.shape

In [None]:
patch_embedding = torch.matmul(x_p, E)
patch_embedding.shape

In [None]:
class_token = nn.Parameter(torch.randn(1, 1, D))
class_token.shape

In [None]:
patch_embeddings = torch.cat((class_token, patch_embedding), 1)
patch_embeddings.shape

In [None]:
E_pos = nn.Parameter(torch.randn(1, N + 1, D))
E_pos.shape

In [None]:
z0 = patch_embeddings + E_pos
z0.shape

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, embedding_dim, key_dim=64):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.key_dim = key_dim
    self.W = nn.Parameter(torch.randn(embedding_dim, 3 * key_dim))

  def forward(self, x):
    key_dim = self.key_dim
    qkv = torch.matmul(x, self.W)
    q = qkv[:, :, :key_dim]
    k = qkv[:, :, key_dim:2*key_dim]
    v = qkv[:, :, 2*key_dim:]
    k_T = torch.transpose(k, -2, -1)
    dot_product = torch.matmul(q, k_T)
    scaled_dot_products = dot_product / np.sqrt(key_dim)
    attention_weights = F.softmax(scaled_dot_products, dim=1)
    weight_values = torch.matmul(attention_weights, v)
    return weight_values

In [None]:
D_h = 64
self_attention = SelfAttention(D, D_h)
attention_scores = self_attention(patch_embeddings)
attention_scores.shape

In [None]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, embedding_dim=768, num_heads=12):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.num_heads = num_heads
    assert embedding_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
    self.key_dim = embedding_dim // num_heads
    self.attention_list = [SelfAttention(embedding_dim, self.key_dim) for _ in range(num_heads)]
    self.multihead_attention = nn.ModuleList(self.attention_list)
    self.W = nn.Parameter(torch.randn(num_heads * self.key_dim, embedding_dim))

  def forward(self, x):
    attention_scores = [attention(x) for attention in self.multihead_attention]
    z = torch.cat(attention_scores, -1)
    attention_scores = torch.matmul(z, self.W)
    return attention_scores

In [None]:
n_head = 12
multi_head_attention = MultiHeadSelfAttention(D, n_head)
attention_scores = multi_head_attention(patch_embeddings)
attention_scores.shape

In [None]:
class MultiLayerPerceptron(nn.Module):
  def __init__(self, embedding_dim=768, hidden_dim=3072):
    super().__init__()
    self.mlp = nn.Sequential(
        nn.Linear(embedding_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, embedding_dim)
    )

  def forward(self, x):
    return self.mlp(x)


In [None]:
hidden_dim = 3072
mlp = MultiLayerPerceptron(D,hidden_dim)
output = mlp(patch_embeddings)
output.shape

In [None]:
class LayerNormalization(nn.Module):
  def __init__(self, eps:float=10**-6) -> None:
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(1))
    self.bias = nn.Parameter(torch.zeros(1))

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True)
    std = x.std(dim=-1, keepdim=True)
    return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, embedding_dim=768, num_heads=12, hidden_dim=3072, dropout=0.1):
    super().__init__()
    self.MSA = MultiHeadSelfAttention(embedding_dim, num_heads)
    self.MLP = MultiLayerPerceptron(embedding_dim, hidden_dim)
    self.layer_norm1 = LayerNormalization(embedding_dim)
    self.layer_norm2 = LayerNormalization(embedding_dim)
    self.dropout1 = nn.Dropout(p=dropout)
    self.dropout2 = nn.Dropout(p=dropout)
    self.dropout3 = nn.Dropout(p=dropout)

  def forward(self, x):
    out_1 = self.dropout1(x)
    out_2 = self.layer_norm1(out_1)
    msa_out = self.MSA(out_2)
    out_3 = self.dropout2(msa_out)
    res_out = x + out_3
    out_4 = self.layer_norm2(res_out)
    mlp_out = self.MLP(out_4)
    out_5 = self.dropout3(mlp_out)
    output = res_out + out_5
    return output


In [None]:
dropout_prob = 0.1
transformer_encoder = TransformerEncoder(D, n_head, hidden_dim, dropout_prob)
output = transformer_encoder(patch_embeddings)
output.shape

In [None]:
class MLPHead(nn.Module):
  def __init__(self, embedding_dim=768, num_classes=10, fine_tune=False):
    super().__init__()
    self.num_classes = num_classes
    if not fine_tune:
      self.mlp_head = nn.Sequential(
          nn.Linear(embedding_dim, 3072),
          nn.Tanh(),
          nn.Linear(3072, num_classes)
          )
    else:
      self.mlp_head = nn.Linear(embedding_dim, num_classes)

  def forward(self, x):
    return self.mlp_head(x)



In [None]:
class VisionTransformer(nn.Module):
  def __init__(self, patch_size=16, image_size=224, channel_size=3, num_layers=12, embedding_dim=768, num_heads=12, hidden_dim=3072, dropout_prob=0.1, num_classed=10, fine_tune=True):
    super().__init__()
    self.patch_size = patch_size
    self.channel_size = channel_size
    self.num_layers = num_layers
    self.embedding_dim = embedding_dim
    self.num_heads = num_heads
    self.hidden_dim = hidden_dim
    self.dropout_prob = dropout_prob
    self.num_classes = num_classes
    self.num_patches = int(image_size**2 / patch_size**2)
    self.W = nn.Parameter(torch.randn(patch_size*patch_size*channel_size, embedding_dim))
    self.pos_embedding = nn.Parameter(torch.randn(self.num_patches+1, embedding_dim))
    self.class_token = nn.Parameter(torch.randn(1,D))
    transformer_encoder_list = [TransformerEncoder(embedding_dim, num_heads, hidden_dim, dropout_prob) for _ in range(num_layers)]
    self.transformer_encoder_layers = nn.Sequential(*transformer_encoder_list)
    self.mlp_head = MLPHead(embedding_dim, num_classes)

  def forward(self, x):
    P, C = self.patch_size, self.channel_size
    patches = x.unfold(1, C, C).unfold(2, P, P).unfold(3, P, P)
    patches = patches.contiguous().view(patches.size(0),-1, C*P*P).float()
    patch_embeddings = torch.matmul(patches, self.W)
    batch_size = patch_embeddings.shape[0]
    patch_embeddings = torch.cat((self.class_token.repeat(batch_size,1,1),patch_embeddings),1)
    patch_embeddings = patch_embeddings + self.pos_embedding
    transfomer_encoder_output = self.transformer_encoder_layers(patch_embeddings)
    output_class_token = transfomer_encoder_output[:,0]
    y = self.mlp_head(output_class_token)
    return y



In [None]:
image_size = 224
channel_size = 3
num_classes = 10
dropout_prob = 0.1
n_layer = 12
embedding_dim = 768
n_head = 12
hidden_dim = 3072
image = Image.open('alihassan.png').resize((image_size,image_size))
x = T.PILToTensor()(image)
x = x[None, ...]
patch_size = 16
vision_transformer = VisionTransformer(patch_size, image_size, channel_size, n_layer, embedding_dim, n_head, hidden_dim, dropout_prob, num_classes)
vit_output = vision_transformer(x)
vit_output.shape


In [None]:
vit_output

In [None]:
probalilities = F.softmax(vit_output[0], dim=0)
probalilities

In [None]:
print(torch.sum(probalilities))