In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
import math
import torchvision
from torchvision import datasets, transforms

In [None]:
class LightViT(nn.Module):
  def __init__(self,image_dim, n_patches=7, n_blocks=2, d=8, n_heads=2, num_classes=10,pixel_patch=4):
    super(LightViT, self).__init__()

    self.image_dim = image_dim
    self.n_patches = n_patches
    self.n_blocks = n_blocks
    self.d = d
    self.n_heads = n_heads
    self.num_classes = num_classes
    self.pixel_patch = pixel_patch


    ## Class Members

    ## 1B) Linear Mapping
    self.linear_map = linearmap(self.image_dim,self.n_patches,self.d);


    ## 2A) Learnable Parameter
    self.special_token = nn.Parameter(torch.randn(1,1, self.d));


    ## 2B) Positional embedding




    ## 3) Encoder blocks

    self.encoder = ViTEncoder(self.d, self.n_heads)

    # 5) Classification Head
    self.classifier = nn.Linear(self.d,self.num_classes);


  def forward(self, images):
    ## Extract patches
    patch =  create_patches(images,self.n_patches)



    ## Linear mapping

    linear_output = self.linear_map(patch)





    ## Add classification token
    b, _, _, _ = images.shape
    cls_tokens = repeat(self.special_token, '() n e -> b n e', b=b)
    tokenized_patches = torch.cat([cls_tokens,linear_output], dim=1)



    ## Add positional embeddings
    position_embeddings = get_pos_embeddings(self.d,tokenized_patches)

    input_embeddings = tokenized_patches+position_embeddings



    ## Pass through encoder


    ecoder_output = self.encoder(input_embeddings)




    # Get classification token

    output = ecoder_output[:,0:1,:]
    output_sequeesed = output.squeeze()



    ## Pass through classifier
    classifier_output = self.classifier(output_sequeesed)


    return classifier_output


In [None]:
def create_patches(input_image,patch_size):
  patches = rearrange(input_image, 'b c (h p1) (w p2)  -> b (p1 p2) (h w c)', p1=patch_size, p2=patch_size)    # Create patch of the input image
  return patches

In [None]:
class linearmap(nn.Module):
  def __init__(self,image_dim,n_patches,emb_size):
    super().__init__()
    self.projection = nn.Sequential(nn.Linear((image_dim[2]//n_patches)**2,emb_size))

  def forward(self,x : Tensor) -> Tensor:

    x = self.projection(x)

    return x

In [None]:
def get_pos_embeddings(emb_size, patches_with_clstokens):
    d_model=emb_size
    length=patches_with_clstokens.shape[1]
    b=patches_with_clstokens.shape[0]
    if d_model % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(d_model))

    pe = torch.zeros(length, d_model)
    position = torch.arange(0, length).unsqueeze(1)
    div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
                         -(math.log(10000.0) / d_model)))
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)
    pe = pe.unsqueeze(0)

    pe = repeat(pe, '() n e -> b n e', b=b)
    return pe

In [None]:
class MHSA(nn.Module):
    def __init__(self,d=8, n_heads=2): # d: dimension of embedding spacr, n_head: dimension of attention heads
        super(MHSA, self).__init__()

        self.d = d
        self.n_heads= n_heads

        self.keys = nn.Linear(d, d)
        self.queries = nn.Linear(d, d)
        self.values = nn.Linear(d, d)


        self.projection = nn.Linear(d, d)

    def forward(self,x: Tensor) -> Tensor:

        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.n_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.n_heads)
        values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.n_heads)

        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len


        scaling = self.d ** (1/2)
        attention_score = F.softmax(energy, dim=-1) / scaling

        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', attention_score, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out



In [None]:
class ViTEncoder(nn.Module):
    def __init__(self, hidden_d, n_heads):
        super(ViTEncoder, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d) # Add Layer-Norm
        self.mhsa = MHSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d) # Add another Layer-Norm
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, 4 * hidden_d),
            nn.GELU(),
            nn.Linear(4 * hidden_d, hidden_d)
        )


    def forward(self, x):
        out = self.mhsa.forward(self.norm1(x)) + x     # x is input embedding i.e. output from LightViT.forward
        out = out + self.mlp(self.norm2(out))
        return out

In [1]:
if __name__ == '__main__':
  model = ViTEncoder(hidden_d=8, n_heads=2)

  x = torch.randn(7, 50, 8)
  print(model(x).shape)

In [None]:
def load_mnist_dataset():

    transform = transforms.Compose([
        transforms.ToTensor()
    ])


    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

    train_data_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 64)


    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    test_data_loader = torch.utils.data.DataLoader(test_dataset,batch_size = 64)

    return train_data_loader, test_data_loader

train_dataset, test_dataset = load_mnist_dataset()

In [None]:
for img,lbl in train_dataset:
  image_dim = img.shape
  break

In [None]:
#Define Model

my_model = LightViT(image_dim)


# Define Optimizer

optimizer_used = torch.optim.Adam(my_model.parameters(), lr=0.005)

# Define Loss

loss_criteria = nn.CrossEntropyLoss()

In [None]:
epochs = []
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

def model_training(epoches,model,criteria,optimizer,train_data,test_data):

  for epoch in range(epoches):

    train_loss = 0
    train_correct = 0
    train_total = 0

    for image, labels in train_data:

      optimizer.zero_grad()

      layer_output = model.forward(image)


      loss = criteria(layer_output,labels)


      loss.backward()

      optimizer.step()



      train_loss += loss.item()

      _,predicted = torch.max(layer_output.data,1)


      train_correct += (predicted == labels).sum().item()



      train_total += labels.size(0)


    epoch_train_accuracy = train_correct/train_total
    epoch_train_loss = (train_loss/len(train_data))











    test_loss = 0
    test_correct = 0
    test_total=0

    with torch.no_grad():

      for test_images, test_labels in test_data:


        test_layer_outputs = model.forward(test_images)

        loss = criteria(test_layer_outputs,test_labels)

        test_loss += loss.item()

        _,predicted = torch.max(test_layer_outputs.data,1)

        test_total  += test_labels.size(0)


        test_correct += (predicted == test_labels).sum().item()

    epoch_test_accuracy = test_correct/test_total
    epoch_test_loss =test_loss / len(test_data)

    print('\nEpoch = ', epoch , 'Training Accuracy = ',epoch_train_accuracy, 'Training loss = ',epoch_train_loss )
    print('\nEpoch = ', epoch , 'Testing Accuracy = ',epoch_test_accuracy, 'Testing loss = ',epoch_test_loss )




    test_accuracies.append(epoch_test_accuracy)

    test_losses.append(epoch_test_loss)

    train_losses.append(epoch_train_loss)

    train_accuracies.append(epoch_train_accuracy)

    epochs.append(epoch)


  return train_accuracies,train_losses,test_accuracies,test_losses,epochs




In [2]:
# train_accuracies, train_losses, test_accuracy, test_losses,epoch = model_training(
#     epoches=10, model=my_model, criteria=loss_criteria, optimizer=optimizer_used, train_data=train_dataset,
    test_data=test_dataset)


In [3]:
plt.plot(epoch, train_accuracies, marker='o', label='Train Accuracy')
plt.plot(epoch, test_accuracy, marker='o', label='Test Accuracy')
plt.title('Train Accuracy and Test Accuracy vs. Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()