### 1 We will implement Vision Transformer (ViT) following the below steps:

1. Patch embedding 
<br> The input images are 2D images, and the input of the self-attention os 1D (sequences). So we need to divide every input image into multiple patches and then embed them into sequences.
2. Mutil-Head Attention
<br> The embedded sequences are passed into a self-attention mechanism. In this process, the attention or dependence between sequences is established
3. Build the Transformer class that contains the Mutil-Head Attention and other activation functions
4. Build the Encoder class that contain multiple Transformer blocks
5. Finally, we build the ViT model based on the Patch embedding and Encoder.

### 2. We test our model on the Cifar 10 and evaluate its performance.

In [14]:
# In this tutorial, we use einops for matrix rearangement of mutiplication
# pip install einops

In [None]:
# Import libraries
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import torch.optim as optim
from torchsummary import summary
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
import os, csv, time
import matplotlib.pyplot as plt
import numpy as np

import collections
try:
    from collections import OrderedDict
except ImportError:
    OrderedDict = dict

### Patch embedding

In [None]:
class Patching(nn.Module):
    def __init__(self, in_channels= 3, img_size = 224, patch_size= 16, embed_size = 768):
      # embed_size = in_channels x patchsize**2
        super(Patching, self).__init__()

        self.patch_size = patch_size
        self.num_path = int(img_size//patch_size)**2
        self.projection = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, embed_size ),
            nn.LayerNorm(embed_size ))

        self.class_token = nn.Parameter(torch.randn(1,1, embed_size))  
        self.pos_embedding = nn.Parameter(torch.randn(self.num_path + 1, embed_size))

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.projection(x)
        class_token = repeat(self.class_token, '() n e -> b n e', b=b)

        x = torch.cat([class_token, x], dim=1)
        # add position embedding
        x += self.pos_embedding

        return x

# x = torch.rand(1,3,224,224)   
# Patching()(x).shape
# model = Patching()
# summary(model, (3,224,224))

### Multi-Head Attention

In [None]:

class MultiHeadAttention(nn.Module):

  def __init__(self, embed_size, num_heads, dropout = 0):

    super(MultiHeadAttention, self).__init__()

    self.emb_size = embed_size
    self.num_heads = num_heads
    self.head_size = embed_size//num_heads

    assert embed_size % num_heads == 0, "embed_size % num_heads should be zero."

    # Determin Wq, Qk and Qv in Attention
    self.keys = nn.Linear(embed_size, self.head_size*num_heads) # (Wk matrix) 
    self.queries = nn.Linear(embed_size,  self.head_size*num_heads) # (Wq matrix) 
    self.values = nn.Linear(embed_size,  self.head_size*num_heads) # (Wv matrix) 

    self.att_drop = nn.Dropout(dropout)
    self.dense = nn.Linear(embed_size, embed_size)

  def forward(self, x):     
    # x.shape = [Batchsize (B) x num_patch (np) x embed_size (ez)] 
    batch_size, np, ez = x.shape
    key = self.keys(x)            # [B x (np x ez)] x [ez x ez] = [B x np x ez] 
    query = self.queries(x)       # [B x (np x ez)] x [ez x ez] = [B x np x ez]
    value = self.values(x)        # [B x (np x ez)] x [ez x ez] = [B x np x ez]

    # split key, query and value in many num_heads
    key = key.view(batch_size, -1, self.num_heads, self.head_size)      # [B x np x h x s]
    query = query.view(batch_size, -1, self.num_heads, self.head_size)  # [B x np x h x s]
    value = value.view(batch_size, -1, self.num_heads, self.head_size)  # [B x np x h x s]

    key = key.permute(2, 0, 1 ,3).contiguous().view(batch_size * self.num_heads, -1, self.head_size) # [(Bh) x np x s]
    query = query.permute(2, 0, 1 ,3).contiguous().view(batch_size * self.num_heads, -1, self.head_size) # [(Bh) x np x s]
    value = value.permute(2, 0, 1 ,3).contiguous().view(batch_size * self.num_heads, -1, self.head_size) # [(Bh) x np x s]
    # Q x K matrix
    score = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(self.head_size)
    soft = F.softmax(score, -1)
    context = torch.bmm(soft, value)
    context = self.att_drop(context)
    # Convert to the original size
    context = context.view(self.num_heads, batch_size, -1, self.head_size) # [h x B x np x s]
    context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.head_size)

    attention = self.dense(context)

    return attention #  [Batchsize (B) x num_patch (np) x embed_size (ez)]
  

# x = torch.rand(1,4 ,32)
# attention = MultiHeadAttention( embed_size=32, num_heads=2)
# summary(attention, (4, 32))

'''
with embed_size=32, num_heads=2
[1,4,32] x W (linear)---> [1,4,32] ---(devide by 2 heads)----> [1 2 4 16] shape of Q, K, V
Soft = QxK [1 2 4 16] x [1 2 16 4].T = [1 2 4 4] 
attention = [1 2 4 4] x [1 2 4 16]---> [1 2 4 16] ---rearrange---> [1, 4, 32] ---(dense)---> [1, 4, 32]
'''

### Transformer Block

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, embed_size, num_heads, expansion, dropout = 0):
    super(TransformerBlock, self).__init__()

    self.norm1 = nn.LayerNorm(embed_size)
    self.mul_attention = MultiHeadAttention(embed_size,num_heads)
    self.drop = nn.Dropout(dropout)
    self.norm2 = nn.LayerNorm(embed_size)
    self.mlp = nn.Sequential(nn.Linear(embed_size, expansion*embed_size),
                            nn.GELU(),
                            nn.Dropout(dropout),
                            nn.Linear(expansion*embed_size, embed_size))
                            
  def forward(self, x):
    out = x + self.drop(self.mul_attention(self.norm1(x)))
    out = out + self.drop(self.mlp(self.norm2(out)))
    return out

# x = torch.rand(1,4 ,32)
# block = TransformerBlock(embed_size =32, num_heads=2, expansion=2)
# print(block(x).shape)
# summary(block, (4, 32))


### Encoder block

In [26]:

class Encoder(nn.Module):
  def __init__(self,embed_size, num_heads, expansion, dropout, depth):
     super(Encoder, self).__init__()

     layers: OrderedDict[str, nn.Module] = OrderedDict()

     for i in range(depth):
       layers[f"encoder_layer_{i}"] = TransformerBlock(embed_size, num_heads, expansion, dropout)
     self.layers = nn.Sequential(layers)
  
  def forward(self, x):
    return self.layers(x)

# x = torch.rand(1,4 ,32)   
# encoder = Encoder(embed_size=32, num_heads=2, expansion=2, dropout=0.2, depth=2)
# print(encoder)
# print(encoder(x).shape)
# summary(encoder, (4, 32)) 

### Vision Transformer (ViT)

In [None]:
class VIT(nn.Module):
  def __init__(self,  in_channels= 3, img_size = 32, patch_size= 4, embed_size = 48, 
               num_heads = 2, expansion = 4, dropout= 0.2, depth = 4, num_classes = 10):
    # embed_size = in_channels x patchsize**2
    super(VIT, self).__init__()
    self.path_embedding = Patching(in_channels, img_size, patch_size, embed_size) 
    self.encoder = Encoder(embed_size, num_heads, expansion, dropout, depth)
    self.num_class = nn.Sequential(Reduce('b n e -> b e', reduction='mean'), 
                                   nn.LayerNorm(embed_size),
                                   nn.Linear(embed_size, num_classes))

  def forward(self, x):
    x = self.path_embedding(x)
    x = self.encoder(x)
    x = self.num_class(x)

    return x


model = VIT(in_channels= 3, img_size = 32, patch_size= 4,
            embed_size = 48, num_heads = 2, expansion = 4,
            dropout= 0.2, depth = 4, num_classes = 10)

x = torch.rand(1,3,32, 32)
summary(model, (3,32, 32))


### By now, we have built the ViT completely. So let's test the model on the Cifar 10 dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using: {device} to train the models")

# model = model.to(device)
net = VIT().to(device)
final_epoch = 2000
batch_size = 32

# Dataset

transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

### Data visualization

In [None]:
image, label = next(iter(train_dataloader))
print(image.shape, label.shape)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(train_dataloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))

# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

### Train model

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(3):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[Epoch: {epoch + 1}, Number of images: {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

### Save model and implement on the test dataset

In [None]:
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
dataiter = iter(test_dataloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

In [None]:
net = VIT()
net.load_state_dict(torch.load(PATH))

In [None]:
outputs = net(images)
_, predicted = torch.max(outputs, 1)

print('Predicted: ', '  '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(4)))

### Test on all test dataset images

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_dataloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [None]:
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

with torch.no_grad():
    for data in test_dataloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

### 