In [1]:
!pip install torch


Collecting torch
  Downloading torch-2.2.0-cp311-none-macosx_11_0_arm64.whl (59.4 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m24.4 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m0:01[0m:01[0m
Collecting typing-extensions>=4.8.0
  Using cached typing_extensions-4.9.0-py3-none-any.whl (32 kB)
Collecting sympy
  Using cached sympy-1.12-py3-none-any.whl (5.7 MB)
Collecting networkx
  Downloading networkx-3.2.1-py3-none-any.whl (1.6 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m MB/s[0m eta [36m0:00:01[0m
Collecting fsspec
  Downloading fsspec-2024.2.0-py3-none-any.whl (170 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.9/170.9 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
Collecting mpmath>=0.19
  Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)
Installing collected packages: mpmath, typing-exten

In [4]:
import torch
from torch import nn

We are going to be writing ViT architecture from scratch.

We need to have different modules for ViT:
#https://towardsdatascience.com/implementing-vision-transformer-vit-from-scratch-3e192c6155f0
1. PatchEmbeddings
2. Embeddings
3. Attention
4. MHA
5. MLP
6. Block
7. Encoder

In [5]:
class PatchEmbeddings(nn.Module):
    """
    Converts image to patches and then projects them to vector space.
    """
    def __init__(self, config):
        super().__init__()
        self.img_size = config.img_size
        self.patch_size = config.patch_size
        self.input_channels = config.input_channels
        self.output_channels = config.hidden_size
        self.num_patches = (self.img_size // self.patch_size) **2
        self.proj = nn.Conv2d(self.input_channels, self.output_channels, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self,x):
        ## BS, input_channels, img_h, img_w -> BS, num_patches, output_channels
        x = self.proj(x)
        x = x.flatten(2).transpose(1,2)
        return x


        

In [None]:
class Embeddings(nn.Module):
    def __init__(self,config):
        self().__init__()
        self.config = config
        self.patch_embeddings = PatchEmbeddings(config)
        self.cls_token = nn.Parameter(torch.randn(1,1,config.hidden_size))
        self.pos_embeddings = nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches +1, config.hidden_size))
    def forward(self,x):
        x = self.patch_embeddings(x)
        bs, _, _ = x.size()
        cls_token = self.cls_token.expand(bs, -1, -1)
        x = torch.cat((cls_token,x), dim=1)
        x = x + self.position_embeddings
        return x

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_size, attn_head_size, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.attn_head_sie = attn_head_size
        self.query = nn.Linear(hidden_size, attn_hidden_size, bias=bias)
        self.key = nn.Linear(hidden_size, attn_hidden_size, bias=bias)
        self.value = nn.Linear(hidden_size, attn_hidden_size, bias=bias)
    def forward(self,x):
        #bs, seq_len, hidden_size -> bs, seq_len, attn_head_size
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        attn_scores =( torch.matmul(query, key.transpose(-1,-2)))/math.sqrt(self.attn_head_size)
        attn_probs = nn.functional.softmax(attn_scores,dim=-1)
        attn_output = torch.matmul(attn_probs, value)
        return attn_output
        

In [None]:
class MHA(nn.Module):
    def __init__(self,config):
        self.hidden_size = config.hidden_size
        self.num_attn_heads= config.num_attn_heads
        self.attn_head_size = self.hidden_size // self.num_attn_heads
        self.qkv_bias = config.qkv_bias

        self.heads = nn.ModuleList([])
        for _ in range(self.num_attn_heads):
            head = Attention(self.hidden_size, self.attn_head_size, self.qkv_bias)
            self.heads.append(head)

        self.out_proj = nn.Linear(self.hidden_size, self.hidden_size)
    def forward(self,x):
        attn_outputs = [head(x) for head in self.heads]
        attn_output = torch.cat([attn_output for attn_output,_ in attn_outputs], dim=-1)
        attn_output = self.out_proj(attn_output)
        

In [None]:
class MLP(nn.Module):
    def __init__(self,config):
        self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.act = GELU()
        self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self,x):
        x = self.dense1(x)
        x = self.act(x)
        x = self.dense2(x)
        return x
        

In [None]:
class Block(nn.Module):
    def __init__(self,config):
        self.attn = MHA(config)
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.mlp = MLP(config)
        self.ln2 = nn.LayerNorm(config.hidden_size)

    def forward(self,x):
        ln1_op = self.ln1(x)
        attn_out = self.attn(ln1_op)
        x = x + attn_op

        ln2_op = self.ln2(x)
        mlp_out = self.mlp(ln2_op)
        x = x+mlp_out
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self, config):
        self.blocks = nn.ModuleList([])
        for _ in range(config.num_blocks):
            block = Block(config)
            self.blocks.append(block)
    def forward(self,x):
        
        for block in self.blocks:
            x = block(x)
        return x