In [14]:
import numpy as np

from tqdm import tqdm, trange

import math
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

np.random.seed(0)
torch.manual_seed(0)

def main():
    # Loading data
    transform = ToTensor()

    train_set = MNIST(root='/content', train=True, download=True, transform=transform)
    test_set = MNIST(root='/content', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
    N_EPOCHS = 5
    LR = 0.005

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [11]:
# Vision Transformer (ViT) Model
class MyViT(nn.Module):
    def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):
        super(MyViT, self).__init__()
        self.chw = chw  # (C, H, W)
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d

        # Input and patches sizes
        assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] // n_patches, chw[2] // n_patches)

        # 1) Linear mapper
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Learnable classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        self.register_buffer('positional_embeddings', self.get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)

        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])

        # 5) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        n, c, h, w = images.shape
        patches = self.patchify(images)
        tokens = self.linear_mapper(patches)

        # Add classification token to the tokens
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)

        # Add positional embedding
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        # Transformer blocks
        for block in self.blocks:
            out = block(out)

        # Get the classification token only
        out = out[:, 0]
        return self.mlp(out)

    def patchify(self, images):
        # Dividing images into patches
        n, c, h, w = images.shape
        patches = images.unfold(2, self.patch_size[0], self.patch_size[0]).unfold(3, self.patch_size[1], self.patch_size[1])
        patches = patches.contiguous().view(n, c, -1, self.patch_size[0] * self.patch_size[1])
        patches = patches.permute(0, 2, 1, 3).reshape(n, -1, self.input_d)
        return patches

    def get_positional_embeddings(self, seq_len, hidden_d):
        pe = torch.zeros(seq_len, hidden_d)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_d, 2).float() * -(math.log(10000.0) / hidden_d))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe

# Vision Transformer Block
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out

# Multi-Head Self Attention
class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads
        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"
        self.d_head = d // n_heads

        self.q_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


In [15]:
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from tqdm import tqdm, trange

# Main function for training and testing
def main():
    # Loading data
    transform = ToTensor()

    train_set = MNIST(root='/content', train=True, download=True, transform=transform)
    test_set = MNIST(root='/content', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    # Define model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    # Initialize the ViT model with the correct input dimensions and parameters
    model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)

    # Hyperparameters
    N_EPOCHS = 5
    LR = 0.005

    # Loss function and optimizer
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()

    # Training loop
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)

            # Forward pass
            y_hat = model(x)
            loss = criterion(y_hat, y)

            # Track training loss
            train_loss += loss.detach().cpu().item() / len(train_loader)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} - Training loss: {train_loss:.2f}")

    # Testing loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)

            # Forward pass for test data
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            # Track correct predictions
            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

# Call main function to train and evaluate
if __name__ == '__main__':
    main()


Using device:  cpu 


Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:00<02:51,  2.73it/s][A
Epoch 1 in training:   0%|          | 2/469 [00:00<02:50,  2.74it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:01<03:09,  2.45it/s][A
Epoch 1 in training:   1%|          | 4/469 [00:01<03:13,  2.40it/s][A
Epoch 1 in training:   1%|          | 5/469 [00:02<03:13,  2.40it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:02<03:17,  2.35it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:02<03:27,  2.23it/s][A
Epoch 1 in training:   2%|▏         | 8/469 [00:03<03:23,  2.27it/s][A
Epoch 1 in training:   2%|▏         | 9/469 [00:03<03:05,  2.47it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:04<02:51,  2.67it/s][A
Epoch 1 in training:   2%|▏         | 11/469 [00:04<02:49,  2.69it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:04<02:41,  2.82it/s][A
Epoch 1 in training: 

Epoch 1/5 - Training loss: 2.11



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:00<02:32,  3.07it/s][A
Epoch 2 in training:   0%|          | 2/469 [00:00<02:28,  3.15it/s][A
Epoch 2 in training:   1%|          | 3/469 [00:01<02:38,  2.94it/s][A
Epoch 2 in training:   1%|          | 4/469 [00:01<02:38,  2.93it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:01<02:34,  3.01it/s][A
Epoch 2 in training:   1%|▏         | 6/469 [00:01<02:32,  3.03it/s][A
Epoch 2 in training:   1%|▏         | 7/469 [00:02<02:28,  3.10it/s][A
Epoch 2 in training:   2%|▏         | 8/469 [00:02<02:29,  3.09it/s][A
Epoch 2 in training:   2%|▏         | 9/469 [00:02<02:30,  3.06it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:03<02:29,  3.07it/s][A
Epoch 2 in training:   2%|▏         | 11/469 [00:03<02:32,  3.00it/s][A
Epoch 2 in training:   3%|▎         | 12/469 [00:03<02:32,  3.00it/s][A
Epoch 2 in training:   3%|▎         | 13/469 [00:04<02:29,  3.06it/s

Epoch 2/5 - Training loss: 1.83



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:00<02:34,  3.02it/s][A
Epoch 3 in training:   0%|          | 2/469 [00:00<02:49,  2.75it/s][A
Epoch 3 in training:   1%|          | 3/469 [00:01<03:17,  2.36it/s][A
Epoch 3 in training:   1%|          | 4/469 [00:01<03:40,  2.11it/s][A
Epoch 3 in training:   1%|          | 5/469 [00:02<03:14,  2.39it/s][A
Epoch 3 in training:   1%|▏         | 6/469 [00:02<03:01,  2.54it/s][A
Epoch 3 in training:   1%|▏         | 7/469 [00:02<02:52,  2.69it/s][A
Epoch 3 in training:   2%|▏         | 8/469 [00:03<02:42,  2.83it/s][A
Epoch 3 in training:   2%|▏         | 9/469 [00:03<02:37,  2.93it/s][A
Epoch 3 in training:   2%|▏         | 10/469 [00:03<02:32,  3.01it/s][A
Epoch 3 in training:   2%|▏         | 11/469 [00:04<02:40,  2.86it/s][A
Epoch 3 in training:   3%|▎         | 12/469 [00:04<02:45,  2.77it/s][A
Epoch 3 in training:   3%|▎         | 13/469 [00:04<02:59,  2.54it/s

Epoch 3/5 - Training loss: 1.75



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:00<03:45,  2.07it/s][A
Epoch 4 in training:   0%|          | 2/469 [00:00<03:29,  2.23it/s][A
Epoch 4 in training:   1%|          | 3/469 [00:01<03:39,  2.12it/s][A
Epoch 4 in training:   1%|          | 4/469 [00:01<03:53,  1.99it/s][A
Epoch 4 in training:   1%|          | 5/469 [00:02<03:50,  2.02it/s][A
Epoch 4 in training:   1%|▏         | 6/469 [00:02<03:30,  2.20it/s][A
Epoch 4 in training:   1%|▏         | 7/469 [00:03<03:10,  2.43it/s][A
Epoch 4 in training:   2%|▏         | 8/469 [00:03<02:56,  2.61it/s][A
Epoch 4 in training:   2%|▏         | 9/469 [00:03<02:55,  2.63it/s][A
Epoch 4 in training:   2%|▏         | 10/469 [00:04<02:58,  2.57it/s][A
Epoch 4 in training:   2%|▏         | 11/469 [00:04<03:26,  2.22it/s][A
Epoch 4 in training:   3%|▎         | 12/469 [00:05<03:24,  2.23it/s][A
Epoch 4 in training:   3%|▎         | 13/469 [00:05<03:05,  2.46it/s

Epoch 4/5 - Training loss: 1.73



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:00<02:35,  3.01it/s][A
Epoch 5 in training:   0%|          | 2/469 [00:00<02:38,  2.95it/s][A
Epoch 5 in training:   1%|          | 3/469 [00:01<02:39,  2.93it/s][A
Epoch 5 in training:   1%|          | 4/469 [00:01<02:34,  3.01it/s][A
Epoch 5 in training:   1%|          | 5/469 [00:01<02:32,  3.04it/s][A
Epoch 5 in training:   1%|▏         | 6/469 [00:02<02:33,  3.01it/s][A
Epoch 5 in training:   1%|▏         | 7/469 [00:02<02:32,  3.03it/s][A
Epoch 5 in training:   2%|▏         | 8/469 [00:02<02:32,  3.01it/s][A
Epoch 5 in training:   2%|▏         | 9/469 [00:02<02:31,  3.03it/s][A
Epoch 5 in training:   2%|▏         | 10/469 [00:03<02:31,  3.03it/s][A
Epoch 5 in training:   2%|▏         | 11/469 [00:03<02:29,  3.06it/s][A
Epoch 5 in training:   3%|▎         | 12/469 [00:03<02:30,  3.03it/s][A
Epoch 5 in training:   3%|▎         | 13/469 [00:04<02:32,  2.99it/s

Epoch 5/5 - Training loss: 1.71


Testing: 100%|██████████| 79/79 [00:08<00:00,  8.81it/s]

Test loss: 1.70
Test accuracy: 76.29%



