In [None]:

import math
from dataclasses import dataclass

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from tqdm import tqdm


In [None]:

@dataclass
class SiglipVisionConfig:
    num_hidden_layers: int = 6
    num_channels: int = 3
    image_size: int = 32
    patch_size: int = 4
    num_attention_heads: int = 8
    hidden_size: int = 384
    intermediate_size: int = 1536
    num_classes: int = 10
    layer_norm_eps: float = 1e-6
    attention_dropout: float = 0.1
    dropout: float = 0.1


def get_cifar10_transforms():
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.247, 0.243, 0.261]

    train_ops = [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop(size=32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ]

    test_ops = [
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ]

    train_transform = transforms.Compose(train_ops)
    test_transform = transforms.Compose(test_ops)

    return train_transform, test_transform


In [None]:

class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()

        self.config = config

        self.num_channels = config.num_channels
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=self.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding=0
        )

        patches_per_side = self.image_size // self.patch_size
        self.num_patches = patches_per_side * patches_per_side
        self.num_positions = self.num_patches + 1

        self.cls_token = nn.Parameter(
            torch.zeros(1, 1, self.embed_dim)
        )

        self.position_embedding = nn.Embedding(
            self.num_positions,
            self.embed_dim
        )

        position_ids = torch.arange(self.num_positions).unsqueeze(0)
        self.register_buffer(
            "position_ids",
            position_ids,
            persistent=False
        )

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        B = pixel_values.size(0)

        patch_embeddings = self.patch_embedding(pixel_values)
        patch_embeddings = patch_embeddings.flatten(2)
        patch_embeddings = patch_embeddings.transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, 1, self.embed_dim)
        embeddings = torch.cat([cls_tokens, patch_embeddings], dim=1)

        position_embeddings = self.position_embedding(self.position_ids)
        embeddings = embeddings + position_embeddings

        embeddings = self.dropout(embeddings)
        return embeddings


In [None]:

class SiglipMLP(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()

        self.config = config

        self.fc1 = nn.Linear(
            config.hidden_size,
            config.intermediate_size
        )
        self.fc2 = nn.Linear(
            config.intermediate_size,
            config.hidden_size
        )

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        intermediate_states = self.fc1(hidden_states)
        activated_states = F.gelu(intermediate_states)

        if self.dropout.p > 0.0:
            activated_states = self.dropout(activated_states)

        hidden_states = self.fc2(activated_states)
        return hidden_states


In [None]:

class SiglipAttention(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()

        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.dropout = config.attention_dropout

        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

    def forward(self, hidden_states):
        B, T, C = hidden_states.size()

        q_states = self.q_proj(hidden_states)
        k_states = self.k_proj(hidden_states)
        v_states = self.v_proj(hidden_states)

        q_states = q_states.reshape(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k_states = k_states.reshape(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v_states = v_states.reshape(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        scale = 1.0 / math.sqrt(self.head_dim)
        attn_scores = torch.matmul(q_states, k_states.transpose(-2, -1)) * scale
        attn_weights = F.softmax(attn_scores, dim=-1)

        if self.dropout > 0.0:
            attn_weights = F.dropout(
                attn_weights,
                p=self.dropout,
                training=self.training
            )

        attn_output = torch.matmul(attn_weights, v_states)

        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
        attn_output = attn_output.view(B, T, C)

        return self.out_proj(attn_output)


In [None]:

class SiglipEncoderLayer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()

        self.embed_dim = config.hidden_size

        self.self_attn = SiglipAttention(config)
        self.mlp = SiglipMLP(config)

        self.layer_norm1 = nn.LayerNorm(
            self.embed_dim,
            eps=config.layer_norm_eps
        )
        self.layer_norm2 = nn.LayerNorm(
            self.embed_dim,
            eps=config.layer_norm_eps
        )

    def forward(self, hidden_states):
        normed_states = self.layer_norm1(hidden_states)
        attn_output = self.self_attn(normed_states)
        hidden_states = hidden_states + attn_output

        normed_states = self.layer_norm2(hidden_states)
        mlp_output = self.mlp(normed_states)
        hidden_states = hidden_states + mlp_output

        return hidden_states


In [None]:

class SiglipEncoder(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config

        encoder_layers = []
        for _ in range(config.num_hidden_layers):
            encoder_layers.append(SiglipEncoderLayer(config))

        self.layers = nn.ModuleList(encoder_layers)

    def forward(self, hidden_states):
        for idx in range(len(self.layers)):
            hidden_states = self.layers[idx](hidden_states)
        return hidden_states


In [None]:

class CIFAR10VisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()

        self.config = config

        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(config)

        self.layer_norm = nn.LayerNorm(
            config.hidden_size,
            eps=config.layer_norm_eps
        )

        self.classifier = nn.Linear(
            config.hidden_size,
            config.num_classes
        )

    def forward(self, pixel_values):
        embeddings = self.embeddings(pixel_values)
        encoded_states = self.encoder(embeddings)
        normalized_states = self.layer_norm(encoded_states)

        cls_token = normalized_states[:, 0, :]
        logits = self.classifier(cls_token)

        return logits


In [None]:

def main():
    transforms = get_cifar10_transforms()
    train_transform, test_transform = transforms[0], transforms[1]

    trainset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=train_transform
    )

    testset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=test_transform
    )

    trainloader = DataLoader(
        trainset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    testloader = DataLoader(
        testset,
        batch_size=128,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    classes = (
        'plane', 'car', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    )

    config = SiglipVisionConfig()
    model = CIFAR10VisionTransformer(config)
    model = model.cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=3e-4,
        weight_decay=0.05
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=50
    )

    num_epochs = 50
    best_acc = 0.0

    for epoch in range(num_epochs):
        model.train()

        train_loss = 0.0
        train_correct = 0
        train_total = 0

        pbar = tqdm(trainloader, desc=f'Epoch {epoch + 1}/{num_epochs}')
        for batch_idx, batch in enumerate(pbar):
            inputs, targets = batch
            inputs = inputs.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                max_norm=1.0
            )
            optimizer.step()

            train_loss += loss.item()

            _, predicted = torch.max(outputs, dim=1)
            train_total += targets.size(0)
            train_correct += (predicted == targets).sum().item()

            pbar.set_postfix(
                Loss=f'{loss.item():.3f}',
                Acc=f'{100.0 * train_correct / train_total:.1f}%'
            )

        scheduler.step()

        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for inputs, targets in testloader:
                inputs = inputs.cuda(non_blocking=True)
                targets = targets.cuda(non_blocking=True)

                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = torch.max(outputs, dim=1)
                test_total += targets.size(0)
                test_correct += (predicted == targets).sum().item()

        train_acc = 100.0 * train_correct / train_total
        test_acc = 100.0 * test_correct / test_total

        print(
            f'Epoch {epoch + 1}: '
            f'Train Acc: {train_acc:.2f}%, '
            f'Test Acc: {test_acc:.2f}%'
        )

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'best_cifar10_vit.pth')

    print(f'Best Test Accuracy: {best_acc:.2f}%')


if __name__ == "__main__":
    main()


100%|██████████| 170M/170M [00:03<00:00, 46.6MB/s]
Epoch 1/50: 100%|██████████| 391/391 [01:13<00:00,  5.31it/s, Loss=1.693, Acc=34.8%]


Epoch 1: Train Acc: 34.82%, Test Acc: 46.01%


Epoch 2/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s, Loss=1.284, Acc=45.5%]


Epoch 2: Train Acc: 45.51%, Test Acc: 53.47%


Epoch 3/50: 100%|██████████| 391/391 [01:16<00:00,  5.12it/s, Loss=1.534, Acc=50.1%]


Epoch 3: Train Acc: 50.15%, Test Acc: 54.90%


Epoch 4/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=1.234, Acc=53.5%]


Epoch 4: Train Acc: 53.47%, Test Acc: 55.70%


Epoch 5/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=1.034, Acc=56.0%]


Epoch 5: Train Acc: 56.03%, Test Acc: 59.75%


Epoch 6/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=1.091, Acc=58.4%]


Epoch 6: Train Acc: 58.37%, Test Acc: 61.69%


Epoch 7/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.946, Acc=60.3%]


Epoch 7: Train Acc: 60.27%, Test Acc: 62.74%


Epoch 8/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=1.056, Acc=62.1%]


Epoch 8: Train Acc: 62.07%, Test Acc: 64.19%


Epoch 9/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.948, Acc=63.7%]


Epoch 9: Train Acc: 63.74%, Test Acc: 65.36%


Epoch 10/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s, Loss=1.034, Acc=65.3%]


Epoch 10: Train Acc: 65.30%, Test Acc: 67.13%


Epoch 11/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.935, Acc=67.1%]


Epoch 11: Train Acc: 67.05%, Test Acc: 69.02%


Epoch 12/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.986, Acc=68.6%]


Epoch 12: Train Acc: 68.60%, Test Acc: 69.77%


Epoch 13/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.766, Acc=69.5%]


Epoch 13: Train Acc: 69.52%, Test Acc: 69.67%


Epoch 14/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.691, Acc=71.0%]


Epoch 14: Train Acc: 70.97%, Test Acc: 71.99%


Epoch 15/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.807, Acc=72.2%]


Epoch 15: Train Acc: 72.21%, Test Acc: 72.57%


Epoch 16/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.559, Acc=73.0%]


Epoch 16: Train Acc: 73.04%, Test Acc: 73.45%


Epoch 17/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.966, Acc=74.0%]


Epoch 17: Train Acc: 73.95%, Test Acc: 73.20%


Epoch 18/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.640, Acc=75.2%]


Epoch 18: Train Acc: 75.25%, Test Acc: 74.46%


Epoch 19/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.633, Acc=76.3%]


Epoch 19: Train Acc: 76.33%, Test Acc: 74.16%


Epoch 20/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.638, Acc=77.3%]


Epoch 20: Train Acc: 77.29%, Test Acc: 75.63%


Epoch 21/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.639, Acc=78.3%]


Epoch 21: Train Acc: 78.32%, Test Acc: 75.50%


Epoch 22/50: 100%|██████████| 391/391 [01:16<00:00,  5.08it/s, Loss=0.459, Acc=79.0%]


Epoch 22: Train Acc: 78.95%, Test Acc: 76.23%


Epoch 23/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.395, Acc=80.1%]


Epoch 23: Train Acc: 80.10%, Test Acc: 76.37%


Epoch 24/50: 100%|██████████| 391/391 [01:16<00:00,  5.08it/s, Loss=0.552, Acc=81.1%]


Epoch 24: Train Acc: 81.06%, Test Acc: 76.32%


Epoch 25/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.328, Acc=81.8%]


Epoch 25: Train Acc: 81.84%, Test Acc: 77.05%


Epoch 26/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.527, Acc=82.6%]


Epoch 26: Train Acc: 82.65%, Test Acc: 77.64%


Epoch 27/50: 100%|██████████| 391/391 [01:17<00:00,  5.07it/s, Loss=0.684, Acc=83.8%]


Epoch 27: Train Acc: 83.85%, Test Acc: 77.46%


Epoch 28/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s, Loss=0.362, Acc=84.6%]


Epoch 28: Train Acc: 84.60%, Test Acc: 77.93%


Epoch 29/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.468, Acc=85.5%]


Epoch 29: Train Acc: 85.52%, Test Acc: 78.45%


Epoch 30/50: 100%|██████████| 391/391 [01:16<00:00,  5.08it/s, Loss=0.423, Acc=86.2%]


Epoch 30: Train Acc: 86.21%, Test Acc: 78.28%


Epoch 31/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.396, Acc=87.1%]


Epoch 31: Train Acc: 87.12%, Test Acc: 78.38%


Epoch 32/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.428, Acc=88.0%]


Epoch 32: Train Acc: 88.01%, Test Acc: 78.64%


Epoch 33/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.300, Acc=89.0%]


Epoch 33: Train Acc: 88.98%, Test Acc: 78.31%


Epoch 34/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.259, Acc=89.4%]


Epoch 34: Train Acc: 89.43%, Test Acc: 78.66%


Epoch 35/50: 100%|██████████| 391/391 [01:17<00:00,  5.08it/s, Loss=0.247, Acc=90.2%]


Epoch 35: Train Acc: 90.19%, Test Acc: 78.56%


Epoch 36/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.301, Acc=90.8%]


Epoch 36: Train Acc: 90.78%, Test Acc: 78.32%


Epoch 37/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.259, Acc=91.5%]


Epoch 37: Train Acc: 91.45%, Test Acc: 78.75%


Epoch 38/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.170, Acc=91.7%]


Epoch 38: Train Acc: 91.67%, Test Acc: 78.79%


Epoch 39/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.270, Acc=92.5%]


Epoch 39: Train Acc: 92.52%, Test Acc: 78.92%


Epoch 40/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.151, Acc=92.9%]


Epoch 40: Train Acc: 92.87%, Test Acc: 79.26%


Epoch 41/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.256, Acc=93.1%]


Epoch 41: Train Acc: 93.13%, Test Acc: 79.23%


Epoch 42/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.090, Acc=93.7%]


Epoch 42: Train Acc: 93.69%, Test Acc: 78.74%


Epoch 43/50: 100%|██████████| 391/391 [01:16<00:00,  5.08it/s, Loss=0.038, Acc=93.9%]


Epoch 43: Train Acc: 93.93%, Test Acc: 78.98%


Epoch 44/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.233, Acc=94.2%]


Epoch 44: Train Acc: 94.23%, Test Acc: 78.93%


Epoch 45/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.223, Acc=94.4%]


Epoch 45: Train Acc: 94.45%, Test Acc: 79.15%


Epoch 46/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.133, Acc=94.6%]


Epoch 46: Train Acc: 94.60%, Test Acc: 79.12%


Epoch 47/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.165, Acc=94.6%]


Epoch 47: Train Acc: 94.61%, Test Acc: 79.23%


Epoch 48/50: 100%|██████████| 391/391 [01:17<00:00,  5.07it/s, Loss=0.180, Acc=94.6%]


Epoch 48: Train Acc: 94.61%, Test Acc: 79.26%


Epoch 49/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s, Loss=0.235, Acc=95.0%]


Epoch 49: Train Acc: 95.03%, Test Acc: 79.30%


Epoch 50/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s, Loss=0.106, Acc=94.8%]


Epoch 50: Train Acc: 94.84%, Test Acc: 79.28%
Best Test Accuracy: 79.30%


The Vision Transformer implemented from scratch was successfully trained on the CIFAR-10 dataset.

The model achieved a ***Training accuracy*** of ***~95%*** and a ***Best Test accuracy*** of ***~79%***.

This indicates that the model exhibits a noticeable gap between training and testing performance. This behavior is expected for models trained on small datasets such as CIFAR-10.

