In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [47]:
from dataclasses import dataclass


@dataclass
class Config:
    # Image Embedding Config
    image_size = 256
    patch_size = 16
    num_channels = 3
    hidden_size = 768

    hidden_dropout = 0.1

    num_layers = 12
    num_heads = 8
    num_classes = 10

In [48]:
class PatchEmbedding(nn.Module):
    # Patch image and Linear projection
    def __init__(self, config: Config):
        super().__init__()

        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.num_channels = config.num_channels
        self.hidden_size = config.hidden_size

        assert (
            self.image_size % self.patch_size == 0
        ), "Image dimensions must be divisible by the patch size."

        self.num_patches = (self.image_size // self.patch_size) ** 2

        self.patch_and_projection = nn.Conv2d(
            self.num_channels,
            self.hidden_size,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding=0,
        )

    def forward(self, x):
        # (B, C, H, W) -> (B, hidden_size, H // patch_size, W // patch_size)
        x = self.patch_and_projection(x)
        return x.flatten(2).transpose(1, 2)

In [49]:
class PositionEmbedding(nn.Module):

    def __init__(self, config: Config):
        super().__init__()

        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        self.position_embeddings = nn.Parameter(
            torch.randn(
                1,
                (config.image_size // config.patch_size) ** 2 + 1,
                config.hidden_size,
            )
        )
        self.dropout = nn.Dropout(config.hidden_dropout)

    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.position_embeddings
        return self.dropout(x)

In [50]:
class ImageEmbedding(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.patch_embedding = PatchEmbedding(config)
        self.position_embedding = PositionEmbedding(config)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.position_embedding(x)

        return x

In [51]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.num_heads = config.num_heads
        assert (
            config.hidden_size % self.num_heads == 0
        ), f"Hidden size ({config.hidden_size}) must be divisible by the number of heads ({self.num_heads})"

        self.attention_head_size = config.hidden_size // config.num_heads

        self.qkv_linear = nn.Linear(config.hidden_size, 3 * config.hidden_size)

        self.out_linear = nn.Linear(config.hidden_size, config.hidden_size)

        self.dropout = nn.Dropout(config.hidden_dropout)

    def forward(self, x):
        q, k, v = map(
            lambda t: t.view(
                t.shape[0], t.shape[1], self.num_heads, self.attention_head_size
            ).transpose(1, 2),
            self.qkv_linear(x).chunk(3, dim=-1),
        )

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.attention_head_size**0.5)

        scores = F.softmax(scores, dim=-1)

        attention = torch.matmul(self.dropout(scores), v)

        attention = (
            attention.transpose(1, 2)
            .contiguous()
            .view(x.shape[0], x.shape[1], self.num_heads * self.attention_head_size)
        )

        return self.dropout(self.out_linear(attention))

In [52]:
class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(0.7978845608 * (x + 0.044715 * x**3)))

In [53]:
class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.ln1 = nn.Linear(config.hidden_size, config.hidden_size * 4)
        self.act = GELU()
        self.ln2 = nn.Linear(config.hidden_size * 4, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout)

    def forward(self, x):
        return self.dropout(self.ln2(self.act(self.ln1(x))))

In [54]:
# Layer Normalization
class LayerNormalization(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))

    def forward(self, x):
        # X: (batch_size, seq_len, features)
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)

        return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [55]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()

        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)

        return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [56]:
class EncoderBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.attention = MultiHeadAttention(config)
        self.norm = nn.LayerNorm(config.hidden_size)
        self.mlp = MLP(config)
        self.dropout = nn.Dropout(config.hidden_dropout)

    def forward(self, x):
        x = x + self.dropout(self.attention(self.norm(x)))
        x = x + self.dropout(self.mlp(self.norm(x)))
        return x

In [57]:
class Encoder(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.blocks = nn.ModuleList(
            [EncoderBlock(config) for _ in range(config.num_layers)]
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [58]:
class MLPProjector(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.hidden_size = config.hidden_size
        self.num_classes = config.num_classes

        self.ln = nn.Linear(self.hidden_size, self.num_classes)

    def forward(self, x):
        return self.ln(x[:, 0])  # Just the CLS token

In [59]:
class ViT(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.image_embedding = ImageEmbedding(config)
        self.encoder = Encoder(config)
        self.mlp_head = MLPProjector(config)

    def forward(self, x):
        x = self.image_embedding(x)
        x = self.encoder(x)
        return self.mlp_head(x)

In [60]:
def init_weight(module):
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.constant_(module.bias, 0)

In [61]:
# Test
config = Config()
model = ViT(config)
model.apply(init_weight)

x = torch.randn(8, 3, 256, 256)
output = model(x)

assert output.shape == (
    8,
    config.num_classes,
), f"Output shape is not as expected: {output.shape}"

print("Test passed.")

Test passed.
