# Implementing Classical Vision Transformer on MNIST dataset

## For this task, we would be implementing the ViT architecture from 'AN IMAGE IS WORTH 16X16 WORDS' paper' ( arXiv:2010.11929v2)

In [1]:
!pip install torch torchvision
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [2]:
import torch
import torch.nn.functional as F

from torch import nn
from einops import rearrange

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        """
        Forward pass for Residual layer.

        Args:
        - x: input tensor.
        - **kwargs: keyword arguments.

        Returns:
        - output tensor of the Residual layer.
        """
        
        
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        """
        Forward pass for PreNorm layer.

        Args:
        - x: input tensor.
        - **kwargs: keyword arguments.

        Returns:
        - output tensor of the PreNorm layer.
        """
        
        
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x):
        """
        Forward pass for FeedForward layer.

        Args:
        - x: input tensor.

        Returns:
        - output tensor of the FeedForward layer.
        """
        
        
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, mask = None):
        """
        Forward pass for Attention layer.

        Args:
        - x: input tensor.
        - mask: optional mask tensor.

        Returns:
        - output tensor of the Attention layer.
        """
        
        
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
        # q,k and v representing queries, keys and values respectively
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, float('-inf'))
            del mask

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
            ]))

    def forward(self, x, mask=None):
        for attn, ff in self.layers:
            x = attn(x, mask=mask)
            x = ff(x)
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
        super().__init__()
        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.to_cls_token = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, num_classes)
        )

    def forward(self, img, mask=None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)

        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.transformer(x, mask)

        x = self.to_cls_token(x[:, 0])
        return self.mlp_head(x)

In [3]:
def train_epoch(model, optimizer, data_loader, loss_history):
    total_samples = len(data_loader.dataset)
    model.train()

    for i, (data, target) in enumerate(data_loader):
        optimizer.zero_grad()
        output = F.log_softmax(model(data), dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
            loss_history.append(loss.item())

In [4]:
import torch
import torchvision

torch.manual_seed(42)

DOWNLOAD_PATH = '/data/mnist'
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_TEST = 1000

transform_mnist = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.1307,), (0.3081,))])

train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True,
                                       transform=transform_mnist)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True)

test_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=False, download=True,
                                      transform=transform_mnist)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE_TEST, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw



In [5]:
def evaluate(model, data_loader, loss_history):
    model.eval()
    
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0

    with torch.no_grad():
        for data, target in data_loader:
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)
            
            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()

    avg_loss = total_loss / total_samples
    loss_history.append(avg_loss)
    print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(total_samples) + ' (' +
          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')

In [6]:
import time
N_EPOCHS = 25

start_time = time.time()
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
            dim=64, depth=6, heads=8, mlp_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
    print('Epoch:', epoch)
    train_epoch(model, optimizer, train_loader, train_loss_history)
    evaluate(model, test_loader, test_loss_history)

print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')

Epoch: 1

Average test loss: 0.1433  Accuracy: 9537/10000 (95.37%)

Epoch: 2

Average test loss: 0.1115  Accuracy: 9663/10000 (96.63%)

Epoch: 3

Average test loss: 0.0978  Accuracy: 9720/10000 (97.20%)

Epoch: 4

Average test loss: 0.1086  Accuracy: 9668/10000 (96.68%)

Epoch: 5

Average test loss: 0.0898  Accuracy: 9721/10000 (97.21%)

Epoch: 6

Average test loss: 0.0946  Accuracy: 9732/10000 (97.32%)

Epoch: 7

Average test loss: 0.0826  Accuracy: 9754/10000 (97.54%)

Epoch: 8

Average test loss: 0.0807  Accuracy: 9765/10000 (97.65%)

Epoch: 9

Average test loss: 0.0697  Accuracy: 9819/10000 (98.19%)

Epoch: 10

Average test loss: 0.0752  Accuracy: 9789/10000 (97.89%)

Epoch: 11

Average test loss: 0.1158  Accuracy: 9712/10000 (97.12%)

Epoch: 12

Average test loss: 0.0767  Accuracy: 9807/10000 (98.07%)

Epoch: 13

Average test loss: 0.0773  Accuracy: 9811/10000 (98.11%)

Epoch: 14

Average test loss: 0.0832  Accuracy: 9789/10000 (97.89%)

Epoch: 15

Average test loss: 0.0863  Accur

# Potential Quantum Vision Transformer Architecture

### The concept of a quantum vision transformer (QVT) is an exciting prospect in the field of quantum machine learning. However,like all quantum machine learning algorithms, we lack in the computational power to realize these models to their fullest for large datasets. 

### To make a quantum vision transformer,we discuss about how the layers of a classical vision transformer can be converted to their corresponding quantum counterparts (layers like convolutional,attention.etc.. and also the feature maps). The main idea to map these concepts to quantum machine learning architectures is to use quantum circuits and quantum gates and manipulate them to mimic the classical layers on quantum data.

### In classical vision transformers, the input image is broken down into patches and then processed using linear self-attention layers. For a Quantum Vision Transformer, these layers could be implemented using quantum circuits that mimic the behavior of classical convolutional filters. In a quantum convolutional layer, the input image is first encoded as a quantum state. This encoding can be achieved using various quantum encoding techniques, such as amplitude encoding (which was implemented in previous tasks) or angle encoding. The encoded image is then processed using a quantum circuit that applies a set of quantum gates, such as rotation gates, controlled-NOT (CNOT) gates, and controlled-phase (CPHASE) gates, to perform convolutional operations on the input image. For example,  a rotation gate can be used to apply a filter that detects edges in the image, a CNOT gate can be used to apply a 3x3 kernel to the input image, where the kernel weights are encoded as the amplitudes of the control qubits and a CPHASE gate can be used to apply a max-pooling operation to the input image, where the maximum amplitude of the target qubit is selected as the output of the pooling operation.

### In a QVT, one could use quantum feature maps to extract features from images. These feature maps could be implemented using quantum circuits that map the input image to a higher-dimensional Hilbert space. This would allow the QVT to extract more complex features from the input image.Quantum self-attention could be implemented using quantum circuits that perform unitary transformations on the quantum states. Similiar to classical self-attention layers, in quantum self attention, the input quantum state is first transformed into a set of query, key, and value quantum states using unitary transformations. The dot product of the query and key quantum states is then computed, and the result is used to weight the value quantum states. The output quantum state is obtained by applying another unitary transformation to the weighted values. These unitary transormations can once again be done by gates like single qubit rotations. The output obtained could be processed to classical output and passed through a FFN to make predictions. (Architecture is made assuming we are working on MNIST classification problems. The post-processing would could also be used for applications like object detection,where the post processing would be different)

## Overview of QVT Architecture:-

### 1) Classical preprocessor : For extracting initial features of the image
### 2) QConv Layer using quantum circuits or gates to perform convolution operation on these features to obtain the feature map.
### 3) Quantum Self-Attention Layer : To capture the relationship between different parts of an input quantum state.
### 4) Post-processing output of QVT into classical information flattened.
### 5) Fully-connected network (FFN) : Flattened array taken as input into the feed forward network to recognize different features in the image and make a prediction about the object in the image.