In [3]:
# Import necessary libraries for building and training the Vision Transformer
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from dataclasses import dataclass
import math
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [4]:
# Define the SiglipVisionConfig dataclass and CIFAR-10 data transformation functions
@dataclass
class SiglipVisionConfig:
    hidden_layers: int = 6
    n_channel: int = 3
    img_size: int = 32
    patch_size: int = 4
    hidden_size: int = 384
    num_heads: int = 8

    mlp_intermediate_size: int = 1536
    num_classes: int = 10
    layer_norm_eps: float = 1e-6
    general_dropout_prob: float = 0.1
    attention_dropout_prob: float = 0.1


def cifar_transform():

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    ])
    return train_transform, test_transform

In [5]:
# Implement the SiglipVisionEmbeddings class for creating patch and position embeddings
class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, embed_config: SiglipVisionConfig):
        super().__init__()
        self.embed_config = embed_config

        self.input_channels = embed_config.n_channel
        self.embedding_dimension = embed_config.hidden_size
        self.image_input_size = embed_config.img_size
        self.patch_dimensions = embed_config.patch_size

        self.patch_convolution_layer = nn.Conv2d(
            in_channels=self.input_channels,
            out_channels=self.embedding_dimension,
            kernel_size=self.patch_dimensions,
            stride=self.patch_dimensions,
            padding='valid'
        )


        self.total_patches = (self.image_input_size // self.patch_dimensions) ** 2
        self.total_positions = self.total_patches + 1
        self.class_token = nn.Parameter(torch.zeros(1, 1, self.embedding_dimension))

        self.positional_encoder = nn.Embedding(self.total_positions, self.embedding_dimension)

        self.register_buffer(

            "position_indices",
            torch.arange(self.total_positions).expand(1, -1),
            persistent=False
        )
        self.embedding_dropout_layer = nn.Dropout(embed_config.general_dropout_prob)

    def forward(self, input_pixel_data: torch.FloatTensor) -> torch.Tensor:

        batch_size, channels, height, width = input_pixel_data.shape

        patch_features = self.patch_convolution_layer(input_pixel_data);

        combined_embeddings = patch_features.flatten(2).transpose(1, 2)


        expanded_class_token = self.class_token.expand(batch_size, -1, -1)
        combined_embeddings = torch.cat((expanded_class_token, combined_embeddings), dim=1)

        combined_embeddings = combined_embeddings + self.positional_encoder(self.position_indices)
        return self.embedding_dropout_layer(combined_embeddings)

In [6]:
# Implement the SiglipMLP (Multi-Layer Perceptron) block
class SiglipMLP(nn.Module):
    def __init__(self, mlp_config: SiglipVisionConfig):
        super().__init__()
        self.mlp_config = mlp_config

        self.dense_layer1 = nn.Linear(mlp_config.hidden_size, mlp_config.mlp_intermediate_size)

        self.dense_layer2 = nn.Linear(mlp_config.mlp_intermediate_size, mlp_config.hidden_size)

        self.mlp_dropout_layer = nn.Dropout(mlp_config.general_dropout_prob)

    def forward(self, input_features: torch.Tensor) -> torch.Tensor:
        processed_features = self.dense_layer1(input_features)

        processed_features = F.gelu(processed_features)
        processed_features = self.mlp_dropout_layer(processed_features)
        processed_features = self.dense_layer2(processed_features)
        return processed_features

In [7]:
# Implement the SiglipAttention (Multi-head Self-Attention) mechanism
class SiglipAttention(nn.Module):
    def __init__(self, attn_config: SiglipVisionConfig):
        super().__init__();
        self.attn_config = attn_config
        self.attention_embedding_dim = attn_config.hidden_size
        self.attention_heads_count = attn_config.num_heads
        self.single_head_dimension = self.attention_embedding_dim // self.attention_heads_count
        self.attention_dropout_layer = attn_config.attention_dropout_prob

        self.query_projection = nn.Linear(self.attention_embedding_dim, self.attention_embedding_dim, bias=True)
        self.key_projection = nn.Linear(self.attention_embedding_dim, self.attention_embedding_dim, bias=True)
        self.value_projection = nn.Linear(self.attention_embedding_dim, self.attention_embedding_dim, bias=True)

        self.output_projection = nn.Linear(self.attention_embedding_dim, self.attention_embedding_dim, bias=True)

    def forward(self, attention_input):
        batch_size, sequence_length, feature_dim = attention_input.shape

        query_states = self.query_projection(attention_input).view(batch_size, sequence_length, self.attention_heads_count, feature_dim // self.attention_heads_count).transpose(1, 2)
        key_states = self.key_projection(attention_input).view(batch_size, sequence_length, self.attention_heads_count, feature_dim // self.attention_heads_count).transpose(1, 2)
        value_states = self.value_projection(attention_input).view(batch_size, sequence_length, self.attention_heads_count, feature_dim // self.attention_heads_count).transpose(1, 2)

        attention_scores = (query_states @ key_states.transpose(-2, -1)) / math.sqrt(self.single_head_dimension)
        attention_scores = F.softmax(attention_scores, dim=-1)

        attention_scores = F.dropout(attention_scores, p=self.attention_dropout_layer, training=self.training)

        attention_output = attention_scores @ value_states

        attention_output = attention_output.transpose(1, 2).reshape(batch_size, sequence_length, feature_dim)
        return self.output_projection(attention_output)

In [8]:
# Implement a single SiglipEncoderLayer composed of self-attention and MLP blocks
class SiglipEncoderLayer(nn.Module):
    def __init__(self, layer_config: SiglipVisionConfig):
        super().__init__()
        self.layer_embedding_dim = layer_config.hidden_size
        self.self_attention_block = SiglipAttention(layer_config)
        self.norm_layer1 = nn.LayerNorm(self.layer_embedding_dim, eps=layer_config.layer_norm_eps)
        self.feed_forward_block = SiglipMLP(layer_config)
        self.norm_layer2 = nn.LayerNorm(self.layer_embedding_dim, eps=layer_config.layer_norm_eps)

    def forward(self, layer_input_states):
        skip_connection_input = layer_input_states
        current_states = self.norm_layer1(layer_input_states)
        current_states = self.self_attention_block(current_states)
        current_states += skip_connection_input

        skip_connection_input = current_states
        current_states = self.norm_layer2(current_states)
        current_states = self.feed_forward_block(current_states)
        current_states += skip_connection_input
        return current_states

In [9]:
# Implement the SiglipEncoder by stacking multiple SiglipEncoderLayer instances
class SiglipEncoder(nn.Module):
    def __init__(self, encoder_config: SiglipVisionConfig):
        super().__init__()
        self.encoder_config = encoder_config
        self.encoder_blocks = nn.ModuleList([SiglipEncoderLayer(encoder_config) for _ in range(encoder_config.hidden_layers)])

    def forward(self, input_hidden_states):
        for encoder_block in self.encoder_blocks:
            input_hidden_states = encoder_block(input_hidden_states)
        return input_hidden_states

In [10]:
# Combine all components to form the complete CIFAR10VisionTransformer model
class CIFAR10VisionTransformer(nn.Module):
    def __init__(self, transformer_config: SiglipVisionConfig):
        super().__init__()
        self.transformer_config = transformer_config
        self.vision_embeddings_module = SiglipVisionEmbeddings(transformer_config)
        self.transformer_encoder = SiglipEncoder(transformer_config)
        self.final_layer_norm = nn.LayerNorm(transformer_config.hidden_size, eps=transformer_config.layer_norm_eps)

        self.classification_head = nn.Linear(transformer_config.hidden_size, transformer_config.num_classes)

    def forward(self, model_input_pixels):
        encoded_features = self.vision_embeddings_module(model_input_pixels)
        encoded_features = self.transformer_encoder(encoded_features)
        encoded_features = self.final_layer_norm(encoded_features)

        classification_token_feature = encoded_features[:, 0]
        predicted_logits = self.classification_head(classification_token_feature)
        return predicted_logits

In [12]:
# Main function to load data, initialize, train, and evaluate the Vision Transformer model
def main():
    training_data_transform, testing_data_transform = cifar_transform()

    cifar10_train_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=training_data_transform
    )
    cifar10_test_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=testing_data_transform
    )

    training_dataloader = DataLoader(cifar10_train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
    testing_dataloader = DataLoader(cifar10_test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

    cifar10_class_names = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    model_config = SiglipVisionConfig()
    vit_model = CIFAR10VisionTransformer(model_config).cuda()
    loss_function = nn.CrossEntropyLoss()
    model_optimizer = optim.AdamW(vit_model.parameters(), lr=3e-4, weight_decay=0.05)
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(model_optimizer, T_max=50)



    max_epochs = 25
    best_test_accuracy = 0.0

    for current_epoch in range(max_epochs):
        vit_model.train()
        current_train_loss = 0.0
        correct_train_predictions = 0
        total_train_samples = 0

        progress_bar = tqdm(training_dataloader, desc=f'Epoch {current_epoch+1}/{max_epochs}')
        for batch_index, (batch_inputs, batch_targets) in enumerate(progress_bar):
            batch_inputs, batch_targets = batch_inputs.cuda(), batch_targets.cuda()

            model_optimizer.zero_grad()
            model_outputs = vit_model(batch_inputs)
            batch_loss = loss_function(model_outputs, batch_targets)
            batch_loss.backward()

            torch.nn.utils.clip_grad_norm_(vit_model.parameters(), max_norm=1.0)
            model_optimizer.step()

            current_train_loss += batch_loss.item()
            _, predicted_labels = model_outputs.max(1)
            total_train_samples += batch_targets.size(0);
            correct_train_predictions += predicted_labels.eq(batch_targets).sum().item()

            progress_bar.set_postfix({
                'Loss': f'{batch_loss.item():.3f}',
                'Acc': f'{100.*correct_train_predictions/total_train_samples:.1f}%'
            })

        lr_scheduler.step()




        vit_model.eval()
        current_test_loss = 0.0
        correct_test_predictions = 0
        total_test_samples = 0

        with torch.no_grad():
            for batch_inputs, batch_targets in testing_dataloader:
                batch_inputs, batch_targets = batch_inputs.cuda(), batch_targets.cuda()
                model_outputs = vit_model(batch_inputs)
                batch_loss = loss_function(model_outputs, batch_targets)

                current_test_loss += batch_loss.item()
                _, predicted_labels = model_outputs.max(1)
                total_test_samples += batch_targets.size(0)
                correct_test_predictions += predicted_labels.eq(batch_targets).sum().item()

        training_accuracy = 100. * correct_train_predictions / total_train_samples
        test_accuracy = 100. * correct_test_predictions / total_test_samples


        print(f'Epoch {current_epoch+1}: Train Acc: {training_accuracy:.2f}%, Test Acc: {test_accuracy:.2f}%')




        if test_accuracy > best_test_accuracy:
            best_test_accuracy = test_accuracy
            torch.save(vit_model.state_dict(), 'best_cifar10_vit.pth')

    print(f'Best Test Accuracy: {best_test_accuracy:.2f}%')

if __name__ == "__main__":
    main()

Epoch 1/25: 100%|██████████| 391/391 [01:20<00:00,  4.85it/s, Loss=1.569, Acc=34.9%]


Epoch 1: Train Acc: 34.88%, Test Acc: 46.49%


Epoch 2/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=1.532, Acc=45.8%]


Epoch 2: Train Acc: 45.79%, Test Acc: 50.73%


Epoch 3/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=1.384, Acc=50.4%]


Epoch 3: Train Acc: 50.36%, Test Acc: 54.97%


Epoch 4/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=1.260, Acc=54.0%]


Epoch 4: Train Acc: 53.98%, Test Acc: 56.07%


Epoch 5/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=1.150, Acc=56.0%]


Epoch 5: Train Acc: 56.00%, Test Acc: 60.14%


Epoch 6/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=1.196, Acc=58.3%]


Epoch 6: Train Acc: 58.26%, Test Acc: 60.37%


Epoch 7/25: 100%|██████████| 391/391 [01:22<00:00,  4.76it/s, Loss=1.101, Acc=60.5%]


Epoch 7: Train Acc: 60.53%, Test Acc: 63.32%


Epoch 8/25: 100%|██████████| 391/391 [01:22<00:00,  4.76it/s, Loss=1.011, Acc=62.7%]


Epoch 8: Train Acc: 62.72%, Test Acc: 65.18%


Epoch 9/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=0.950, Acc=65.0%]


Epoch 9: Train Acc: 64.99%, Test Acc: 65.91%


Epoch 10/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=0.837, Acc=66.5%]


Epoch 10: Train Acc: 66.50%, Test Acc: 68.93%


Epoch 11/25: 100%|██████████| 391/391 [01:22<00:00,  4.75it/s, Loss=0.817, Acc=68.2%]


Epoch 11: Train Acc: 68.18%, Test Acc: 68.74%


Epoch 12/25: 100%|██████████| 391/391 [01:22<00:00,  4.76it/s, Loss=0.998, Acc=69.6%]


Epoch 12: Train Acc: 69.56%, Test Acc: 70.73%


Epoch 13/25: 100%|██████████| 391/391 [01:22<00:00,  4.76it/s, Loss=1.150, Acc=71.0%]


Epoch 13: Train Acc: 70.98%, Test Acc: 70.47%


Epoch 14/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=0.960, Acc=72.1%]


Epoch 14: Train Acc: 72.12%, Test Acc: 72.41%


Epoch 15/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=0.805, Acc=73.1%]


Epoch 15: Train Acc: 73.13%, Test Acc: 72.67%


Epoch 16/25: 100%|██████████| 391/391 [01:22<00:00,  4.75it/s, Loss=0.745, Acc=74.1%]


Epoch 16: Train Acc: 74.10%, Test Acc: 73.65%


Epoch 17/25: 100%|██████████| 391/391 [01:22<00:00,  4.76it/s, Loss=0.716, Acc=75.3%]


Epoch 17: Train Acc: 75.25%, Test Acc: 74.15%


Epoch 18/25: 100%|██████████| 391/391 [01:22<00:00,  4.74it/s, Loss=0.664, Acc=76.1%]


Epoch 18: Train Acc: 76.10%, Test Acc: 73.12%


Epoch 19/25: 100%|██████████| 391/391 [01:22<00:00,  4.73it/s, Loss=0.594, Acc=77.0%]


Epoch 19: Train Acc: 76.98%, Test Acc: 74.80%


Epoch 20/25: 100%|██████████| 391/391 [01:22<00:00,  4.75it/s, Loss=0.661, Acc=77.9%]


Epoch 20: Train Acc: 77.94%, Test Acc: 75.66%


Epoch 21/25: 100%|██████████| 391/391 [01:22<00:00,  4.76it/s, Loss=0.577, Acc=78.9%]


Epoch 21: Train Acc: 78.86%, Test Acc: 76.00%


Epoch 22/25: 100%|██████████| 391/391 [01:22<00:00,  4.76it/s, Loss=0.500, Acc=79.7%]


Epoch 22: Train Acc: 79.67%, Test Acc: 77.24%


Epoch 23/25: 100%|██████████| 391/391 [01:22<00:00,  4.75it/s, Loss=0.521, Acc=80.7%]


Epoch 23: Train Acc: 80.68%, Test Acc: 77.14%


Epoch 24/25: 100%|██████████| 391/391 [01:22<00:00,  4.76it/s, Loss=0.438, Acc=81.5%]


Epoch 24: Train Acc: 81.48%, Test Acc: 77.31%


Epoch 25/25: 100%|██████████| 391/391 [01:22<00:00,  4.75it/s, Loss=0.360, Acc=82.5%]


Epoch 25: Train Acc: 82.55%, Test Acc: 78.07%
Best Test Accuracy: 78.07%
