In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

In [6]:
class PatchEmbedding(nn.Module):
    """
    Layer to take in the input image and do the following:
        1.  Transform grid of image into a sequence of patches.
            Number of patches are decided based on image height,width and
            patch height, width.
        2. Add cls token to the above created sequence of patches in the
            first position
        3. Add positional embedding to the above sequence(after adding cls)
        4. Dropout if needed
    """

    def __init__(self, config):
        super().__init__()

        image_height = config['image_height']
        image_width = config['image_widht']
        image_channels = config['image_channels']
        d_model = config['d_mode']

        self.patch_height = config['patch_height']
        self.patch_width = config['patch_width']

        num_patches = (image_height // self.patch_height) * (image_width // self.patch_width)

        patch_dim = image_channels * self.patch_height * self.patch_width

        self.patch_embedding = nn.Sequential(
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, d_model),
            nn.LayerNorm(d_model)
        )

        self.pos_emd = nn.Embedding(num_patches + 1, d_model)
        self.cls_token = nn.Parameter(torch.randn(d_model))
        
    def forward(self, x):
        batch_size = x.shape[0]

        out = rearrange(x, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)',
                    ph=self.patch_height,
                    pw=self.patch_width)
        out = self.patch_embed(out)
        
        cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=batch_size)
        out = torch.cat((cls_tokens, out), dim=1)
        
        out = self.pos_emd(out)


In [None]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_heads = config['n_heads']
        self.head_dim = config['head_dim']
        self.d_model = config['d_model']

        self.qkv_proj = nn.Linear(self.d_model, 3 * self.n_heads * self.head_dim)
        self.out_proj = nn.Linear(self.d_model, self.d_model)

    def forward(self, x):
        batch_size, num_patches = x.shape[:2]
        Q, K, V = self.qkv_proj(x).split(self.d_model, dim=-1)

        Q = Q.view(batch_size, num_patches, self.n_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, num_patches, self.n_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, num_patches, self.n_heads, self.head_dim).transpose(1, 2)

        scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
        scores = F.softmax(scores, dim=-1)

        out = torch.matmul(scores, V)

        out = out.view(batch_size, num_patches, self.d_model).transpose(1, 2)
        out = self.out_proj(out)

        return out

In [8]:
class TransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.d_model = config['d_model']
        self.d_ff = config['d_ff']

        self.attention_layer = Attention(config)

        self.mlp = nn.Sequential(
            nn.Linear(self.d_model, self.d_ff),
            nn.GeLU(),
            nn.Linear(self.d_ff, self.d_model)
        )

        self.norm1 = nn.LayerNorm(self.d_model)
        self.norm2 = nn.LayerNorm(self.d_model)

        self.dropout1 = nn.Dropout(self.d_model)
        self.dropout2 = nn.Dropout(self.d_model)

    def forward(self, x):
        out = x

        out = out + self.dropout1(self.attention_layer(self.norm1(out)))
        out = out + self.dropout2(self.mlp(self.norm2(out)))

        return out

In [9]:
class VIT(nn.Module):
    def __init__(self, config):
        super().__init__()
        n_layers = config['n_layers']
        emb_dim = config['emb_dim']
        num_classes = config['num_classes']
        self.patch_embed_layer = PatchEmbedding(config)
        self.layers = nn.ModuleList([
            TransformerLayer(config) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(emb_dim)
        self.fc_number = nn.Linear(emb_dim, num_classes)
        
    def forward(self, x):
        out = self.patch_embed_layer(x)
        
        for layer in self.layers:
            out = layer(out)
        out = self.norm(out)
        
        return self.fc_number(out[:, 0])