In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=100):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # shape (B, C, H/P, W/P)
        x = x.flatten(2)  # flatten the last two dimensions
        x = x.transpose(1, 2)  # shape (B, num_patches, emb_size)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size=768, num_heads=8):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.head_dim = emb_size // num_heads
        assert self.head_dim * num_heads == emb_size, "emb_size must be divisible by num_heads"

        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.fc_out = nn.Linear(emb_size, emb_size)

    def forward(self, x):
        B, N, C = x.shape
        keys = self.keys(x).reshape(B, N, self.num_heads, self.head_dim)
        queries = self.queries(x).reshape(B, N, self.num_heads, self.head_dim)
        values = self.values(x).reshape(B, N, self.num_heads, self.head_dim)

        energy = torch.einsum("bnqd,bnkd->bnqk", [queries, keys]) / math.sqrt(self.head_dim)
        attention = torch.softmax(energy, dim=-1)  # shape (B, num_heads, N, N)

        out = torch.einsum("bnqk,bnkd->bnqd", [attention, values]).reshape(B, N, self.emb_size)
        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, emb_size=768, num_heads=8, forward_expansion=4):
        super().__init__()
        self.attention = MultiHeadAttention(emb_size, num_heads)
        self.norm1 = nn.LayerNorm(emb_size)
        self.norm2 = nn.LayerNorm(emb_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(emb_size, forward_expansion * emb_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * emb_size, emb_size)
        )

    def forward(self, x):
        attention = self.attention(x)
        x = self.norm1(attention + x)
        forward = self.feed_forward(x)
        out = self.norm2(forward + x)
        return out

class VisionTransformer(nn.Module):
    def __init__(self, img_size=100, patch_size=16, num_patches=36, emb_size=768, depth=12, num_heads=8, num_classes=9):
        super().__init__()
        self.patch_embedding = PatchEmbedding(patch_size=patch_size, emb_size=emb_size, img_size=img_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.position_embeddings = nn.Parameter(torch.randn(1, 1 + num_patches, emb_size))

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(emb_size, num_heads)
            for _ in range(depth)
        ])

        self.to_cls_token = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, num_classes)
        )

    def forward(self, x):
        x = self.patch_embedding(x)
        cls_token = self.cls_token.repeat(x.shape[0], 1, 1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.position_embeddings

        for transformer in self.transformer_blocks:
            x = transformer(x)

        cls_token_final = self.to_cls_token(x[:, 0])
        return self.mlp_head(cls_token_final)

# Model instantiation
model = VisionTransformer()
print(model)


VisionTransformer(
  (patch_embedding): PatchEmbedding(
    (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (transformer_blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): MultiHeadAttention(
        (keys): Linear(in_features=768, out_features=768, bias=True)
        (queries): Linear(in_features=768, out_features=768, bias=True)
        (values): Linear(in_features=768, out_features=768, bias=True)
        (fc_out): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): ReLU()
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (to_cls_token): Identity()
  (mlp_head): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (