In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from dataclasses import dataclass
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import math

In [None]:
@dataclass
class SiglipVisionConfig:
    num_hidden_layers: int = 6
    num_channels: int = 3       # channels in input image - 3 - RGB
    image_size: int = 32        # CIFAR10 dataset images are 32x32
    patch_size: int = 4         # 4x4 patches for 32x32 -> 64 patches
    num_attention_heads: int = 8# number of attention heads
    hidden_size: int = 384
    intermediate_size: int = 1536
    num_classes: int = 10       # since CIFAR10 has 10 classes
    layer_norm_eps: float = 1e-6
    attention_dropout: float = 0.1
    dropout: float = 0.1

def get_cifar10_transforms():

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),               # converts the images to a tensor
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])  # normalizes the image tensor
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    ])
    return train_transform, test_transform


In [None]:
class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config

        # basic parameters -
        self.num_channels = config.num_channels
        self.embed_dim = config.hidden_size # each image patch is going to be turned into a vector of 384 dimensions
        self.image_size = config.image_size # image size is 32x32 pixels
        self.patch_size = config.patch_size # each patch has a size of 4x4 pixels

        # convolution used to create patch embeddings
        self.patch_embedding = nn.Conv2d(
            in_channels=self.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding='valid'
        )

        # calculating the number of patches were going to get-
        self.num_patches = (self.image_size // self.patch_size) ** 2  # 32/4 = 8x8 = 64 patches
        self.num_positions = self.num_patches + 1                     # +1 for CLS token

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        # creating the positional embedding layer -> creates a lookup table of size num_patches x embed_dim
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)

        self.register_buffer(
            # storing position indices -
            "position_ids",
            torch.arange(self.num_positions).expand(1, -1),
            persistent=False
        )
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        # getting batch size and image dimensions -
        B, C, H, W = pixel_values.shape # assigning the pixel values to Batch, Channel, Height and Width

        patch_embeds = self.patch_embedding(pixel_values)
        # flattening and reshaping patches
        embeddings = patch_embeds.flatten(2).transpose(1, 2)  # B, num_patches, embed_dim

        # adding CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # adding the position embeddings to patch embeddings
        embeddings = embeddings + self.position_embedding(self.position_ids)
        return self.dropout(embeddings)

In [None]:
# MLP LAYER

class SiglipMLP(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        # fully connected layer 1
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        # fully connected layer 2
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
        # using dropout for regularization -
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        # applying the non linearity activation function -
        hidden_states = F.gelu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states

In [None]:
# ATTENTION LAYER (part of Encoder Layer)

class SiglipAttention(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads # dimensions per head
        self.dropout = config.attention_dropout

        # initialising three linear transformations for key, query and value
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)  # this gives the output projection

    def forward(self, hidden_states):
        B, T, C = hidden_states.shape # B = batch, T = tokens, C = channels

        q_states = self.q_proj(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        k_states = self.k_proj(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        v_states = self.v_proj(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)

        # scaled dot product attention -
        attn_weights = (q_states @ k_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(attn_weights, dim=-1)

        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

        # multiply attention with values -
        attn_outs = attn_weights @ v_states

        attn_outs = attn_outs.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(attn_outs)

In [None]:
# ENCODER LAYER

class SiglipEncoderLayer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = SiglipAttention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        # first residual block
        residual = hidden_states
        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(hidden_states)
        hidden_states += residual

        # second residual block
        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states += residual
        return hidden_states

In [None]:
# FULL ENCODER

class SiglipEncoder(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        # stacking multiple encoder layers -
        self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states):
        for layer in self.layers:
            hidden_states = layer(hidden_states)
        return hidden_states

In [None]:
# complete VISION TRANSFORMER for CIFAR10 classification

class CIFAR10VisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(config)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # classification head
        self.classifier = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, pixel_values):
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.encoder(hidden_states)
        hidden_states = self.layer_norm(hidden_states)

        # Use CLS token (first token) for classification
        cls_token = hidden_states[:, 0]
        logits = self.classifier(cls_token)
        return logits



In [None]:
# DATA LOADING AND TRAINING

def main():
    # loading CIFAR10 dataset
    train_transform, test_transform = get_cifar10_transforms()

    trainset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=train_transform
    )
    testset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=test_transform
    )

    trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
    testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

    # CIFAR10 classes names
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # Model, loss, optimizer
    config = SiglipVisionConfig()
    model = CIFAR10VisionTransformer(config).cuda()
    criterion = nn.CrossEntropyLoss() # loss function
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)



    # Training loop
    num_epochs = 50
    best_acc = 0.0

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.cuda(), targets.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

            pbar.set_postfix({
                'Loss': f'{loss.item():.3f}',
                'Acc': f'{100.*train_correct/train_total:.1f}%'
            })

        scheduler.step()




        # Evaluation
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()

        train_acc = 100. * train_correct / train_total
        test_acc = 100. * test_correct / test_total


        print(f'Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')




        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'best_cifar10_vit.pth')

    print(f'Best Test Accuracy: {best_acc:.2f}%')

if __name__ == "__main__":
    main()


100%|██████████| 170M/170M [00:03<00:00, 46.1MB/s]
Epoch 1/50:  60%|█████▉    | 234/391 [00:47<00:31,  5.00it/s, Loss=1.607, Acc=30.4%]

The Vision Transformer implemented from scratch was successfully trained on the CIFAR-10 dataset.

The model achieved a ***Training accuracy*** of ***~95%*** and a ***Best Test accuracy*** of ***~79%***.

This indicates that the model exhibits a noticeable gap between training and testing performance. This behavior is expected for models trained on small datasets such as CIFAR-10.

