In [1]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import einops
import torch
from torch import nn

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
patch_size = 16
latent_size = 768
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
img_size = 224

epochs = 30
base_lr = 10e-3
weight_decay = 0.03
batch_size = 4

In [5]:
class InputEmbedding(nn.Module):
    def __init__(self, patch_size=patch_size, n_channels=n_channels, latent_size=latent_size, batch_size=batch_size, device=device): 
        super(InputEmbedding, self).__init__()

        self.patch_size = patch_size
        self.n_channels = n_channels
        self.latent_size = latent_size
        self.batch_size = batch_size
        self.device = device
        self.input_size = self.patch_size * self.patch_size * self.n_channels

        self.linear_projection = nn.Linear(self.input_size, self.latent_size) 

        self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
        self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device) 
    
    def forward(self, img):
        img = img.to(self.device)
        # print(img.shape)

        # img to patches
        patches = einops.rearrange(
            img, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)', h1 = self.patch_size, w1 = self.patch_size
        )
        # print(patches.shape)

        linear_projection = self.linear_projection(patches).to(self.device)
        b, n, _ = linear_projection.shape
        # print(linear_projection.shape)

        linear_projection = torch.cat((self.class_token, linear_projection), dim=1)
        pos_embed = einops.repeat(self.pos_embedding, 'b 1 d -> b m d', m = n+1)
        # print(linear_projection.shape)

        linear_projection += pos_embed

        return linear_projection        

In [6]:
test_input = torch.randn((4, 3, 224, 224))
test_class = InputEmbedding().to(device)
embed_test = test_class(test_input)
embed_test.shape

torch.Size([4, 197, 768])

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self, latent_size=latent_size, num_heads=num_heads, dropout=dropout, device=device):
        super(EncoderBlock, self).__init__()

        self.latent_size = latent_size
        self.num_heads = num_heads
        self.dropout = dropout 
        self.device = device

        self.norm = nn.LayerNorm(self.latent_size)

        self.multihead = nn.MultiheadAttention(
            self.latent_size, self.num_heads, dropout = self.dropout
        )

        self.enc_MLP = nn.Sequential(
            nn.Linear(self.latent_size, self.latent_size*4),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latent_size*4, self.latent_size),
            nn.Dropout(self.dropout)
        )
    
    def forward(self, embedded_patches):
        # first sublayer
        first_norm_out = self.norm(embedded_patches)
        attention_out = self.multihead(first_norm_out, first_norm_out, first_norm_out)[0]
        # residual connection
        first_added_out = attention_out + embedded_patches

        # second sublayer
        second_norm_out = self.norm(first_added_out)
        # encMLP
        encMLP_out = self.enc_MLP(second_norm_out)
        # residual connection
        second_added_out = encMLP_out + first_added_out

        return second_added_out

In [8]:
test_encoder = EncoderBlock().to(device)
test_encoder_out = test_encoder(embed_test)
test_encoder_out.shape

torch.Size([4, 197, 768])

In [9]:
class ViT(nn.Module):
    def __init__(self, num_encoders=num_encoders, latent_size=latent_size, num_classes=num_classes, dropout=dropout, device=device):
        super(ViT, self).__init__()

        self.num_encoders = num_encoders
        self.latent_size = latent_size
        self.num_classes = num_classes
        self.dropout = dropout 
        self.device = device

        self.embedding = InputEmbedding()

        self.enc_stack = nn.ModuleList(
            [EncoderBlock() for _ in range(self.num_encoders)]
        )

        self.MLP_head = nn.Sequential(
            nn.LayerNorm(self.latent_size),
            nn.Linear(self.latent_size, self.latent_size),
            nn.Linear(self.latent_size, self.num_classes)
        )
    
    def forward(self, inp):
        enc_output = self.embedding(inp)

        for enc_layer in self.enc_stack:
            enc_output = enc_layer.forward(enc_output)
        
        cls_token_embedding = enc_output[:, 0]

        pred = self.MLP_head(cls_token_embedding)

        return pred

In [10]:
model = ViT().to(device)
vit_output = model(test_input)
print(test_input.shape)
print(vit_output.shape)

torch.Size([4, 3, 224, 224])
torch.Size([4, 10])
