# **Vision Transformer**
CODE:https://github.com/BrianPulfer/PapersReimplementations/blob/master/vit/vit_torch.py




In [92]:
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn 
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

np.random.seed(0)

def patchify (images, n_patches):
    n,c,h,w = images.shape              #50x1x28x28
    
    assert h==w,"Patchify method is implemented for square images only"   #if 28==28
    patches = torch.zeros(n,n_patches**2,(h//n_patches)**2)               #50x49x16
    patch_size = h//n_patches
    
    for idx, image in enumerate(images):
        for i in range (n_patches):
            for j in range(n_patches):
                patch = image[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] # patching for one image 
                patches[idx, i*n_patches+j]=patch.flatten()     # taking all (49) patches one by one and assigning the patch , first image first patch = 1x16, first image second patch =1x16
                
    return patches

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

In [105]:

patches = torch.zeros(50,49,16)   
print(patches[1,7])

torch.Size([49, 16])


In [106]:
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
hidth = 28
witdh = 28
chann= 1
number_of_mages =50
desired_token_lenght= 8 
number_patches=7
patch_size = hidth//number_patches

fake_images = torch.rand(number_of_mages, chann, hidth, witdh)
cls_t= nn.Parameter(torch.rand(1,desired_token_lenght)) 

patches=patchify(fake_images,number_patches)         
linear_mapp = nn.Linear((hidth//number_patches)**2, desired_token_lenght)    # in order to initialize and tokenize with random weights, pixels in each of patches to desired token lenghth
tokens = linear_mapp(patches)                                                # tokenize the input 

print(patches.shape)
print(tokens.shape)


torch.Size([50, 49, 16])
torch.Size([50, 49, 8])


In [None]:
class MyViT(nn.Module):
    def __init__(self,chw=(1,28,28),n_patches=7, n_blocks=2, token_lenght=8, n_heads=2):  
        super(MyViT,self).__init__()
        self.chw       = chw
        self.n_patches = n_patches
        self.token_lenght  = token_lenght
        self.n_heads   = n_heads
        self.n_blocks  = n_blocks

        assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"        
        assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"

        self.patch_size   = (chw[1] / n_patches , chw[2] / n_patches )      # patch_size    = 28/7, 28/7  --- 4x4
       
        #linear mapper 
        self.input_d      = (chw[0]*self.patch_size[0]*self.patch_size[1])  # input_d       = 1*4*4       --- 16  
        self.linear_mapper = nn.Linear(self.input_d, self.token_lenght)         # linear_mapper = input=16, output=8
        # classification token 
        self.class_token = nn.Parameter(torch.rand(1,self.token_lenght))        # class_token 1x8
        # positional embedding
        self.pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(self.n_patches**2+1, self.token_lenght)))
        self.pos_embed = requires_grad = False
        self.blocks = nn.ModuleList([MyViTBlock(token_lenght,n_heads)for _ in range(n_blocks)])
    
    def forward(self, images):
      n,c,h,w = images.shape   
      
      patches = patchify(images, self.n_patches) # number of images in dataset x 49 x 16 
      tokens = self.linear_mapper(patches)       # number of images in dataset x 49 x 8 

      #adding classification tokens
      tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))]) 

      #positional embeding adding
      pos_embed = self.pos_embed.repeat(n,1,1)
      out = tokens + pos_embed    # (NX50X8)

      #transformer Blocks
      for block in self.blocks:
        out=block(out)

      return out

In [None]:
class MyMSA(nn.Module):
  
    def __init__(self,d,n_heads=2):
      super().__init__()
      self.d=d
      self.n_heads=n_heads

      assert d % n_heads == 0,f"Can not divide dimension {d}into {n_heads}"

      d_head = int(d/n_heads)
      self.q_mappings=nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)])
      self.k_mappings=nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)])
      self.v_mappings=nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)])
      self.d_head=d_head
      self.softmax=nn.Softmax(dim=1)

    def forward(self,sequences):
      result=[]
      for sequence in sequences:
        seq_result=[]
        for head in range (self.n_heads):
          q_mapping = self.q_mappings[head]
          k_mapping = self.k_mappings[head]
          v_mapping = self.v_mappings[head]

          seq = sequence[:, head*self.d_head: (head + 1) * self.d_head]
          q,k,v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

          attention = self.softmax(q @ k.T / (self.d_head**0.5))
          seq_result.append(attention @ v )
        result.append(torch.hstack(seq_result))
      return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [None]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super().__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out

#**FAKE_IMAGES_TRYING**

In [None]:

# adding cls_t to each of tokens // vstack = 1x8 to 49x8 => 50x8
token_stack = torch.stack([torch.vstack((cls_t, tokens[i])) for i in range(len(tokens))]) 

pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(50, 8)))
pos_embed = pos_embed.repeat(5,1,1)

out=token_stack+pos_embed

out_numpy = out.cpu().detach().numpy()

#plt.imshow(fake_images[2,0,:,:])
#plt.imshow(out_numpy[2,:,:])

print(f"fake images shape     : {fake_images.shape}") #fake images shape     : torch.Size([5, 1, 28, 28])
print(f"fake images shape     : {len(fake_images)}")  #fake images shape     : 5
print(f"shape of patches      : {patches.shape}")     #shape of patches      : torch.Size([5, 49, 16])
print(f"shape of tokens       : {tokens.shape}")      #shape of tokens       : torch.Size([5, 49, 8])
print(f"shape of token_stack  : {type(token_stack)}") #shape of token_stack  : torch.Size([5, 50, 8])
print(f"shape of pos_embed    : {type(pos_embed)}")   #shape of pos_embed    : torch.Size([5, 50, 8])
print(f"shape of out          : {type(out)}")         #shape of out          : torch.Size([5, 50, 8])

x=torch.randn(70,50,8)
models = MyViTBlock(hidden_d=8,n_heads=2)
print(f'model output shape',models(x).shape)

'''
q_mappings=nn.ModuleList([nn.Linear(4,4) for _ in range(2)])
k_mappings=nn.ModuleList([nn.Linear(4,4) for _ in range(2)])
v_mappings=nn.ModuleList([nn.Linear(4,4) for _ in range(2)])
softmax=nn.Softmax(dim=1)

result=[]
for sequence in out:
  seq_result=[]
  for head in range (2):            # 2 head for each 50x8
    q_mapping = q_mappings[head]
    k_mapping = k_mappings[head]
    v_mapping = v_mappings[head]

    seq = sequence[:, head*4: (head + 1) * 4]               # seq=50x4
    q,k,v = q_mapping(seq), k_mapping(seq), v_mapping(seq)  # q = 50x4 after linear(4x4)
    attention = softmax(q @ k.T / (2**0.5))
    seq_result.append(attention @ v )                       # 2 heads x50x4  
  result.append(torch.hstack(seq_result))                   # storing all number of images (nx50x8) n=5
 
print(f'sequcens_result : ', len(seq_result[0][0]))
print(f'result',len(result[0][0]))

sonc=torch.cat([torch.unsqueeze(r, dim=0) for r in result])
sonc.size()
'''

In [None]:
def main():
    transform =ToTensor()
    train_set=MNIST(root='./datasets', train=True, download=True, transform=transform)
    test_set=MNIST(root='./datasets', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_set,shuffle=True,batch_size=128)
    test_loader = DataLoader(test_set,shuffle=False,batch_size=128)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model=...
    N_EPOCHS = 5
    LR = 0.005
    
   #training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    
    for epoch in tqdm(range(N_EPOCHS), desc='Training'):
        train_loss = 0.00
        for batch in tqdm (train_loader, desc=f"Epoch {epoch+1} in training", leave=False):
            x,y = batch
            x,y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat,y)
            
            train_loss += loss.detach().cpu().item() / len(train_loader)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print(f"Epoch {epoch +1}/{N_EPOCHS}, loss : {train_loss:.2f}")
        
    #test loop   
    with torch.no_grad():
        correct, total =0,0
        test_loss = 0
        for batch in tqdm(test_loader, desc='Testing'):
            x,y     = batch 
            x,y     = x.to(device), y.to(device)
            y_hat   = model(x)
            loss    = criterion(y_hat,y)
            test_loss += loss.detach().cpu().item() / len(test_loader)
            
            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        
        print (f"Test loss : {test_loss:.2f}")
        print (f"Test accuracy : {correct/total*100:.2f}")


In [None]:
if __name__ == '__main__':
    main()