In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from dataclasses import dataclass
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import math

@dataclass
class SiglipVisionConfig:
    num_transformer_layers: int = 6
    input_channels: int = 3
    input_image_size: int = 32
    image_patch_size: int = 4
    attention_heads_count: int = 8
    model_hidden_size: int = 384
    mlp_intermediate_size: int = 1536
    output_classes_count: int = 10
    layer_norm_epsilon: float = 1e-6
    attention_dropout_rate: float = 0.1
    main_dropout_rate: float = 0.1

def get_cifar10_transforms():

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    ])
    return train_transform, test_transform

In [None]:
class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config

        self.input_channels = config.input_channels
        self.embed_dim = config.model_hidden_size
        self.input_image_size = config.input_image_size
        self.image_patch_size = config.image_patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=self.input_channels,
            out_channels=self.embed_dim,
            kernel_size=self.image_patch_size,
            stride=self.image_patch_size,
            padding='valid'
        )

        self.num_patches = (self.input_image_size // self.image_patch_size) ** 2
        self.num_positions = self.num_patches + 1

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)

        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand(1, -1),
            persistent=False
        )
        self.dropout = nn.Dropout(config.main_dropout_rate)

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        B, C, H, W = pixel_values.shape

        patch_embeds = self.patch_embedding(pixel_values)
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        embeddings = embeddings + self.position_embedding(self.position_ids)
        return self.dropout(embeddings)

In [None]:
class SiglipMLP(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.fc1 = nn.Linear(config.model_hidden_size, config.mlp_intermediate_size)
        self.fc2 = nn.Linear(config.mlp_intermediate_size, config.model_hidden_size)
        self.dropout = nn.Dropout(config.main_dropout_rate)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = F.gelu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states

In [None]:
class SiglipAttention(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.model_hidden_size
        self.num_heads = config.attention_heads_count
        self.head_dim = self.embed_dim // self.num_heads
        self.dropout = config.attention_dropout_rate

        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

    def forward(self, hidden_states):
        B, T, C = hidden_states.shape

        q_states = self.q_proj(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        k_states = self.k_proj(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        v_states = self.v_proj(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)

        attn_weights = (q_states @ k_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(attn_weights, dim=-1)

        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_outs = attn_weights @ v_states

        attn_outs = attn_outs.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(attn_outs)

In [None]:
class SiglipEncoderLayer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.embed_dim = config.model_hidden_size
        self.self_attn = SiglipAttention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        self.mlp = SiglipMLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(hidden_states)
        hidden_states += residual

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states += residual
        return hidden_states

In [None]:
class SiglipEncoder(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_transformer_layers)])

    def forward(self, hidden_states):
        for layer in self.layers:
            hidden_states = layer(hidden_states)
        return hidden_states

In [None]:
class CIFAR10VisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(config)
        self.layer_norm = nn.LayerNorm(config.model_hidden_size, eps=config.layer_norm_epsilon)

        self.classifier = nn.Linear(config.model_hidden_size, config.output_classes_count)

    def forward(self, pixel_values):
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.encoder(hidden_states)
        hidden_states = self.layer_norm(hidden_states)

        cls_token = hidden_states[:, 0]
        logits = self.classifier(cls_token)
        return logits

In [None]:
class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config

        self.input_channels = config.input_channels
        self.embedding_dimension = config.model_hidden_size
        self.input_image_size = config.input_image_size
        self.image_patch_size = config.image_patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=self.input_channels,
            out_channels=self.embedding_dimension,
            kernel_size=self.image_patch_size,
            stride=self.image_patch_size,
            padding='valid'
        )

        self.total_patches = (self.input_image_size // self.image_patch_size) ** 2
        self.total_positions = self.total_patches + 1

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embedding_dimension))
        self.positional_embedding_layer = nn.Embedding(self.total_positions, self.embedding_dimension)

        self.register_buffer(
            "position_indices",
            torch.arange(self.total_positions).expand(1, -1),
            persistent=False
        )
        self.dropout = nn.Dropout(config.main_dropout_rate)

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        B, C, H, W = pixel_values.shape

        patch_embeds = self.patch_embedding(pixel_values)
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        embeddings = embeddings + self.positional_embedding_layer(self.position_indices)
        return self.dropout(embeddings)

In [None]:
class SiglipMLP(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.feed_forward_layer1 = nn.Linear(config.model_hidden_size, config.mlp_intermediate_size)
        self.feed_forward_layer2 = nn.Linear(config.mlp_intermediate_size, config.model_hidden_size)
        self.dropout = nn.Dropout(config.main_dropout_rate)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.feed_forward_layer1(hidden_states)
        hidden_states = F.gelu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.feed_forward_layer2(hidden_states)
        return hidden_states

In [None]:
class SiglipAttention(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.model_hidden_size
        self.num_heads = config.attention_heads_count
        self.attention_head_dimension = self.embed_dim // self.num_heads
        self.dropout = config.attention_dropout_rate

        self.query_projection = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.key_projection = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.value_projection = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.output_projection = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

    def forward(self, hidden_states):
        B, T, C = hidden_states.shape

        q_states = self.query_projection(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        k_states = self.key_projection(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        v_states = self.value_projection(hidden_states).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)

        attention_weights = (q_states @ k_states.transpose(-2, -1)) / math.sqrt(self.attention_head_dimension)
        attention_weights = F.softmax(attention_weights, dim=-1)

        attention_weights = F.dropout(attention_weights, p=self.dropout, training=self.training)

        attention_outputs = attention_weights @ v_states

        attention_outputs = attention_outputs.transpose(1, 2).reshape(B, T, C)
        return self.output_projection(attention_outputs)

In [None]:
class SiglipEncoderLayer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.embed_dim = config.model_hidden_size
        self.self_attention_block = SiglipAttention(config)
        self.first_layer_normalization = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        self.mlp = SiglipMLP(config)
        self.second_layer_normalization = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.first_layer_normalization(hidden_states)
        hidden_states = self.self_attention_block(hidden_states)
        hidden_states += residual

        residual = hidden_states
        hidden_states = self.second_layer_normalization(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states += residual
        return hidden_states

In [None]:
class SiglipEncoder(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.transformer_layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_transformer_layers)])

    def forward(self, hidden_states):
        for layer in self.transformer_layers:
            hidden_states = layer(hidden_states)
        return hidden_states

In [None]:
class CIFAR10VisionTransformer(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(config)
        self.layer_norm = nn.LayerNorm(config.model_hidden_size, eps=config.layer_norm_epsilon)

        self.classification_head = nn.Linear(config.model_hidden_size, config.output_classes_count)

    def forward(self, pixel_values):
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.encoder(hidden_states)
        hidden_states = self.layer_norm(hidden_states)

        cls_token = hidden_states[:, 0]
        logits = self.classification_head(cls_token)
        return logits

In [None]:
def main():
    train_transform, test_transform = get_cifar10_transforms()

    cifar10_train_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=train_transform
    )
    cifar10_test_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=test_transform
    )

    train_data_loader = DataLoader(cifar10_train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
    test_data_loader = DataLoader(cifar10_test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    config = SiglipVisionConfig()
    model = CIFAR10VisionTransformer(config).cuda()
    loss_function = nn.CrossEntropyLoss()
    model_optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
    learning_rate_scheduler = optim.lr_scheduler.CosineAnnealingLR(model_optimizer, T_max=50)



    total_epochs = 50
    best_test_accuracy = 0.0

    for epoch in range(total_epochs):
        model.train()
        current_train_loss = 0.0
        correct_train_predictions = 0
        total_train_samples = 0

        progress_bar = tqdm(train_data_loader, desc=f'Epoch {epoch+1}/{total_epochs}')
        for batch_idx, (batch_images, batch_labels) in enumerate(progress_bar):
            batch_images, batch_labels = batch_images.cuda(), batch_labels.cuda()

            model_optimizer.zero_grad()
            model_outputs = model(batch_images)
            batch_loss = loss_function(model_outputs, batch_labels)
            batch_loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            model_optimizer.step()

            current_train_loss += batch_loss.item()
            _, predicted_labels = model_outputs.max(1)
            total_train_samples += batch_labels.size(0)
            correct_train_predictions += predicted_labels.eq(batch_labels).sum().item()

            progress_bar.set_postfix({
                'Loss': f'{batch_loss.item():.3f}',
                'Acc': f'{100.*correct_train_predictions/total_train_samples:.1f}%'
            })

        learning_rate_scheduler.step()




        model.eval()
        current_test_loss = 0.0
        correct_test_predictions = 0
        total_test_samples = 0

        with torch.no_grad():
            for batch_images, batch_labels in test_data_loader:
                batch_images, batch_labels = batch_images.cuda(), batch_labels.cuda()
                model_outputs = model(batch_images)
                batch_loss = loss_function(model_outputs, batch_labels)

                current_test_loss += batch_loss.item()
                _, predicted_labels = model_outputs.max(1)
                total_test_samples += batch_labels.size(0)
                correct_test_predictions += predicted_labels.eq(batch_labels).sum().item()

        current_train_accuracy = 100. * correct_train_predictions / total_train_samples
        current_test_accuracy = 100. * correct_test_predictions / total_test_samples


        print(f'Epoch {epoch+1}: Train Acc: {current_train_accuracy:.2f}%, Test Acc: {current_test_accuracy:.2f}%')




        if current_test_accuracy > best_test_accuracy:
            best_test_accuracy = current_test_accuracy
            torch.save(model.state_dict(), 'best_cifar10_vit.pth')

    print(f'Best Test Accuracy: {best_test_accuracy:.2f}%')

if __name__ == "__main__":
    main()

100%|██████████| 170M/170M [00:03<00:00, 42.6MB/s]
Epoch 1/50: 100%|██████████| 391/391 [01:14<00:00,  5.26it/s, Loss=1.500, Acc=34.9%]


Epoch 1: Train Acc: 34.94%, Test Acc: 46.56%


Epoch 2/50: 100%|██████████| 391/391 [01:21<00:00,  4.82it/s, Loss=1.414, Acc=46.2%]


Epoch 2: Train Acc: 46.16%, Test Acc: 51.23%


Epoch 3/50: 100%|██████████| 391/391 [01:22<00:00,  4.72it/s, Loss=1.213, Acc=50.8%]


Epoch 3: Train Acc: 50.80%, Test Acc: 52.67%


Epoch 4/50:   0%|          | 1/391 [00:00<02:24,  2.70it/s, Loss=1.449, Acc=53.1%]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d3b2ffed8a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1618, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 1136, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/select

KeyboardInterrupt: 