# Import Libraries

In [1]:
import torch
import torchvision
import torch.nn as nn
import einops
from torchsummary import summary    
import torch.optim as optim
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, RandomHorizontalFlip, RandomCrop



import tqdm.notebook as tqdm

In [2]:
# to make progrss bar work
!jupyter nbextension enable --py widgetsnbextension   

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: ok


# Set Hyper parameters

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

patch_size=16
latent_size=768 
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
size = 224


epochs = 10
base_lr = 0.001
weight_decay = 0.03
batch_size = 1


# Input Embedding


In [10]:
class InputEmbedding(nn.Module):
    def __init__(self, device=device, patch_size=patch_size, n_channels=n_channels, latent_size=latent_size,batch_size=batch_size):
        super(InputEmbedding,self).__init__()
        self.device = device
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.latent_size = latent_size
        self.batch_size = batch_size
        self.input_size = self.patch_size * self.patch_size * self.n_channels #16*16*3
        
        #Linear Projection
        self.linearProjection = nn.Linear(self.input_size, self.latent_size)
        
        #class token
        self.class_token = nn.Parameter(torch.randn(self.batch_size, 1,self.latent_size)).to(self.device)

        # pos embedding
        self.pos_embedding= nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
        
    
    def forward(self,input_data):
        input_data = input_data.to(self.device)
        
        #patchify the image
        patches = einops.rearrange(input_data, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
        """
            Image is represented with 4 dimensions ( batch, channels, height, width ). Now height = x and patch_size similarly for width
            So image dimensions are (batch, channels, height = x*patch_size, width = y*patch_size)
            patches is represented with 3 dimensions ( batch, num_patches = x*y, channels*patch_size*patch_size )
            For each patch we are going to apply linear projection
        
        """
        
        print(f"Size of input data is :{input_data.size()}")
        print(f"Size of patches is :{patches.size()}")
        
        # Linear Projection
        linear_projection = self.linearProjection(patches).to(self.device)
        print(f"Size of Linear Projection is :{linear_projection.size()}")
        b,n,_ = linear_projection.size()
        # pre pend the class token to the linear projections
        linear_projection = torch.cat([self.class_token, linear_projection], dim=1) # we are inserting across the patch dimension which is numbered one
        print(f"Size of linear projection with class token is :{linear_projection.size()}")
        
        # add the positional embedding to all patches. so repeat pos embbedding == number of patches
        pos_embeddings=einops.repeat(self.pos_embedding, 'b 1 d -> b m d',m=n+1)
        print(f"Size of positional embedding is :{self.pos_embedding.size()}")
        
        # add the positional embedding to the linear projection
        linear_projection = linear_projection + pos_embeddings
        print(f"Size of linear projection with positional embedding is :{linear_projection.size()}")
        
        return linear_projection
        
        
                

In [11]:
test_data = torch.randn((1,3,224,224))
test_class = InputEmbedding().to(device)

embedding_patches = test_class(test_data)
embedding_patches

Size of input data is :torch.Size([1, 3, 224, 224])
Size of patches is :torch.Size([1, 196, 768])
Size of Linear Projection is :torch.Size([1, 196, 768])
Size of linear projection with class token is :torch.Size([1, 197, 768])
Size of positional embedding is :torch.Size([1, 1, 768])
Size of linear projection with positional embedding is :torch.Size([1, 197, 768])


tensor([[[ 0.6310, -0.3207,  1.7034,  ...,  1.5155,  1.5132,  0.8933],
         [ 0.1953,  1.3033,  0.5022,  ...,  0.8689, -0.1009,  1.1851],
         [ 0.6021,  1.6609,  1.0280,  ..., -0.1888,  0.4152,  0.1821],
         ...,
         [ 1.0463,  0.5505,  1.1236,  ...,  1.0360, -0.0582,  0.5389],
         [ 1.8190,  0.1885,  0.8947,  ...,  0.3344,  0.7989,  1.1417],
         [-0.4104,  0.9053,  1.7362,  ...,  0.3034,  0.2745,  1.4105]]],
       grad_fn=<AddBackward0>)

# Encoder Block


In [12]:
class EncoderBlock(nn.Module):
    def __init__(self, device=device, latent_size=latent_size, num_heads=num_heads, dropout=dropout):
        super(EncoderBlock,self).__init__()
        self.device = device
        self.latent_size = latent_size
        self.num_heads = num_heads
        self.dropout = dropout
        
        #Normalization Layer
        self.norm = nn.LayerNorm(self.latent_size)
        
        # multi head attention
        self.multihead = nn.MultiheadAttention(self.latent_size, self.num_heads,self.dropout)
        
        # FFL or MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.latent_size, self.latent_size*4),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latent_size*4, self.latent_size), 
            nn.Dropout(self.dropout)
            )  
    
    def forward(self, embedded_patches):
        # 1. Norm
        firstNormOut = self.norm(embedded_patches)
        # 2. multi head attention
        attention_out = self.multihead(firstNormOut,firstNormOut,firstNormOut)[0] 
        # 3. first residual connection
        firstResidual_out = attention_out+embedded_patches
        # 4. second norm
        secondNorm_out = self.norm(firstResidual_out)
        # 5. mlp
        mlp_out = self.mlp(secondNorm_out)
        # 6. second residual connection
        output = mlp_out + firstResidual_out
        
        print(f"Embedded Pathces size is :{embedded_patches.size()}")
        print(f"Output size is: {output.size()}")   
        
        return output

In [13]:
test_encoder = EncoderBlock().to(device)
test_encoder(embedding_patches)

Embedded Pathces size is :torch.Size([1, 197, 768])
Output size is: torch.Size([1, 197, 768])


tensor([[[ 0.3052, -0.1552,  1.3714,  ...,  1.5523,  1.8341,  0.7944],
         [ 0.0802,  0.3924,  1.1418,  ...,  0.8873,  0.1584,  2.0419],
         [ 1.1035,  1.4225,  1.2099,  ...,  0.4294,  1.0284,  1.1855],
         ...,
         [ 1.4826,  0.3611,  0.9429,  ...,  1.2505, -0.1727,  1.3183],
         [ 2.4398,  0.0811,  1.1461,  ...,  0.6752,  1.1753,  1.5884],
         [-0.0375,  0.5218,  2.0174,  ...,  0.0156,  0.5107,  1.5428]]],
       grad_fn=<AddBackward0>)

# Assembling a Vit


In [29]:
class ViT(nn.Module):
    def __init__(self,num_encoders = num_encoders, latent_size=latent_size,patch_size=patch_size,n_channels=n_channels,batch_size = batch_size, device = device, num_classes = num_classes, dropout=dropout, num_heads = num_heads):
        super(ViT,self).__init__()
        self.device = device
        self.num_encoders = num_encoders
        self.latent_size = latent_size
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.dropout = dropout
        self.num_heads = num_heads
        
        # input embedding
        self.inputEmbedding = InputEmbedding(self.device, self.patch_size, self.n_channels, self.latent_size,self.batch_size)
        
        # encoder
        self.encStack = nn.ModuleList([EncoderBlock(self.device, self.latent_size, self.num_heads, self.dropout) for _ in range(self.num_encoders)])
        
        # MLP head
        self.mlpHead=nn.Sequential(
            nn.LayerNorm(self.latent_size),
            nn.Linear(  self.latent_size, self.latent_size),
            nn.Linear(self.latent_size, self.num_classes)
        )
        
    def forward(self,input_data):
        input_data = input_data.to(self.device)
        # get the embedding
        embedding_out = self.inputEmbedding(input_data)
        # pass embedding through encoder stack
        for encoder in self.encStack:
            embedding_out= encoder(embedding_out)
        
        #get the cls token
        cls = embedding_out[:,0]
        # pass through the mlp head
        output = self.mlpHead(cls)
        
        return output

In [31]:
test_ViT = ViT().to(device)
output=test_ViT(test_data)
print(output,output.size())

Size of input data is :torch.Size([1, 3, 224, 224])
Size of patches is :torch.Size([1, 196, 768])
Size of Linear Projection is :torch.Size([1, 196, 768])
Size of linear projection with class token is :torch.Size([1, 197, 768])
Size of positional embedding is :torch.Size([1, 1, 768])
Size of linear projection with positional embedding is :torch.Size([1, 197, 768])
Embedded Pathces size is :torch.Size([1, 197, 768])
Output size is: torch.Size([1, 197, 768])
Embedded Pathces size is :torch.Size([1, 197, 768])
Output size is: torch.Size([1, 197, 768])
Embedded Pathces size is :torch.Size([1, 197, 768])
Output size is: torch.Size([1, 197, 768])
Embedded Pathces size is :torch.Size([1, 197, 768])
Output size is: torch.Size([1, 197, 768])
Embedded Pathces size is :torch.Size([1, 197, 768])
Output size is: torch.Size([1, 197, 768])
Embedded Pathces size is :torch.Size([1, 197, 768])
Output size is: torch.Size([1, 197, 768])
Embedded Pathces size is :torch.Size([1, 197, 768])
Output size is: to