In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.optim import Adam, SGD

import torchvision
import torchvision.datasets as dataset
import torchvision.transforms as transform

import math
import numpy as np

from sklearn.model_selection import train_test_split

In [2]:
# reproducability by fixing the random seed
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f45c0256af0>

In [3]:
# define the VIT model
class TransformerEncoder(nn.Module):
  def __init__(self,n_heads, embed_dim, num_patches, hidden_dim ):
    super().__init__()

    self.n_heads = n_heads
    self.hidden_dim = hidden_dim
    self.embed_dim = embed_dim
    self.num_patches = num_patches

    # input size = bs, dim, patches
    self.norm1 = nn.LayerNorm([ num_patches,embed_dim ])
    self.norm2 = nn.LayerNorm([num_patches,embed_dim])

    self.W_Q = nn.Linear(embed_dim, n_heads*hidden_dim)   # Linear function is applied on the last dimension of the features then I need permutation
    self.W_K = nn.Linear(embed_dim, n_heads*hidden_dim)
    self.W_V = nn.Linear(embed_dim, n_heads*hidden_dim)
    self.fc = nn.Linear(n_heads*hidden_dim, embed_dim)

    self.ML_Linear = nn.Linear(hidden_dim*n_heads, hidden_dim)
    self.MLP = nn.Sequential(nn.Linear(embed_dim,embed_dim), nn.GELU(), nn.Linear(embed_dim,embed_dim))



  def forward(self, x):

      x = self.norm1(x)
      x_ = x

      x_q = self.W_Q(x) # bs, n_patches, hidden_dim * n_heads
      x_q = x_q.reshape(x.size(0), x.size(1),self.n_heads,self.hidden_dim,) # bs, n_patches ,n_heads, hidden_dim
      x_q = x_q.permute(0,2,1,3)

      x_k = self.W_K(x)
      x_k = x_k.reshape(x.size(0), x.size(1), self.n_heads,self.hidden_dim,) # bs, n_patches ,n_heads, hidden_dim
      x_k = x_k.permute(0,2,1,3)

      x_v = self.W_Q(x)
      x_v = x_v.reshape(x.size(0), x.size(1), self.n_heads, self.hidden_dim) # bs, n_patches ,n_heads, hidden_dim
      x_v = x_v.permute(0,2,1,3)


      # compute the attention score
#       x_q = x_q.reshape(x_q.size(0)*x_q.size(1),x_q.size(2),x_q.size(3))
#       x_k = x_k.reshape(x_k.size(0)*x_k.size(1),x_k.size(2),x_k.size(3))
#       x_v = x_v.reshape(x_v.size(0)*x_v.size(1),x_v.size(2),x_v.size(3))

      x_q = torch.transpose(x_q,2,3)
#       scores =  torch.einsum("bij, bkj -> bik", x_q, x_k)
      scores = x_q@x_k
      scores = scores/(self.hidden_dim**(0.5))
      scores = torch.softmax(scores, dim=-1)

      
        
#       attention =  torch.einsum("bik, bkj -> bij", scores, x_v) #bs, n_patches, n_score_weights_per_patch= n_patches of k
      scores = torch.transpose(scores,2,3)
      
      
      attention = x_v@scores
      
    
    # bs,n_patches,n_heads,embed_dim
    
      

      attention = attention.permute(0,2,1,3)
      attention = attention.reshape(attention.size(0),attention.size(1),attention.size(2)*attention.size(3))
      x = self.fc(attention)

      x = x + x_

      x_ = self.norm2(x)
      x = self.MLP(x)
      x = x + x_
      return x







In [4]:

class Vit(nn.Module):
  def __init__(self, patch_size, embed_dim, hidden_dim, n_heads,n_classes):
    super().__init__()

    """
    parameters:

    output:
    """
    self.patch_size = patch_size
    self.linear_embedding = nn.Conv2d(in_channels=1, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, padding=0)
    self.classifier = nn.Linear(embed_dim+3,n_classes)
    self.n_heads = n_heads
    self.hidden_dim = hidden_dim

  def get_positional_embeddings(self, sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result



  def forward(self, x):
    x = x.to(device)
    assert x.size(2) % self.patch_size == 0, "height of the image is not divideable by patch size"
    assert x.size(3) % self.patch_size == 0, "width of the image is not divideable by patch size"
    x = self.linear_embedding(x)
    
    r = self.get_positional_embeddings(x.size(2)*x.size(3)+1, 3) # positional feature
    r = r.to(device)
    
    r = r.transpose(0,1)
    r = r.unsqueeze(0)
    r = r.repeat(x.size(0),1,1) # 100, 1, 65
    
    x = x.flatten(start_dim=2)
    
    # add dummy token
    x_CLS_token = nn.Parameter(torch.ones(x.size(0),x.size(1),1)).to(device)
#     print(x.size(), x_CLS_token.size())
    
    x = torch.concat((x_CLS_token,x), dim=2) # bs, embed_dim, nums_patches + 1_(CLS)
    x = torch.concat((x,r), dim=1)  # bs, dim(position + embed_dim), num_patches + CLS
    
#     print(x.size())
    
    x = x.permute(0, 2, 1) # bs, embed_dim, n_patches ---> bs, n_patches, embed_dim 
    TE_model = TransformerEncoder(n_heads=self.n_heads, embed_dim=x.size(2), num_patches=x.size(1), hidden_dim=self.hidden_dim).to(device)
    
#     TE_model_2 = TransformerEncoder(n_heads=self.n_heads, embed_dim=x.size(2), num_patches=x.size(1), hidden_dim=self.hidden_dim).to(device)

#     y = TE_model_2(TE_model(x))
    y = TE_model(x)

    c = self.classifier(y[:,0,:])

    return c

In [5]:
# Training function
def train(model, train_dataloader, val_dataloader, lr, optimizer, loss_function, epochs, device, PATH):
  model.to(device)
# extract validation and report the accuracy in each epoch
  for epoch in range(epochs+1):
    total_loss = 0
    n_samples = 0

    total_correct = 0
    total_samples = 0

    for batch in train_dataloader:
      img, labels = batch
      img = img.to(device)
      labels = labels.to(device)
      y = model(img)

      loss = loss_function(y, labels)
      n_samples += img.size(0)
      total_loss += loss.item()

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()


    for batch in train_dataloader:
      img, labels = batch
      img = img.to(device)
      labels = labels.to(device)
      outputs = model(img)
      _, predicted = torch.max(outputs, 1)

      # Update the running total of correct predictions and samples
      total_correct += (predicted == labels).sum().item()
      total_samples += labels.size(0)





    # Calculate the accuracy for this epoch
    accuracy = 100 * total_correct / total_samples





    if epoch%10 == 0:
      print("loss: ", total_loss/n_samples, "epoch=", epoch, "acc on val", accuracy )
      PATH_ = PATH + 'checkpoint' + str(epoch) + '.pt'
      torch.save(model.state_dict(), PATH_)


  return model.state_dict()




In [9]:
def test(model, test_dataloader, device):

  total_correct = 0
  total_samples = 0

  for batch in test_dataloader:
    img, labels = batch
    img = img.to(device)
    labels = labels.to(device)
    outputs = model(img)
    
    _, predicted = torch.max(outputs, 1)

    # Update the running total of correct predictions and samples
    total_correct += (predicted == labels).sum().item()
    total_samples += labels.size(0)





    # Calculate the accuracy for this epoch
  accuracy = 100 * total_correct / total_samples
  print("accuracy", accuracy)




In [11]:
# Input size b, c, h, w
# It is a classification task
# For me the challenge is sequencialization of the image patches and channel handeling





# Download the dataset and define the dataloader

trans = transform.Compose([
         transform.ToTensor(),
         transform.Resize((32,32)),
         #transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ])


VAL_SIZE = 0.1
train_set = dataset.MNIST(root='./datasets', train=True, download=True, transform=trans)
test_set = dataset.MNIST(root='./datasets', train=False, download=True, transform=trans)




train_indices, val_indices, _, _ = train_test_split(
    range(len(train_set)),
    train_set.targets,
    stratify=train_set.targets,
    test_size=VAL_SIZE,
)

# generate a subset of train and validation based on indices
train_split = data.Subset(train_set, train_indices)
val_split = data.Subset(train_set, val_indices)


# Dataloader
train_dataloader = data.DataLoader(train_split, shuffle=True, batch_size=64)
val_dataloader = data.DataLoader(val_split, shuffle=False, batch_size=64)
test_dataloader = data.DataLoader(test_set, shuffle=False, batch_size=64)

# define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("The code is running on device: ", device)


PATH = "./model/"  # folder path to save the checkpoints

param = {"patch_size":8, "embed_dim":30, "hidden_dim":50, "n_heads":3, "n_classes":10}
# define the mode
model = Vit(**param)

# learning rate
lr = 0.02

# training iteration
epochs = 1500

# optimizer, we can also define  a scheduler for the optimizer

optimizer = SGD(model.parameters(), lr)
loss_function = nn.CrossEntropyLoss()


# checkpoints = train(model, train_dataloader, val_dataloader,lr, optimizer, loss_function, epochs, device, PATH)

# save checkpoints

# torch.save(checkpoints, PATH)

# load checkpoints
# model.load_state_dict(torch.load(PATH))
# model.eval()



# CHECK TRANSFORMER IMPLEMENTATION FOR REPEATING THE RESNET BOCK AND FOR MATMUL N_HEADS AND FEATURE MANAGEMENT


The code is running on device:  cuda


In [13]:
PATH = "./model/checkpoint10.pt"
model.load_state_dict(torch.load(PATH))
model.eval()
model = model.to(device)

test(model, test_dataloader, device)

accuracy 25.08
