In [23]:
##IMPORT 

import numpy as np
from tqdm import tqdm, trange
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST
import matplotlib.pyplot as plt
from patchify import patchify


In [56]:
##PATCHFY FUNCTION & POSITIONAL_EMB

def patching_func(images,patch_size):

    n,c,h_iamge,w_image=images.shape
    n_patch=0
    n_patch_per_img = int((h_iamge/patch_size)**2)
    patch_vector=int(patch_size**2)
    
    images=images.numpy()           #needs patchify function
    images=images.squeeze()

    total_patches = np.zeros((n,n_patch_per_img,patch_vector))
 
    for idx,image in enumerate(images):
        patching = patchify(image, (patch_size,patch_size),step=4) # split image into 7x7 small 4x4 patches.patchify(image, (3, 3), step=1)
        for i in range(patching.shape[0]):
            for j in range (patching.shape[1]):
                single_patch=patching[i,j,:,:]
                total_patches[idx,n_patch,:]=single_patch.flatten()
                n_patch=n_patch+1
                if n_patch==49:
                    n_patch=0
    return torch.tensor(total_patches)

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

In [25]:
#IMPORT DATA
# 
transform = ToTensor()
train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
test_set = MNIST(root='./../datasets', train=False, download=False, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

In [57]:

#DATA TO PATCH
for batch in train_loader:
    images, targets = batch 

#plt.imshow(images[0].squeeze())      #1x28x28--> 28x28
patches=patching_func(images,4)

print(patches.shape)
print(patches.dtype)
#print(images.shape)
#print(patching_func(images,4).shape)

torch.Size([96, 49, 16])
torch.float64


In [79]:
# CLASSES VIT

class ViT (nn.Module):
    def __init__(self, images_shape, patch_size=4, t_blocks=2, token_dim=8, n_heads=2, output_dim=10, mlp_layer_size=8):
        super().__init__()
        
        self.c,self.h_image,self.w_image=images_shape
        
        self.patch_size = patch_size
        self.t_blocks   = t_blocks
        self.n_heads    = n_heads
        self.token_dim  = token_dim
        self.output_dim = output_dim
        self.mlp_layer_size     = mlp_layer_size
        self.token_length       = int(self.h_image/self.patch_size**2)

        self.linear_map         =   nn.Linear(patch_size**2,token_dim)
        self.class_token        =   nn.Parameter(torch.rand(1, token_dim))
        self.blocks             =   nn.ModuleList([ViTBlock(token_dim, mlp_layer_size,n_heads) for _ in range(t_blocks)])
        self.linear_classifier  =   nn.Linear(token_dim, output_dim)
        self.output_pr          =   nn.Softmax()
        

    def forward(self, images):
        

        self.n_images,self.c,self.h_image,self.w_image = images.shape
        
        all_class_token = np.zeros((self.n_images,self.token_length,self.token_dim))
        all_class_token[self.n_images,:,:]= self.class_token
        
        patches     = patching_func(images, self.patch_size)
        linear_emb  = self.linear_map(patches)

        tokens      = torch.cat((all_class_token,linear_emb),dim=1)
        out         = tokens # positional embeddings will be added

        for block in self.blocks:
            out = block(out)
            
        out = out[:, 0, :]
        out = self.output_pr(self.linear_classifier(out))
        return out 


class ViTBlock(nn.Module):

    def __init__(self, token_dim, mlp_layer_size=8, num_heads=2):
        super().__init__() 
        self.token_dim      = token_dim
        self.num_heads      = num_heads
        self.mlp_layer_size = mlp_layer_size

        self.layer_norms    = nn.ModuleList([nn.LayerNorm(token_dim) for _ in range(num_heads)])    
        self.msa            = MSA_Module(token_dim, num_heads)
        self.act_layer      = nn.GELU()
        
        self.mlp            = nn.Sequential(
            nn.Linear(token_dim, mlp_layer_size), self.act_layer,
            nn.Linear(mlp_layer_size, token_dim)
            )

    def forward(self, x):
        out = x + self.msa(self.layer_norms[0](x))
        out = self.layer_norms[1](out) + self.mlp(self.layer_norms[1](out))
        return out
        
class MSA_Module(nn.Module):
    def __init__(self, token_dim, n_heads=2):
        super().__init__() 
        self.n_heads    = n_heads
        self.token_dim  = token_dim

        self.q_layers   = nn.ModuleList([nn.Linear(token_dim,token_dim) for _ in range(n_heads)])
        self.k_layers   = nn.ModuleList([nn.Linear(token_dim,token_dim) for _ in range(n_heads)])
        self.v_layers   = nn.ModuleList([nn.Linear(token_dim,token_dim) for _ in range(n_heads)])
        self.softmax    = nn.Softmax()

    def forward (self, tokens):
        
        self.n,self.number_tokens,self.patch_size = tokens.shape

        result = torch.zeros(self.n,self.number_tokens,self.patch_size)
        
        for idx,token in enumerate(tokens):   # 128 batch. each of 50x8, token size : 50x8   --> 50x8
            
            concat      = torch.zeros(self.n_heads,self.number_tokens,self.patch_size)

            for head in range(self.n_heads):        # number of heads : 2
                q_linear = self.q_layers[head]      # linear (8x8)  == 50x8 --> 50x8
                k_linear = self.k_layers[head]
                v_linear = self.v_layers[head]

                q  = q_linear(token)
                k  = k_linear(token)
                v  = v_linear(token)

                mat_mul = (torch.matmul(q, k.T)) / ((self.number_tokens-1)**0.5)   # 50x8 x 8x50 = 50x50 
                attention_mask  = self.softmax(mat_mul)
                attention       = torch.matmul(attention_mask,v)
                concat[head,:,:]= attention
            result[idx,:,:] = torch.flatten(input=concat, start_dim=0, end_dim=1)


In [82]:
def main():
    # Loading data
    transform = ToTensor()

    train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
    test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    model = ViT((1, 28, 28), patch_size=4, t_blocks=2, token_dim=64, n_heads=2, output_dim=10)
    N_EPOCHS = 5
    LR = 0.005
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            print(x.shape)
            y_hat = model(x)
            
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")


In [83]:

if __name__ == '__main__':
    main()

Training:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([128, 1, 28, 28])





IndexError: index 128 is out of bounds for axis 0 with size 128

In [70]:

# Linear Embeddings
token_dim=8
patch_size=4

class_token = nn.Parameter(torch.rand(1, token_dim))
linear_map=nn.Linear(patch_size**2,token_dim)

#print(class_token.shape)
#print(patches.dtype)
#print(patches.shape)
linear_emb = linear_map(patches.float())
#print(linear_emb.shape)
#print(class_token.shape)
        
n_images,c,h_image,w_image= images.shape

token_size=int(h_image/patch_size**2)

all_class_token = torch.zeros((n_images,token_size,token_dim))
tokens      = torch.cat((all_class_token,linear_emb),dim=1)

# Layer Norm 
token_dim=8
layer_norms = nn.ModuleList([nn.LayerNorm(token_dim) for _ in range(2)])
print(layer_norms[0](tokens).shape)


# Multi Self Attention 
n,number_tokens,patch_size = tokens.shape
result = torch.zeros(n,number_tokens,patch_size)
print('result: ',result.shape)

n_heads    = 2

q_layers   = nn.ModuleList([nn.Linear(token_dim,token_dim) for _ in range(n_heads)])
k_layers   = nn.ModuleList([nn.Linear(token_dim,token_dim) for _ in range(n_heads)])
v_layers   = nn.ModuleList([nn.Linear(token_dim,token_dim) for _ in range(n_heads)])
softmax    = nn.Softmax()

n,number_tokens,patch_size = tokens.shape
result = torch.zeros(n,number_tokens*n_heads,patch_size)

for idx,token in enumerate(tokens):   # 128 batch. each of 50x8, token size : 50x8   --> 50x8

    concat      = torch.zeros(n_heads,number_tokens,patch_size)

    for head in range(n_heads):        # number of heads : 2
        q_linear = q_layers[head]      # linear (8x8)  == 50x8 --> 50x8
        k_linear = k_layers[head]
        v_linear = v_layers[head]

        q  = q_linear(token)
        k  = k_linear(token)
        v  = v_linear(token)

        mat_mul = (torch.matmul(q, k.T)) / ((number_tokens-1)**0.5)   # 50x8 x 8x50 = 50x50 
        attention_mask  = softmax(mat_mul)
        attention       = torch.matmul(attention_mask,v)
        concat[head,:,:]= attention
    result[idx,:,:] = torch.flatten(input=concat, start_dim=0, end_dim=1)


        

torch.Size([96, 1, 8])
torch.Size([96, 49, 8])
torch.Size([96, 50, 8])
result:  torch.Size([96, 50, 8])


  attention_mask  = softmax(mat_mul)
