In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [4]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.utils.data import DataLoader

# # Hyperparameters
# num_classes = 10
# patch_size = 16
# embedding_dim = 768
# num_heads = 12
# num_layers = 12
# learning_rate = 1e-4
# num_epochs = 10
# batch_size = 32

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
print(device)

cuda


In [7]:
# Clear CUDA cache
torch.cuda.empty_cache()


In [8]:
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, patch_size, embedding_dim, num_heads, num_layers):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = nn.Conv2d(3, embedding_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_encoding = nn.Parameter(torch.randn(1, 14 * 14 + 1, embedding_dim))
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.patch_embedding(x)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat((x, self.positional_encoding.repeat(batch_size, 1, 1)), dim=1)
        for layer in self.transformer_layers:
            x = layer(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        return x

In [9]:
# Initialize the model
model = VisionTransformer(num_classes, patch_size, embedding_dim, num_heads, num_layers).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
import itertools
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the hyperparameter space
patch_sizes = [8, 16]
embedding_dims = [256, 512]
num_heads = [4, 8]
num_layers = [4, 6]
learning_rates = [1e-4, 5e-5]
batch_sizes = [16, 32]

# CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

# Function to train and evaluate a model
def train_and_evaluate_model(patch_size, embedding_dim, num_heads, num_layers, learning_rate, batch_size):
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
    
    # Define the Vision Transformer model
    class VisionTransformer(nn.Module):
        def __init__(self, num_classes, patch_size, embedding_dim, num_heads, num_layers):
            super(VisionTransformer, self).__init__()
            self.patch_embedding = nn.Conv2d(3, embedding_dim, kernel_size=patch_size, stride=patch_size)
            self.positional_encoding = nn.Parameter(torch.randn(1, 14 * 14 + 1, embedding_dim))
            self.transformer_layers = nn.ModuleList([
                nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
            ])
            self.fc = nn.Linear(embedding_dim, num_classes)

        def forward(self, x):
            batch_size = x.size(0)
            x = self.patch_embedding(x)
            x = x.flatten(2).transpose(1, 2)
            x = torch.cat((x, self.positional_encoding.repeat(batch_size, 1, 1)), dim=1)
            for layer in self.transformer_layers:
                x = layer(x)
            x = x.mean(dim=1)
            x = self.fc(x)
            return x

    # Initialize the model
    model = VisionTransformer(num_classes=10, patch_size=patch_size, embedding_dim=embedding_dim, 
                              num_heads=num_heads, num_layers=num_layers).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    num_epochs = 5  # Using a smaller number of epochs for hyperparameter search
    best_accuracy = 0.0
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total_samples = 0
        
        # Use tqdm for progress bar
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch + 1}/{num_epochs}')
        
        for i, (images, labels) in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            
            total_loss += loss.item()

            progress_bar.set_postfix({'Loss': loss.item(), 'Accuracy': (correct / total_samples) * 100})

        model.eval()
        correct = 0
        total_samples = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total_samples += labels.size(0)

        # Calculate accuracy
        accuracy = correct / total_samples * 100

        # Print validation accuracy for each epoch
        print(f'Epoch [{epoch + 1}/{num_epochs}], Validation Accuracy: {accuracy:.2f}%')

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            # Save model to /kaggle/working/
            torch.save(model.state_dict(), '/kaggle/working/best_model.pth')
            print(f'Saving model with validation accuracy: {best_accuracy:.2f}%')

    return best_accuracy
best_hyperparameters = None
best_validation_accuracy = 0.0

for patch_size, embedding_dim, num_heads, num_layers, learning_rate, batch_size in itertools.product(
        patch_sizes, embedding_dims, num_heads, num_layers, learning_rates, batch_sizes):
    
    print(f'Running with hyperparameters: patch_size={patch_size}, embedding_dim={embedding_dim}, '
          f'num_heads={num_heads}, num_layers={num_layers}, learning_rate={learning_rate}, batch_size={batch_size}')
    
    validation_accuracy = train_and_evaluate_model(patch_size, embedding_dim, num_heads, num_layers, learning_rate, batch_size)
    
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_hyperparameters = {
            'patch_size': patch_size,
            'embedding_dim': embedding_dim,
            'num_heads': num_heads,
            'num_layers': num_layers,
            'learning_rate': learning_rate,
            'batch_size': batch_size
        }

print(f'Best hyperparameters found: {best_hyperparameters}')
print(f'Best validation accuracy found: {best_validation_accuracy:.2f}%')


Files already downloaded and verified
Files already downloaded and verified
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=4, num_layers=4, learning_rate=0.0001, batch_size=16


Epoch 1/5: 100%|██████████| 3125/3125 [05:13<00:00,  9.98it/s, Loss=1.58, Accuracy=31.3]


Epoch [1/5], Validation Accuracy: 39.06%
Saving model with validation accuracy: 39.06%


Epoch 2/5: 100%|██████████| 3125/3125 [05:16<00:00,  9.86it/s, Loss=1.72, Accuracy=43.6] 


Epoch [2/5], Validation Accuracy: 46.37%
Saving model with validation accuracy: 46.37%


Epoch 3/5: 100%|██████████| 3125/3125 [05:15<00:00,  9.91it/s, Loss=1.55, Accuracy=50.4] 


Epoch [3/5], Validation Accuracy: 51.36%
Saving model with validation accuracy: 51.36%


Epoch 4/5: 100%|██████████| 3125/3125 [05:14<00:00,  9.93it/s, Loss=1.33, Accuracy=54.5] 


Epoch [4/5], Validation Accuracy: 55.74%
Saving model with validation accuracy: 55.74%


Epoch 5/5: 100%|██████████| 3125/3125 [05:13<00:00,  9.98it/s, Loss=1.2, Accuracy=57.2]  


Epoch [5/5], Validation Accuracy: 56.83%
Saving model with validation accuracy: 56.83%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=4, num_layers=4, learning_rate=0.0001, batch_size=32


Epoch 1/5: 100%|██████████| 1563/1563 [05:01<00:00,  5.19it/s, Loss=1.85, Accuracy=30.6]


Epoch [1/5], Validation Accuracy: 37.70%
Saving model with validation accuracy: 37.70%


Epoch 2/5: 100%|██████████| 1563/1563 [04:59<00:00,  5.21it/s, Loss=1.7, Accuracy=42.4] 


Epoch [2/5], Validation Accuracy: 47.31%
Saving model with validation accuracy: 47.31%


Epoch 3/5: 100%|██████████| 1563/1563 [04:58<00:00,  5.23it/s, Loss=1.36, Accuracy=49.3] 


Epoch [3/5], Validation Accuracy: 51.49%
Saving model with validation accuracy: 51.49%


Epoch 4/5: 100%|██████████| 1563/1563 [04:57<00:00,  5.25it/s, Loss=1.2, Accuracy=53.7]  


Epoch [4/5], Validation Accuracy: 53.94%
Saving model with validation accuracy: 53.94%


Epoch 5/5: 100%|██████████| 1563/1563 [04:57<00:00,  5.26it/s, Loss=1.62, Accuracy=56.5] 


Epoch [5/5], Validation Accuracy: 58.24%
Saving model with validation accuracy: 58.24%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=4, num_layers=4, learning_rate=5e-05, batch_size=16


Epoch 1/5: 100%|██████████| 3125/3125 [05:16<00:00,  9.86it/s, Loss=1.79, Accuracy=29.3]


Epoch [1/5], Validation Accuracy: 34.76%
Saving model with validation accuracy: 34.76%


Epoch 2/5: 100%|██████████| 3125/3125 [05:16<00:00,  9.87it/s, Loss=1.58, Accuracy=39.4]


Epoch [2/5], Validation Accuracy: 42.70%
Saving model with validation accuracy: 42.70%


Epoch 3/5: 100%|██████████| 3125/3125 [05:16<00:00,  9.87it/s, Loss=1.95, Accuracy=46]   


Epoch [3/5], Validation Accuracy: 48.83%
Saving model with validation accuracy: 48.83%


Epoch 4/5: 100%|██████████| 3125/3125 [05:15<00:00,  9.89it/s, Loss=1.72, Accuracy=50.3] 


Epoch [4/5], Validation Accuracy: 50.17%
Saving model with validation accuracy: 50.17%


Epoch 5/5: 100%|██████████| 3125/3125 [05:15<00:00,  9.91it/s, Loss=1.42, Accuracy=53.7] 


Epoch [5/5], Validation Accuracy: 55.21%
Saving model with validation accuracy: 55.21%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=4, num_layers=4, learning_rate=5e-05, batch_size=32


Epoch 1/5: 100%|██████████| 1563/1563 [05:01<00:00,  5.19it/s, Loss=1.78, Accuracy=29]  


Epoch [1/5], Validation Accuracy: 35.30%
Saving model with validation accuracy: 35.30%


Epoch 2/5: 100%|██████████| 1563/1563 [05:00<00:00,  5.20it/s, Loss=1.89, Accuracy=37.8]


Epoch [2/5], Validation Accuracy: 42.42%
Saving model with validation accuracy: 42.42%


Epoch 3/5: 100%|██████████| 1563/1563 [05:00<00:00,  5.21it/s, Loss=1.75, Accuracy=43.9]


Epoch [3/5], Validation Accuracy: 45.20%
Saving model with validation accuracy: 45.20%


Epoch 4/5: 100%|██████████| 1563/1563 [04:59<00:00,  5.22it/s, Loss=1.72, Accuracy=48.1]


Epoch [4/5], Validation Accuracy: 49.12%
Saving model with validation accuracy: 49.12%


Epoch 5/5: 100%|██████████| 1563/1563 [04:59<00:00,  5.21it/s, Loss=1.47, Accuracy=51.1] 


Epoch [5/5], Validation Accuracy: 53.18%
Saving model with validation accuracy: 53.18%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=4, num_layers=6, learning_rate=0.0001, batch_size=16


Epoch 1/5: 100%|██████████| 3125/3125 [07:39<00:00,  6.80it/s, Loss=1.58, Accuracy=31.8]


Epoch [1/5], Validation Accuracy: 40.34%
Saving model with validation accuracy: 40.34%


Epoch 2/5: 100%|██████████| 3125/3125 [07:38<00:00,  6.81it/s, Loss=1.69, Accuracy=44.8] 


Epoch [2/5], Validation Accuracy: 46.74%
Saving model with validation accuracy: 46.74%


Epoch 3/5: 100%|██████████| 3125/3125 [07:36<00:00,  6.85it/s, Loss=1.19, Accuracy=51.2] 


Epoch [3/5], Validation Accuracy: 54.17%
Saving model with validation accuracy: 54.17%


Epoch 4/5: 100%|██████████| 3125/3125 [07:35<00:00,  6.86it/s, Loss=1.47, Accuracy=55.8] 


Epoch [4/5], Validation Accuracy: 57.09%
Saving model with validation accuracy: 57.09%


Epoch 5/5: 100%|██████████| 3125/3125 [07:34<00:00,  6.88it/s, Loss=0.975, Accuracy=58.8]


Epoch [5/5], Validation Accuracy: 58.91%
Saving model with validation accuracy: 58.91%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=4, num_layers=6, learning_rate=0.0001, batch_size=32


Epoch 1/5: 100%|██████████| 1563/1563 [07:19<00:00,  3.56it/s, Loss=1.99, Accuracy=31.5]


Epoch [1/5], Validation Accuracy: 38.95%
Saving model with validation accuracy: 38.95%


Epoch 2/5: 100%|██████████| 1563/1563 [07:17<00:00,  3.57it/s, Loss=1.04, Accuracy=43.8]


Epoch [2/5], Validation Accuracy: 45.62%
Saving model with validation accuracy: 45.62%


Epoch 3/5: 100%|██████████| 1563/1563 [07:16<00:00,  3.58it/s, Loss=0.746, Accuracy=50.5]


Epoch [3/5], Validation Accuracy: 53.58%
Saving model with validation accuracy: 53.58%


Epoch 4/5: 100%|██████████| 1563/1563 [07:16<00:00,  3.58it/s, Loss=1.27, Accuracy=55]   


Epoch [4/5], Validation Accuracy: 56.24%
Saving model with validation accuracy: 56.24%


Epoch 5/5: 100%|██████████| 1563/1563 [07:15<00:00,  3.59it/s, Loss=1.03, Accuracy=58.3] 


Epoch [5/5], Validation Accuracy: 60.51%
Saving model with validation accuracy: 60.51%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=4, num_layers=6, learning_rate=5e-05, batch_size=16


Epoch 1/5: 100%|██████████| 3125/3125 [07:40<00:00,  6.78it/s, Loss=1.63, Accuracy=30.3]


Epoch [1/5], Validation Accuracy: 38.15%
Saving model with validation accuracy: 38.15%


Epoch 2/5: 100%|██████████| 3125/3125 [07:40<00:00,  6.79it/s, Loss=1.63, Accuracy=41.4]


Epoch [2/5], Validation Accuracy: 45.31%
Saving model with validation accuracy: 45.31%


Epoch 3/5: 100%|██████████| 3125/3125 [07:39<00:00,  6.80it/s, Loss=1.8, Accuracy=48.2]  


Epoch [3/5], Validation Accuracy: 50.94%
Saving model with validation accuracy: 50.94%


Epoch 4/5: 100%|██████████| 3125/3125 [07:39<00:00,  6.80it/s, Loss=1.25, Accuracy=52]   


Epoch [4/5], Validation Accuracy: 54.43%
Saving model with validation accuracy: 54.43%


Epoch 5/5: 100%|██████████| 3125/3125 [07:38<00:00,  6.82it/s, Loss=1.53, Accuracy=55.2] 


Epoch [5/5], Validation Accuracy: 55.46%
Saving model with validation accuracy: 55.46%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=4, num_layers=6, learning_rate=5e-05, batch_size=32


Epoch 1/5: 100%|██████████| 1563/1563 [07:18<00:00,  3.57it/s, Loss=1.5, Accuracy=30.1] 


Epoch [1/5], Validation Accuracy: 34.47%
Saving model with validation accuracy: 34.47%


Epoch 2/5: 100%|██████████| 1563/1563 [07:18<00:00,  3.56it/s, Loss=1.72, Accuracy=40.4]


Epoch [2/5], Validation Accuracy: 42.93%
Saving model with validation accuracy: 42.93%


Epoch 3/5: 100%|██████████| 1563/1563 [07:18<00:00,  3.57it/s, Loss=1.48, Accuracy=46.6]


Epoch [3/5], Validation Accuracy: 49.26%
Saving model with validation accuracy: 49.26%


Epoch 4/5: 100%|██████████| 1563/1563 [07:18<00:00,  3.56it/s, Loss=1.53, Accuracy=50.6] 


Epoch [4/5], Validation Accuracy: 53.18%
Saving model with validation accuracy: 53.18%


Epoch 5/5: 100%|██████████| 1563/1563 [07:18<00:00,  3.57it/s, Loss=1.19, Accuracy=54]   


Epoch [5/5], Validation Accuracy: 54.97%
Saving model with validation accuracy: 54.97%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=8, num_layers=4, learning_rate=0.0001, batch_size=16


Epoch 1/5: 100%|██████████| 3125/3125 [05:30<00:00,  9.44it/s, Loss=1.66, Accuracy=31.6]


Epoch [1/5], Validation Accuracy: 38.37%
Saving model with validation accuracy: 38.37%


Epoch 2/5: 100%|██████████| 3125/3125 [05:29<00:00,  9.49it/s, Loss=1.19, Accuracy=43.8] 


Epoch [2/5], Validation Accuracy: 49.49%
Saving model with validation accuracy: 49.49%


Epoch 3/5: 100%|██████████| 3125/3125 [05:27<00:00,  9.54it/s, Loss=1.66, Accuracy=50.8] 


Epoch [3/5], Validation Accuracy: 54.21%
Saving model with validation accuracy: 54.21%


Epoch 4/5: 100%|██████████| 3125/3125 [05:27<00:00,  9.56it/s, Loss=1.28, Accuracy=54.9] 


Epoch [4/5], Validation Accuracy: 54.36%
Saving model with validation accuracy: 54.36%


Epoch 5/5: 100%|██████████| 3125/3125 [05:27<00:00,  9.55it/s, Loss=1.02, Accuracy=57.8] 


Epoch [5/5], Validation Accuracy: 58.67%
Saving model with validation accuracy: 58.67%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=8, num_layers=4, learning_rate=0.0001, batch_size=32


Epoch 1/5: 100%|██████████| 1563/1563 [05:11<00:00,  5.02it/s, Loss=1.62, Accuracy=31]  


Epoch [1/5], Validation Accuracy: 39.31%
Saving model with validation accuracy: 39.31%


Epoch 2/5: 100%|██████████| 1563/1563 [05:09<00:00,  5.05it/s, Loss=2.26, Accuracy=42.5]


Epoch [2/5], Validation Accuracy: 47.72%
Saving model with validation accuracy: 47.72%


Epoch 3/5: 100%|██████████| 1563/1563 [05:09<00:00,  5.06it/s, Loss=1.21, Accuracy=49.4] 


Epoch [3/5], Validation Accuracy: 52.34%
Saving model with validation accuracy: 52.34%


Epoch 4/5: 100%|██████████| 1563/1563 [05:08<00:00,  5.07it/s, Loss=1.16, Accuracy=54.1] 


Epoch [4/5], Validation Accuracy: 53.88%
Saving model with validation accuracy: 53.88%


Epoch 5/5: 100%|██████████| 1563/1563 [05:07<00:00,  5.08it/s, Loss=1.75, Accuracy=57.1] 


Epoch [5/5], Validation Accuracy: 57.22%
Saving model with validation accuracy: 57.22%
Running with hyperparameters: patch_size=8, embedding_dim=256, num_heads=8, num_layers=4, learning_rate=5e-05, batch_size=16


Epoch 1/5: 100%|██████████| 3125/3125 [05:30<00:00,  9.45it/s, Loss=1.94, Accuracy=29.6]


Epoch [1/5], Validation Accuracy: 37.16%
Saving model with validation accuracy: 37.16%


Epoch 2/5: 100%|██████████| 3125/3125 [05:30<00:00,  9.46it/s, Loss=1.23, Accuracy=40.2] 


Epoch [2/5], Validation Accuracy: 44.97%
Saving model with validation accuracy: 44.97%


Epoch 3/5: 100%|██████████| 3125/3125 [05:29<00:00,  9.47it/s, Loss=1.53, Accuracy=46.8] 


Epoch [3/5], Validation Accuracy: 47.76%
Saving model with validation accuracy: 47.76%


Epoch 4/5: 100%|██████████| 3125/3125 [05:29<00:00,  9.49it/s, Loss=1.37, Accuracy=50.9] 


Epoch [4/5], Validation Accuracy: 51.91%
Saving model with validation accuracy: 51.91%


Epoch 5/5:  77%|███████▋  | 2421/3125 [04:14<01:14,  9.50it/s, Loss=1.34, Accuracy=53.7] 

# As it can been seen, the best hyperparamether is 
> **Validation Accuracy: 60.51%**
> 
> **patch_size=8, embedding_dim=256, num_heads=4, num_layers=6, learning_rate=5e-05, batch_size=16**