#### Task VIII: Vision Transformer ####

This task implements a simple Vision Transformer (ViT) architecture to classify 16x16 MNIST images into their respective numbers. 

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np

## Data Extraction ##
The data is extracted in much of the same way as for the diffusion model, but instead of having one dataset and data_loader, I split them into train and test datasets with their respective loaders. This is so that the model can be tested on previously unseen images. 

In [4]:
transform = transforms.Compose([
    transforms.Resize((16, 16)),
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

## The Model ## 
The model architecture largely follows this website: https://itp.uni-frankfurt.de/~gros/StudentProjects/WS22_23_VisualTransformer/

They use Tensorflow, but I decided to go with PyTorch to be consistent. I omitted some of the steps such as data augmentation (transformations in PyTorch) due to the fact that the MNIST dataset is already quite large, and I wanted the training to be faster even if it meant less accuracy. I also simplified the architecture quite a bit because I found that a more close to the website implementation was taking awhile. 

In the end, I only had a VisionTransformer block with PatchEmbedding. 

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size=16, patch_size=4, in_channels=1, embed_dim=96):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, image_size=16, patch_size=4, num_classes=10, embed_dim=96, num_heads=4, mlp_dim=2048, num_layers=16):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=mlp_dim),
            num_layers=num_layers
        )
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embedding(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.transformer_encoder(x)
        x = x[:, 0]
        x = self.mlp_head(x)
        return x

## Training ## 
I scaled down the model to 10 epochs (the original code used 100) in the training loop. The testing just outputs the accuracy.

In [1]:
model = VisionTransformer()

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)

def train(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader, 1):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Loss: {epoch_loss:.4f}')

def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(test_loader, 1):
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f'Test Accuracy: {accuracy:.4f}')



## Result ##
The accuracy is 11.35%, which is barely above random. I think this is due to the fact that I both scaled down the model and the number of epochs. My loss stayed at around 2.3 for the whole process, indicating that perhaps the model architecture could be made more efficient. I didn't have too much time to spend on researching ViT techniques, but I do think it's an interesting topic to look into with QML. To make this model quantum, I'd probably consider a quantum transformer to implement into the visual transformer architecture. Unfortunately, I didn't get to play around with the self-attention mechanism in this model, but this is potentially where the quantum advantage could come in.

In [2]:
# Train and test the model
train(model, train_loader, criterion, optimizer)
test(model, test_loader)


Epoch [1/10]
Loss: 2.3173
Epoch [2/10]
Loss: 2.3057
Epoch [3/10]
Loss: 2.3036
Epoch [4/10]
Loss: 2.3028
Epoch [5/10]
Loss: 2.3026
Epoch [6/10]
Loss: 2.3023
Epoch [7/10]
Loss: 2.3022
Epoch [8/10]
Loss: 2.3019
Epoch [9/10]
Loss: 2.3019
Epoch [10/10]
Loss: 2.3018
Test Accuracy: 0.1135
