In [53]:
##IMPORT 

import numpy as np
from tqdm import tqdm, trange
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST
import matplotlib.pyplot as plt
from patchify import patchify


In [54]:
##PATCHFY FUNCTION & POSITIONAL_EMB

def patching_func(images,patch_size):

    n,c,h_iamge,w_image=images.shape
    n_patch=0
    n_patch_per_img = int((h_iamge/patch_size)**2)
    patch_vector=int(patch_size**2)
    images=images.squeeze()

    total_patches = np.zeros((n,n_patch_per_img,patch_vector))
 
    for idx,image in enumerate(images):
        patching = patchify(image, (patch_size,patch_size),step=4) # split image into 7x7 small 4x4 patches.patchify(image, (3, 3), step=1)
        for i in range(patching.shape[0]):
            for j in range (patching.shape[1]):
                single_patch=patching[i,j,:,:]
                total_patches[idx,n_patch,:]=single_patch.flatten()
                n_patch=n_patch+1
                if n_patch==49:
                    n_patch=0
    return total_patches

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

In [55]:
#IMPORT DATA
# 
transform = ToTensor()
train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
test_set = MNIST(root='./../datasets', train=False, download=False, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)



In [58]:
#DATA TO PATCH
for batch in train_loader:
    images, targets = batch 

#plt.imshow(images[0].squeeze())      #1x28x28--> 28x28
images=images.numpy()

print(images.shape)
print(patching_func(images,4).shape)
patches=torch.tensor(patching_func(images,4))


(96, 1, 28, 28)
(96, 49, 16)


In [None]:
x=torch.rand((4,3,5))
print(x)
print(x[:,0])
print(x[:,0,:])

In [None]:
# CLASSES VIT

class ViT (nn.module):
    def __init__(self, images_shape, patch_size=4, t_blocks=2, token_dim=64, n_heads=2, output_dim=10):
        super().__init__()
        
        n,c,h_image,w_image=images_shape
        self.patch_size = patch_size
        self.t_blocks   = t_blocks
        self.n_heads    = n_heads
        self.token_dim  = token_dim
        self.output_dim = output_dim

        linear_map=nn.Linear(self.patch_size**2,token_dim)
        self.class_token = nn.Parameter(torch.rand(1, self.token_dim))
        self.blocks = nn.ModuleList([ViTBlock(token_dim, n_heads) for _ in range(t_blocks)])

        self.linear_classifier  = nn.Linear(self.token_dim, output_dim),
        output_pr = nn.Softmax(dim=-1)

    def forward(self, images):

        patches = patching_func(images, self.patc_size)

        for block in self.blocks:
            out = block(out)
            
        out = out[:, 0, :]
        out = self.output_pr(self.linear_classifier(out))
        return out 

class ViTBlock(nn.Module):

    def __init__(self, token_dim, num_heads, mlp_ratio):
        super().__init__()
        
        self.mlp_ratio=mlp_ratio
        self.norm1 = norm_layer(token_dim)
        self.msa = MSA_Module(token_dim, num_heads)
        self.norm2 = norm_layer(token_dim)
        self.mlp   = nn.Sequential(
            nn.Linear(token_dim, self.mlp_ratio * token_dim),
            nn.GELU(),
            nn.Linear(self.mlp_ratio * token_dim, token_dim)
            )

        act_layer=nn.GELU()
        norm_layer=nn.LayerNorm()

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x