In [1]:
import einops
from tqdm.notebook import tqdm

import torch
from torch import nn
import torchvision
import torch.optim as optim
from torchvision.transforms import Compose,Resize,ToTensor,Normalize,RandomHorizontalFlip,RandomCrop

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

patch_size = 16
latent_size = 768
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
size = 224

epochs = 10
base_lr = 10e-3
weight_decay = 0.03
batch_size = 4

cpu


In [3]:
class InputEmbedding(nn.Module):

    def __init__(self,
                 patch_size = patch_size,
                 latent_size = latent_size,
                 n_channels = n_channels,
                 batch_size = batch_size,
                 device = device):
        super(InputEmbedding,self).__init__()

        self.patch_size = patch_size
        self.latent_size = latent_size
        self.n_channels = n_channels
        self.batch_size = batch_size
        self.device = device

        self.input_size = self.patch_size * self.patch_size * self.n_channels

        # Linear Projection Layer
        self.LinearProjection = nn.Linear(self.input_size,
                                       self.latent_size)
        
        # Class Token
        self.ClassToken = nn.Parameter(torch.randn(self.batch_size,
                                                   1,
                                                   self.latent_size))  # Creates a class token vector of size (4,1,768)
        
        # Positional Embedding
        self.PositionalEmbedding = nn.Parameter(torch.randn(self.batch_size,
                                                            1,
                                                            self.latent_size)) # Creates a positional embedding of size (4,1,768)
        
    def forward(self,
                input_data):
        
        # Patchify the input data --> Convert the image of size (224,224,3) to (196,768) [ where 196 = (224*224)/(16*16) Num_patches, 768 = (16*16*3) Flattened vector of every patch 
        patches = einops.rearrange(
            input_data, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)', h1=self.patch_size, w1=self.patch_size
            ) 
        
        linear_projection = self.LinearProjection(patches)  # Convert the (4,196,768) to (4,196,768) :: the first 768 is just by chance and depend on patch size later one is latent_dim
        b,n,p = linear_projection.shape

        linear_projection = torch.cat([self.ClassToken,linear_projection],dim=1) # Concatentate classtoken to linear project (4,197,768)
        
        positional_embedding = einops.repeat(self.PositionalEmbedding, 'b 1 d -> b m d', m=n+1) # Converted pos_embed (4,1,768) --> (4,197,768) 

        linear_projection += positional_embedding # Added pos_embed to linear_proj (4,196,768)

        return linear_projection

In [4]:
test_input = torch.randn(batch_size,3,224,224).to(device)
test_class = InputEmbedding().to(device)
embed_patches = test_class(test_input)

In [5]:
class EncoderBlock(nn.Module):

    def __init__(self,
                 latent_size = latent_size,
                 num_heads = num_heads,
                 dropout = dropout,
                 device = device):
           super(EncoderBlock,self).__init__()

           self.latent_size = latent_size
           self.num_heads = num_heads
           self.dropout = dropout
           self.device = device

           # Normalization Layer
           self.Normalization = nn.LayerNorm(self.latent_size)

           # Multi-Headed Attention
           self.MultiHead = nn.MultiheadAttention(
                 embed_dim = self.latent_size,
                 num_heads = self.num_heads,
                 dropout = self.dropout
           )

           # MLP Layer
           self.encoder_MLP = nn.Sequential(
                 nn.Linear(self.latent_size,self.latent_size*4),
                 nn.GELU(),
                 nn.Dropout(self.dropout),
                 nn.Linear(self.latent_size*4,self.latent_size),
                 nn.Dropout(self.dropout)
           )

    def forward(self,embed_patches):
          
          # First Normalization Layer output
          first_normout = self.Normalization(embed_patches)

          # Output from the MultiHeadedAttention  
          attention_out = self.MultiHead(first_normout,first_normout,first_normout)[0]

          # First Residual Connection
          residual_out = attention_out + embed_patches

          # Second Norm output  
          second_normout = self.Normalization(residual_out)

          # Output from MLP
          mlp_out = self.encoder_MLP(second_normout)
          
          # Second residual connection
          output = mlp_out + residual_out
          
          return output


In [6]:
enc = EncoderBlock()
output = enc(embed_patches)

In [7]:
class ViT(nn.Module):

    def __init__(self,
                 num_encoders = num_encoders,
                 latent_size = latent_size,
                 num_classes = num_classes,
                 dropout = dropout,
                 device = device,
                 ):
        super(ViT,self).__init__()

        self.num_encoders = num_encoders
        self.num_classes = num_classes
        self.dropout = dropout
        self.device = device
        self.latent_size = latent_size

        # Embedding Layer
        self.embedding = InputEmbedding()

        # Encoder Stack
        self.enc_stack = nn.ModuleList([EncoderBlock() for i in range(self.num_encoders)])

        # VIT-MLP
        self.MLP = nn.Sequential(
            nn.LayerNorm(self.latent_size),
            nn.Linear(self.latent_size,self.latent_size),
            nn.Linear(self.latent_size,self.num_classes)
        )

    def forward(self, test_input):

        # Finding Patch Embeddings
        encoder_output = self.embedding(test_input)

        # Looping through the encoder stack
        for encoder_layer in self.enc_stack:
            encoder_output = encoder_layer(encoder_output)

        # Extracting the class token embedding
        class_token_embedding = encoder_output[:,0]

        # Sending it through the MLP
        output = self.MLP(class_token_embedding)

        return output

In [8]:
model = ViT()

test_input = torch.randn(4,3,224,224)
output = model(test_input)

print(output.shape)
print(output)

torch.Size([4, 10])
tensor([[ 0.0715,  0.5755, -0.0244,  0.1811,  0.4734, -0.2432, -0.4321,  0.3122,
         -0.2350, -0.0411],
        [-0.1588, -0.1795, -0.0497,  0.2420, -0.6304, -0.0514, -0.8325, -0.0392,
         -0.3193,  0.4923],
        [-0.0328,  0.1791, -0.4711,  0.2433, -0.0438, -0.4839, -0.3521,  0.0707,
          0.1584, -0.0077],
        [-0.2587,  0.0469,  0.4143,  0.5155, -0.3633, -0.3819, -0.6325,  0.0724,
          0.0907,  0.6130]], grad_fn=<AddmmBackward0>)


In [9]:
total_params = sum(p.numel() for p in model.parameters())

print(f"Total Trainable Parameters: {total_params}")

Total Trainable Parameters: 1196554


In [10]:
model

ViT(
  (embedding): InputEmbedding(
    (LinearProjection): Linear(in_features=768, out_features=768, bias=True)
  )
  (MLP): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=768, out_features=768, bias=True)
    (2): Linear(in_features=768, out_features=10, bias=True)
  )
)