<a href="https://www.kaggle.com/code/mohamedastitou/vit-pytroch?scriptVersionId=209794644" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
pip install idx2numpy

In [None]:
import numpy as np

from tqdm import tqdm, trange
import idx2numpy
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Subset, Dataset



In [None]:
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

In [None]:
# Loading data

# transform = ToTensor()

# train_set = MNIST(root='/kaggle/input/', train=True, download=True, transform=transform)
# test_set = MNIST(root='/kaggle/input/', train=False, download=True, transform=transform)

In [None]:
#train_subset = Subset(train_set, indices=list(range(500)))
#test_subset = Subset(test_set, indices=list(range(100)))

In [None]:
# Charger les données IDX
train_images_path = "/kaggle/input/mnist-dataset/train-images.idx3-ubyte"
train_labels_path = "/kaggle/input/mnist-dataset/train-labels.idx1-ubyte"
test_images_path = "/kaggle/input/mnist-dataset/t10k-images.idx3-ubyte"
test_labels_path = "/kaggle/input/mnist-dataset/t10k-labels.idx1-ubyte"

In [None]:
train_images = idx2numpy.convert_from_file(train_images_path)
train_labels = idx2numpy.convert_from_file(train_labels_path)
test_images = idx2numpy.convert_from_file(test_images_path)
test_labels = idx2numpy.convert_from_file(test_labels_path)

In [None]:
print(train_images.shape)


In [None]:
#custionm mnist dataset to vgg
class mnist_dataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform


    def __len__(self):
        return len(self.images)

   
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
train_dataset = mnist_dataset(train_images, train_labels, transform=ToTensor())
test_dataset = mnist_dataset(test_images, test_labels, transform=ToTensor())

In [None]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=128)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=128)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [None]:
def get_positional_embeddings(sequence_length, d):
    print("h")

    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

In [None]:
class MyViT(nn.Module):
  def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):
    # Super constructor
    super(MyViT, self).__init__()
    
    # Attributes
    self.chw = chw # ( C , H , W )
    self.n_patches = n_patches
    self.n_blocks = n_blocks
    self.n_heads = n_heads
    self.hidden_d = hidden_d
        
    assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

    # 1) Linear mapper
    self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
    self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

    # 2) Learnable classifiation token
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
      
    #print(self.class_token)
    # 3) Positional embedding
    # self.pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d)))
    # self.pos_embed.requires_grad = False
    self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)

    # 4) Transformer encoder blocks
    self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])

    # 5) Classification MLPk
    self.mlp = nn.Sequential(
        nn.Linear(self.hidden_d, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 32),
        nn.ReLU(),
        nn.Linear(32, out_d),
        nn.Softmax(dim=-1)
    )

  def forward(self, images):
    n, c, h, w = images.shape
      
    patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)
      
    tokens = self.linear_mapper(patches)

      
    # Adding classification token to the tokens
    tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)

    # tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
    #print(tokens[0])
    # Adding positional embedding
    # pos_embed = self.pos_embed.repeat(2, 1, 1)
    #print(pos_embed)
    # out = tokens + pos_embed
    out = tokens + self.positional_embeddings.repeat(n, 1, 1)
    for block in self.blocks:
            out = block(out)
    # Getting the classification token only
    out = out[:, 0]
    return self.mlp(out) # Map to output dimension, output category distribution
    

In [None]:
class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"
        d_head = int(d / n_heads)
        
        print("d_head", d_head)

        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        # print("sequences" , sequences)
        result = []
        for sequence in sequences:
            # print("sequence", sequence)
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                # print("seq", seq)
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
                # weights = list(q_mapping.parameters())[0]

                # print(weights.shape)  # Output: torch.Size([5, 10])
                # print(weights)  # Output: tensor containing the weight values
                # print(q)
                # print(k)
                # print(k.T)
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
                # print("seq_result", seq_result)
            result.append(torch.hstack(seq_result))
            # print("resutl", result)
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [None]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return x

In [None]:
# if __name__ == '__main__':
#     model = MyViTBlock(hidden_d=4, n_heads=2)

#     x = torch.randn(3, 4, 4)  # Dummy sequences
#     # print(x)
#     print(model(x).shape)      

In [None]:
model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=16, n_heads=2, out_d=10).to(device)

In [None]:
N_EPOCHS = 30
LR = 0.01

# Training loop
optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()

In [None]:
for epoch in range(N_EPOCHS, desc="Training"):
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)
            predit = model(x)
            loss = criterion(predit, y)

            train_loss += loss.detach().cpu().item() 

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss/len(train_loader):.6f}")