# Assignment: Vision Transformers on CIFAR10

In [1]:
#imports
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


In [14]:
#loading the dataset
dataset = dset.CIFAR10(root="./data", download=True,
                           transform=transforms.Compose([
                               transforms.Resize(224),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
nc=3

dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)


In [3]:
#checking the availability of cuda devices
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Tasks:
* try to get the best test Accuracy on Cifar10 using a transformer model
* pre-trained models allowed - see [here](https://docs.pytorch.org/vision/main/models/vision_transformer.html) for list of models in TorchVision
* **hint**: just like with the CNN in Week 5 - wee need to change the classification layer to fit our 10 class CIFAR-10 problem before we can fine-tune it...
* **hint**: Transformers need a lot of compute + memory - use the A100 GPU



In [7]:
from torchvision.models import vit_b_16, ViT_B_16_Weights

In [11]:
# Define Vision Transformer Model
weights = ViT_B_16_Weights.DEFAULT
model = vit_b_16(weights=weights)
model.heads.head = nn.Linear(model.heads.head.in_features, 10)
model = model.to(device)

# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [10]:
def train_model(model, train_loader, criterion, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct, total = 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total:.2f}%')

In [12]:
def evaluate_model(model, test_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy on test data: {100 * correct / total:.2f}%')


In [15]:

# Train and Evaluate the Model
train_model(model, dataloader, criterion, optimizer, device, epochs=5)
evaluate_model(model, dataloader, device)

Epoch [1/1], Loss: 2.2075450789288182, Accuracy: 17.24%
Accuracy on test data: 16.52%
