In [1]:
import einops
from tqdm.notebook import tqdm
from torchsummary import summary
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torchvision.transforms import Compose,Resize,ToTensor,Normalize,RandomHorizontalFlip,RandomCrop


In [5]:
# set hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

patch_size = 16
latent_size = 768
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
size = 224


epochs = 10
base_lr = 10e-3
weight_decay = 0.03
batch_size = 8

cuda


Implementation of Input Linear Projection

In [6]:
class InputEmbedding(nn.Module):
    def __init__(self,patch_size=patch_size,n_channels=n_channels,device=device,latent_size=latent_size,batch_size=batch_size):
        super(InputEmbedding,self).__init__()
        self.latent_size = latent_size
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.batch_size = batch_size
        self.input_size = self.patch_size * self.patch_size * self.n_channels
        
        # Linear Projection
        self.linearProjection = nn.Linear(self.input_size,self.latent_size)
        
        # class token
        self.class_token = nn.Parameter(torch.randn(self.batch_size,1,self.latent_size)).to(device)
        
        # positional embedding 
        self.pos_embedding = nn.Parameter(torch.randn(self.batch_size,1,self.latent_size)).to(device)
    
    def forward(self,input_data):
        input_data = input_data.to(device)
        
        # Patchify input image
        patches = einops.rearrange(
            input_data, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)',h1 = self.patch_size,w1 = self.patch_size)
        
        # print(input_data.size())
        # print(patches.size())
        
        linear_projection = self.linearProjection(patches).to(device)
        b, n, _ = linear_projection.shape
        
        linear_projection = torch.cat((self.class_token,linear_projection),dim=1)
        pos_embed = einops.repeat(self.pos_embedding, 'b 1 d -> b m d',m=n+1)
        # print(linear_projection.size())
        # print(pos_embed.size())
        
        linear_projection += pos_embed
        
        return linear_projection
        

In [7]:
test_input = torch.randn((8,3,224,224))
test_class = InputEmbedding().to(device)
embed_test = test_class(test_input)

In [8]:
class Encoder(nn.Module):
    def __init__(self,latent_size=latent_size,num_heads=num_heads,device=device,dropout=dropout):
        super(Encoder,self).__init__()
        
        self.latent_size = latent_size
        self.num_heads = num_heads
        self.device = device
        self.dropout = dropout
        
        # Normalization layer 
        self.norm = nn.LayerNorm(self.latent_size)
        
        self.multihead = nn.MultiheadAttention(
            self.latent_size,self.num_heads,self.dropout
        )
        
        self.enc_MLP = nn.Sequential(
            nn.Linear(self.latent_size,self.latent_size * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.latent_size*4,self.latent_size),
            nn.Dropout(dropout)
        )
    
    def forward(self,embedded_pathes):
        firstnorm = self.norm(embedded_pathes)
        attention_out = self.multihead(firstnorm,firstnorm,firstnorm)[0]
        
        # first residual connection
        first_added = attention_out + embedded_pathes
        
        secondnorm_out = self.norm(first_added)
        ff_out = self.enc_MLP(secondnorm_out)
        
        # output = ff_out + first_added
        # print('Embed: ',embedded_pathes.size())
        # print('Output: ',output.size())
        
        return ff_out + first_added    
        

In [9]:
test_encoder = Encoder().to(device)
test_encoder(embed_test)

tensor([[[-0.5248, -0.0131, -0.8531,  ...,  0.0661,  0.4559, -0.1604],
         [-1.3620, -0.3457,  0.7769,  ...,  0.9863,  0.3427,  0.2748],
         [-0.3905,  0.5816, -0.5273,  ...,  2.2097, -0.2858,  0.6337],
         ...,
         [ 0.0492,  1.3259, -1.1039,  ...,  0.7285, -0.4145, -0.5071],
         [-0.3874,  0.5646, -0.2843,  ...,  1.9877, -0.5334, -0.6394],
         [-0.0396, -0.5293,  0.6363,  ...,  0.4858,  0.4059,  0.2588]],

        [[-1.4182, -2.7925,  2.5594,  ..., -0.7089,  0.3096,  2.5393],
         [-0.8160, -1.4472,  1.5562,  ...,  0.2870,  0.8329,  1.3774],
         [-2.2243, -1.3191,  2.2585,  ...,  0.6834,  2.5762,  1.8160],
         ...,
         [-0.9010, -2.0125,  2.2293,  ..., -0.4387,  1.7208,  2.1022],
         [ 0.0475, -1.7603, -0.0063,  ..., -0.4887,  1.7121,  1.7676],
         [-1.3061, -3.3139,  1.8482,  ..., -0.6609,  1.3545,  1.2455]],

        [[-1.6195,  2.5127, -0.4175,  ..., -2.7350, -0.4364,  1.1066],
         [-1.4524,  1.9297, -0.4958,  ..., -0

In [10]:
class ViT(nn.Module):
    def __init__(self,num_encoders=num_encoders,latent_size=latent_size,device=device,num_classes=num_classes,dropout=dropout):
        super(ViT,self).__init__()
        self.num_encoder = num_encoders
        self.latent_size = latent_size
        self.device = device
        self.num_classes = num_classes
        self.dropout = dropout
        
        self.embedding = InputEmbedding()
        
        # create the stack of encoders
        self.encStack = nn.ModuleList([Encoder() for i in range(self.num_encoder)])
        
        self.MLP_head = nn.Sequential(
            nn.LayerNorm(self.latent_size),
            nn.Linear(self.latent_size,self.latent_size),
            nn.Linear(self.latent_size,self.num_classes)
        )
        
    def forward(self,test_input):
        enc_output = self.embedding(test_input)
        
        for enc_layer in self.encStack:
            enc_output = enc_layer(enc_output)
        
        cls_token_embed = enc_output[:,0]
        
        return self.MLP_head(cls_token_embed)    

In [11]:
model = ViT().to(device)
vit_output = model(test_input)
print(vit_output)
print(vit_output.size())

tensor([[ 0.3254, -0.1509,  0.2518,  0.1005, -0.0672,  0.1720, -0.2237,  0.0961,
         -0.2804, -0.3799],
        [-0.0868, -0.1790, -0.0209,  0.0729,  0.2386,  0.1178, -0.1998, -0.1146,
         -0.1288,  0.1116],
        [ 0.1914,  0.3783,  0.4936,  0.1537, -0.2746, -0.4081, -0.0298,  0.3294,
         -0.1753, -0.1010],
        [ 0.0350, -0.3818, -0.4114, -0.0839,  0.0546, -0.0980, -0.1402,  0.2382,
         -0.1924,  0.0042],
        [ 0.1217,  0.3642,  0.0122, -0.1125,  0.5475, -0.3069,  0.3059,  0.2368,
         -0.6082, -0.1529],
        [ 0.0226, -0.1642,  0.1978, -0.5831, -0.3268, -0.3327, -0.3731,  0.4667,
         -0.2083, -0.0684],
        [ 0.2539,  0.1225,  0.7804,  0.5862,  0.0316, -0.3312,  0.1216,  0.5738,
         -0.6443,  0.0989],
        [ 0.4382,  0.3348, -0.1298,  0.0220,  0.5685,  0.4743,  0.1789,  0.3369,
          0.0847,  0.1542]], device='cuda:0', grad_fn=<AddmmBackward0>)
torch.Size([8, 10])
