### This is my implementation of Vision Transformer(https://arxiv.org/pdf/2010.11929.pdf)

![Vit image](vit.png)


In [73]:
import torch
import torch.nn as nn
import numpy as np

In [74]:
class ImageToPatchEmbeddings(nn.Module):
    def __init__(self, latent_dim, patch_size):
        super().__init__()
        self.patch_size = patch_size
        self.latent_dim = latent_dim

        self.lin_projection = nn.Linear(3*self.patch_size*self.patch_size, self.latent_dim, bias = False)

        self.class_embedding = nn.Linear(self.latent_dim, self.latent_dim, bias = False)

        self.learnable_positional_enbedding = nn.Linear(self.latent_dim, self.latent_dim, bias = False)

    def forward(self, x):
        x = x.unfold(-2, self.patch_size, self.patch_size)
        x = x.unfold(-2, self.patch_size, self.patch_size)
        x = x.movedim(1,-3)
        x = x.flatten(1,2)
        x = x.flatten(-3,-1)

        x = self.lin_projection(x)

        pos = self.positions(x.shape[1], x.shape[2])
        pos_embedding = self.learnable_positional_enbedding(pos)

        x = x+pos_embedding

        ones = torch.ones(x.shape[0], 1, self.latent_dim)
        cls_embedding = self.class_embedding(ones)

        embeddings = torch.cat((cls_embedding, x), 1)

        return embeddings
    
    def positions(self, num_patch, latent_dim):
        x = torch.ones(num_patch, latent_dim)
        for i in range(num_patch):
            x[i,:]*=i+1
        
        return x


In [75]:
class CreateQKV(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.WQ = nn.Linear(d_model, d_model, bias= False)
        self.WK = nn.Linear(d_model, d_model, bias= False)
        self.WV = nn.Linear(d_model, d_model, bias= False)

    def forward(self, x):
        return self.WQ(x), self.WK(x), self.WV(x)

In [76]:
    def Attention(query, key, values):
        dk = query.size(1)
        scores = nn.functional.softmax((torch.matmul(query, key.transpose(-2, -1))/np.sqrt(dk)), dim = -1)

        return torch.matmul(scores, values)

In [77]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        self.d_model = d_model
        self.heads = heads

        self.WQ = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WK = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WV = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WO = nn.Linear(self.d_model, self.d_model, bias = False)
    
    def forward(self, query, key, values):
        attn = []
        for i in range(self.heads):
            q = self.WQ[i](query)
            k = self.WK[i](key)
            v = self.WV[i](values)
            
            attn.append(Attention(q, k, v))
        
        cat_attn = torch.cat(attn, dim = -1)

        return self.WO(cat_attn)

In [78]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_heads, latent_dim):
        super().__init__()

        self.layer_norm1 = nn.LayerNorm(latent_dim)
        self.layer_norm2 = nn.LayerNorm(latent_dim)
        self.qkv = CreateQKV(latent_dim)
        self.MSA = MultiHeadAttention(latent_dim, num_heads)
        self.linear = nn.Linear(latent_dim, latent_dim)
        self.activation = nn.GELU()


    def forward(self, x):
        x_norm1 = self.layer_norm1(x)
        q, k, v = self.qkv(x_norm1)
        attention = self.MSA(q, k, v)
        add1 = x+attention
        x_norm2 = self.layer_norm2(add1)
        x_norm2 = self.activation(self.linear(x_norm2))

        return x_norm2+add1



In [79]:
class VIT(nn.Module):
    def __init__(self,patch_size, num_heads, latent_dim, num_classes, num_encoder):
        super().__init__()

        self.embs = ImageToPatchEmbeddings(latent_dim, patch_size)
        self.encoder = nn.Sequential(*[TransformerEncoder(num_heads, latent_dim) for _ in range(num_encoder)])
        self.lin1 = nn.Linear(latent_dim, latent_dim)
        self.act = nn.GELU()
        self.lin2 = nn.Linear(latent_dim, num_classes)
    
    def forward(self, x):
        embds = self.embs(x)
        encds = self.encoder(embds)
        output = self.lin2(self.act(self.lin1(encds[:,0,:])))

        return output

In [83]:
vit = VIT(16, 8, 512, 10, 6)
sample_imgs = torch.rand(128, 3, 128, 128)
vit(sample_imgs).shape

torch.Size([128, 10])