In [None]:
import torch
import torch.nn as n
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

In [None]:
import numpy as np
import matplotlib.pyplot as plt

**Split Image to Patches**

In [None]:
class patch_embedding():
  def __init__(self, patch_size):
    self.patch_size = patch_size

  def __call__(self, img):
    patches = img.unfold(1,self.patch_size,self.patch_size)
    patches = img.unfold(2,self.patch_size,self.patch_size)
    patches = patches.reshape(-1,self.patch_size*self.patch_size)
    return patches

**Load Dataset**

In [None]:
patch_size = 4
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomResizedCrop((28,28)),
     transforms.RandomRotation(90),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,)),
     patch_embedding(patch_size)])

In [None]:
# Load the training dataset
train_dataset = torchvision.datasets.MNIST(root='./data',
                               train=True,
                               transform=transform,
                               download=True)

# Load the test dataset
test_dataset = torchvision.datasets.MNIST(root='./data',
                              train=False,
                              transform=transform)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
torch.nn.Linear(16,512)(train_dataset[0][0]).shape

torch.Size([49, 512])

**Patch Embedding**

In [None]:
class ImageEmbeddings(n.Module):
  def __init__(self, size,patch_size,hidden_size,num_patches):
    super().__init__()
    self.projection = n.Linear(size, hidden_size)
    self.class_token = n.Parameter(torch.rand(1,hidden_size))
    self.position_embedding = n.Parameter(torch.rand(1,num_patches+1,hidden_size))

    self.dropout = n.Dropout(0.5)

  def forward(self, x):
    x = self.projection(x)
    class_token = self.class_token.expand(x.shape[0],-1,-1)
    x = torch.cat((class_token,x),dim=1)

    position_embedding = self.position_embedding.expand(x.shape[0],-1,-1)

    x = x + position_embedding
    x = self.dropout(x)
    return x


**Implementation of Attention**

In [None]:
class HeadAttention(n.Module):
  def __init__(self, hidden_size):
    super().__init__()
    self.query = n.Linear(hidden_size, hidden_size)
    self.key = n.Linear(hidden_size, hidden_size)
    self.value = n.Linear(hidden_size, hidden_size)


  def forward(self, x):
    query = self.query(x)
    key = self.key(x)
    value = self.value(x)
    scale = query.shape[-1] ** -0.5

    attention= torch.bmm(query,key.transpose(-1,-2))/scale
    attention = torch.softmax(attention,dim=-1)
    attention = torch.bmm(attention,value)

    return attention

In [None]:
class MultiHeadAttention(n.Module):
  def __init__(self, hidden_size, num_heads):
    super().__init__()
    self.heads = n.ModuleList([HeadAttention(hidden_size) for _ in range(num_heads)])
    self.linear = n.Linear(hidden_size*num_heads, hidden_size)

  def forward(self, x):
    x = [head(x) for head in self.heads]
    x = torch.cat(x,dim=-1)
    x = self.linear(x)

    return x

**The Encoder Block**

In [None]:
class EncoderBlock(n.Module):
  def __init__(self, hidden_size, num_heads):
    super().__init__()
    self.attention = MultiHeadAttention(hidden_size, num_heads)
    self.normAttention = n.LayerNorm(hidden_size)
    self.normMLP = n.LayerNorm(hidden_size)
    self.mlp = n.Sequential(
        n.Linear(hidden_size, 4*hidden_size),
        n.Dropout(0.5),
        n.LeakyReLU(),
        n.Linear(4*hidden_size, hidden_size),
        n.Dropout(0.5)
    )

  def forward(self, x):
    x = x + self.attention(self.normAttention(x))
    x = x + self.mlp(self.normMLP(x))
    return x

**Build the Transformer Model**

In [None]:
class Vit(n.Module):
  def __init__(self, hidden_size, patch_size, num_patches, num_heads,num_classes):
    super().__init__()
    self.image_embedding = ImageEmbeddings(size=patch_size,patch_size=patch_size,hidden_size=hidden_size,num_patches=num_patches)
    self.encoders = n.Sequential(*[EncoderBlock(hidden_size, num_heads) for _ in range(1)])
    self.mlp = n.Linear(hidden_size, num_classes)

  def forward(self, x):
    x = self.image_embedding(x)
    x = self.encoders(x)
    #only output the class_token layer
    x = self.mlp(x[:,0,:])
    return x


In [None]:
num_heads = 6
hidden_size = 128
num_patches = 49
num_classes = 10
patch_size = 16
num_epochs = 10

**Training the model**

In [None]:
criterion = n.CrossEntropyLoss()
model = Vit(hidden_size, patch_size, num_patches, num_heads,num_classes)
model.to('cuda')
optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-2)
torch.set_float32_matmul_precision('medium')
def train(epochs):
  model.train()
  for _ in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
      data, target = data.to('cuda'), target.to('cuda')
      optimizer.zero_grad()
      output = model(data)
      loss = criterion(output, target)
      loss.backward()
      optimizer.step()
      print(loss)
    print('epoch {} finished:'.format(_))


In [None]:
train(num_epochs)

**Evaluating the model**

In [None]:
model.eval()
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to('cuda'), target.to('cuda')
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = correct / len(test_loader.dataset)
print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
    correct, len(test_loader.dataset),
    100. * accuracy))


Test set: Accuracy: 1770/10000 (18%)

