In [3]:
import torch
from torchinfo import summary
from torch import nn
from torch import optim
import torchvision
from torchvision import transforms
import einops

In [38]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
patchsize = 16
latentsize = 768
n_channels = 3
num_heads = 12
dropout = 0.1
num_classes = 10
size = 224
num_encoders = 4

epochs = 10
base_lr = 10e-3
weight_decay = 0.03
batchsize = 4

In [39]:
class InputEmbedding(nn.Module):
    def __init__(self, patchsize=patchsize, n_channels=n_channels, device=device, latentsize=latentsize, batchsize=batchsize):
        super().__init__()
        self.patchsize = patchsize
        self.n_channels = n_channels
        self.device = device
        self.latentsize = latentsize
        self.batchsize = batchsize
        self.inputsize = self.patchsize * self.patchsize * self.n_channels

        # Linear projection
        self.linPro = nn.Linear(self.inputsize, self.latentsize)
        # Class token
        self.class_tok = nn.Parameter(torch.randn(size=[self.batchsize, 1, self.latentsize])).to(self.device)
        # Positional embedding
        self.pos_embedding = nn.Parameter(torch.randn(size=[self.batchsize, 1, self.latentsize])).to(self.device)

    def forward(self, input_data):
        input_data = input_data.to(self.device)

        # Get patches
        patches = einops.rearrange(input_data, 'b (w w1) (h h1) c -> b (w h) (w1 h1 c)', w1=self.patchsize, h1=self.patchsize)
        print(input_data.shape)
        print(patches.shape)
        linear_proj = self.linPro(patches).to(self.device)
        b, n, _ = linear_proj.shape
        linear_proj = torch.cat([self.class_tok, linear_proj], dim=1)
        # print(linear_proj.shape)
        pos_embed = einops.repeat(self.pos_embedding, 'b 1 d -> b m d', m=n+1)
        # print(pos_embed.shape)
        linear_proj += pos_embed
        return linear_proj

In [40]:
test_model = InputEmbedding()
test_data = torch.randn(size=[batchsize, size, size, 3])
test_model(test_data)

torch.Size([4, 224, 224, 3])
torch.Size([4, 196, 768])


tensor([[[ 1.2408, -0.2792,  2.5888,  ...,  0.4907,  0.6831,  1.8021],
         [ 1.4117,  2.2739,  1.2022,  ..., -2.0835,  0.9015,  2.7208],
         [-0.0323,  2.2394,  0.0081,  ..., -0.0370,  1.1567,  2.3511],
         ...,
         [ 0.9875,  2.4814,  1.1338,  ..., -1.5002,  2.0912,  2.8131],
         [ 1.0445,  1.3634,  0.1274,  ...,  0.4798, -0.2692,  2.1926],
         [ 1.1807,  2.2859,  1.1538,  ..., -0.4362,  0.6950,  0.7023]],

        [[ 0.9836,  0.6614, -2.2089,  ..., -0.4724,  1.3556, -0.4839],
         [ 0.6325, -0.8714, -1.7137,  ...,  1.0885, -0.2708, -0.6648],
         [ 1.9617, -0.0686, -0.0846,  ...,  0.4206, -0.1544, -0.6950],
         ...,
         [ 0.5441, -0.9152, -0.4645,  ...,  0.2636, -0.7799,  0.4371],
         [ 1.1925, -0.1778, -0.6451,  ..., -0.4484,  0.0733,  0.2611],
         [ 0.5775,  0.6868, -0.9261,  ..., -0.4567, -0.5782, -0.0293]],

        [[-0.5908, -1.4816,  0.6578,  ...,  2.2974,  0.0587, -1.0836],
         [-0.1214,  2.5687,  0.0375,  ...,  0

In [41]:
embed_test = test_model(test_data)

torch.Size([4, 224, 224, 3])
torch.Size([4, 196, 768])


In [42]:
# Implementing the encoder block
class EncoderBlock(nn.Module):
    def __init__(self, latentsize=latentsize, num_heads=num_heads, dropout=dropout):
        super().__init__()
        self.latentsize = latentsize
        self.num_heads = num_heads
        self.dropout = dropout

        self.norm = nn.LayerNorm(self.latentsize)
        self.multihead = nn.MultiheadAttention(self.latentsize, self.num_heads, dropout)
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.latentsize, self.latentsize*4),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latentsize*4, self.latentsize),
            nn.Dropout(self.dropout)
        )
    def forward(self, embeded_patches):
        firstnum_out = self.norm(embeded_patches)
        attention_out = self.multihead(firstnum_out, firstnum_out, firstnum_out)[0]

        # First residual connection
        first_added = attention_out + embeded_patches

        # Second normalization
        secondnum_out = self.norm(first_added)
        mlp_out = self.enc_mlp(secondnum_out)

        return mlp_out + first_added

In [46]:
enc_test = EncoderBlock()
enc_test(embed_test)

tensor([[[ 1.0016e+00,  6.2812e-02,  3.0892e+00,  ...,  2.0176e-01,
           6.7085e-01,  1.6577e+00],
         [ 1.4674e+00,  2.7044e+00,  1.2845e+00,  ..., -1.9885e+00,
           6.8636e-01,  2.9603e+00],
         [-4.2287e-01,  2.8176e+00,  9.9075e-02,  ...,  1.2024e-01,
           1.3174e+00,  2.3826e+00],
         ...,
         [ 8.8647e-01,  2.7722e+00,  1.1649e+00,  ..., -1.4299e+00,
           2.0114e+00,  3.3571e+00],
         [ 1.0149e+00,  1.9444e+00,  3.4402e-01,  ...,  7.6743e-01,
          -7.3242e-02,  2.6087e+00],
         [ 1.2897e+00,  3.2423e+00,  1.0900e+00,  ..., -8.7784e-02,
           5.7765e-01,  9.1872e-01]],

        [[ 7.3546e-01,  1.2888e+00, -2.0278e+00,  ..., -4.8453e-01,
           1.2954e+00, -6.6643e-01],
         [ 4.2461e-01, -6.6702e-01, -1.7530e+00,  ...,  1.3552e+00,
          -4.8863e-01, -8.9262e-01],
         [ 1.6196e+00,  5.2400e-01,  4.3377e-02,  ...,  4.9845e-01,
          -2.6144e-01, -6.5412e-01],
         ...,
         [ 4.1916e-01, -2

In [47]:
# Put everything together
class ViT(nn.Module):
    def __init__(self, num_encoders=num_encoders, latentsize=latentsize, deice=device, num_classes=num_classes, dropout=dropout):
        super().__init__()
        self.num_encoders = num_encoders
        self.latentsize = self.latentsize
        self.device = device
        self.num_classes = num_classes
        self.dropout = dropout

        self.embedding = InputEmbedding()
        self.encStack = nn.ModuleList([EncoderBlock() for _ in range(self.num_encoders)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.latentsize),
            nn.Linear(self.latentsize, self.latentsize),
            nn.Linear(self.latentsize, self.num_classes)
        )

    def forward(self, test_input):
        enc_output = self.embedding(test_input)

        for enc_layer in self.encStack:
            enc_output = enc_layer(enc_output)

        cls_tok_embed = enc_output[:, 0]
        return self.mlp_head(cls_tok_embed)