In [1]:
#orig: https://github.com/BrianPulfer/PapersReimplementations/blob/master/vit/vit_torch.py

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor

np.random.seed(0)
torch.manual_seed(0)


def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(h // n_patches):
            for j in range(w // n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches


class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10, device=None):
        # Super constructor
        super(MyViT, self).__init__()
        self.device = device

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 4b) Multi-head Self Attention (MSA) and classification token
        self.msa = MyMSA(self.hidden_d, n_heads)

        # 5a) Layer normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.hidden_d),
            nn.ReLU()
        )

        # 6) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = patchify(images, self.n_patches)

        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1).to(self.device)

        # TRANSFORMER ENCODER BEGINS ###################################
        # NOTICE: MULTIPLE ENCODER BLOCKS CAN BE STACKED TOGETHER ######
        # Running Layer Normalization, MSA and residual connection
        out = tokens + self.msa(self.ln1(tokens))

        # Running Layer Normalization, MLP and residual connection
        out = out + self.enc_mlp(self.ln2(out))
        # TRANSFORMER ENCODER ENDS   ###################################

        # Getting the classification token only
        out = out[:, 0]

        return self.mlp(out)


class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

In [2]:
def main():
    # Loading data
    transform = ToTensor()

    train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
    test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=16)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=16)

    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MyViT((1, 28, 28), n_patches=7, hidden_d=20, n_heads=2, out_d=10, device=device).to(device)
    N_EPOCHS = 5
    LR = 0.01

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in tqdm(range(N_EPOCHS), desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y) / len(x)

            train_loss += loss.detach().cpu().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y) / len(x)
        test_loss += loss.detach().cpu().item()

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")


if __name__ == '__main__':
    main()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./../datasets/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw



Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1:   0%|          | 0/3750 [00:00<?, ?it/s][A
Epoch 1:   0%|          | 1/3750 [00:00<12:02,  5.19it/s][A
Epoch 1:   0%|          | 4/3750 [00:00<04:06, 15.20it/s][A
Epoch 1:   0%|          | 7/3750 [00:00<03:10, 19.68it/s][A
Epoch 1:   0%|          | 10/3750 [00:00<02:51, 21.78it/s][A
Epoch 1:   0%|          | 13/3750 [00:00<02:42, 22.96it/s][A
Epoch 1:   0%|          | 16/3750 [00:00<02:32, 24.54it/s][A
Epoch 1:   1%|          | 19/3750 [00:00<02:26, 25.55it/s][A
Epoch 1:   1%|          | 22/3750 [00:00<02:26, 25.39it/s][A
Epoch 1:   1%|          | 25/3750 [00:01<02:24, 25.74it/s][A
Epoch 1:   1%|          | 28/3750 [00:01<02:22, 26.21it/s][A
Epoch 1:   1%|          | 31/3750 [00:01<02:22, 26.13it/s][A
Epoch 1:   1%|          | 34/3750 [00:01<02:23, 25.81it/s][A
Epoch 1:   1%|          | 37/3750 [00:01<02:24, 25.61it/s][A
Epoch 1:   1%|          | 40/3750 [00:01<02:22, 26.01it/s][A
Epoch 1:   1%|          | 43/3750 

Epoch 1/5 loss: 464.71



Epoch 2:   0%|          | 0/3750 [00:00<?, ?it/s][A
Epoch 2:   0%|          | 1/3750 [00:00<06:31,  9.59it/s][A
Epoch 2:   0%|          | 3/3750 [00:00<04:55, 12.67it/s][A
Epoch 2:   0%|          | 5/3750 [00:00<04:34, 13.64it/s][A
Epoch 2:   0%|          | 7/3750 [00:00<04:22, 14.23it/s][A
Epoch 2:   0%|          | 9/3750 [00:00<04:36, 13.53it/s][A
Epoch 2:   0%|          | 11/3750 [00:00<04:23, 14.18it/s][A
Epoch 2:   0%|          | 13/3750 [00:00<04:24, 14.15it/s][A
Epoch 2:   0%|          | 15/3750 [00:01<04:19, 14.38it/s][A
Epoch 2:   0%|          | 17/3750 [00:01<04:21, 14.30it/s][A
Epoch 2:   1%|          | 19/3750 [00:01<04:12, 14.78it/s][A
Epoch 2:   1%|          | 21/3750 [00:01<04:05, 15.16it/s][A
Epoch 2:   1%|          | 23/3750 [00:01<04:07, 15.05it/s][A
Epoch 2:   1%|          | 25/3750 [00:01<04:14, 14.65it/s][A
Epoch 2:   1%|          | 27/3750 [00:01<04:01, 15.42it/s][A
Epoch 2:   1%|          | 30/3750 [00:02<03:32, 17.50it/s][A
Epoch 2:   1%|       

Epoch 2/5 loss: 479.39



Epoch 3:   0%|          | 0/3750 [00:00<?, ?it/s][A
Epoch 3:   0%|          | 2/3750 [00:00<04:14, 14.72it/s][A
Epoch 3:   0%|          | 5/3750 [00:00<03:12, 19.49it/s][A
Epoch 3:   0%|          | 8/3750 [00:00<02:54, 21.41it/s][A
Epoch 3:   0%|          | 11/3750 [00:00<02:53, 21.49it/s][A
Epoch 3:   0%|          | 14/3750 [00:00<02:50, 21.90it/s][A
Epoch 3:   0%|          | 17/3750 [00:00<02:47, 22.33it/s][A
Epoch 3:   1%|          | 20/3750 [00:00<02:45, 22.51it/s][A
Epoch 3:   1%|          | 23/3750 [00:01<02:49, 21.96it/s][A
Epoch 3:   1%|          | 26/3750 [00:01<02:45, 22.52it/s][A
Epoch 3:   1%|          | 29/3750 [00:01<02:45, 22.47it/s][A
Epoch 3:   1%|          | 32/3750 [00:01<02:43, 22.81it/s][A
Epoch 3:   1%|          | 35/3750 [00:01<02:45, 22.38it/s][A
Epoch 3:   1%|          | 38/3750 [00:01<02:41, 23.01it/s][A
Epoch 3:   1%|          | 41/3750 [00:01<02:39, 23.19it/s][A
Epoch 3:   1%|          | 44/3750 [00:01<02:40, 23.13it/s][A
Epoch 3:   1%|▏    

Epoch 3/5 loss: 524.14



Epoch 4:   0%|          | 0/3750 [00:00<?, ?it/s][A
Epoch 4:   0%|          | 2/3750 [00:00<03:42, 16.85it/s][A
Epoch 4:   0%|          | 5/3750 [00:00<03:05, 20.18it/s][A
Epoch 4:   0%|          | 8/3750 [00:00<02:55, 21.35it/s][A
Epoch 4:   0%|          | 11/3750 [00:00<02:48, 22.25it/s][A
Epoch 4:   0%|          | 14/3750 [00:00<02:42, 22.95it/s][A
Epoch 4:   0%|          | 17/3750 [00:00<02:40, 23.21it/s][A
Epoch 4:   1%|          | 20/3750 [00:00<02:40, 23.23it/s][A
Epoch 4:   1%|          | 23/3750 [00:01<02:39, 23.30it/s][A
Epoch 4:   1%|          | 26/3750 [00:01<02:40, 23.24it/s][A
Epoch 4:   1%|          | 29/3750 [00:01<02:39, 23.30it/s][A
Epoch 4:   1%|          | 32/3750 [00:01<02:44, 22.57it/s][A
Epoch 4:   1%|          | 35/3750 [00:01<02:42, 22.80it/s][A
Epoch 4:   1%|          | 38/3750 [00:01<02:49, 21.92it/s][A
Epoch 4:   1%|          | 41/3750 [00:01<02:55, 21.14it/s][A
Epoch 4:   1%|          | 44/3750 [00:01<02:45, 22.38it/s][A
Epoch 4:   1%|▏    

Epoch 4/5 loss: 534.27



Epoch 5:   0%|          | 0/3750 [00:00<?, ?it/s][A
Epoch 5:   0%|          | 2/3750 [00:00<03:31, 17.76it/s][A
Epoch 5:   0%|          | 5/3750 [00:00<02:50, 21.93it/s][A
Epoch 5:   0%|          | 8/3750 [00:00<02:47, 22.36it/s][A
Epoch 5:   0%|          | 11/3750 [00:00<02:39, 23.43it/s][A
Epoch 5:   0%|          | 14/3750 [00:00<02:37, 23.66it/s][A
Epoch 5:   0%|          | 17/3750 [00:00<02:35, 24.05it/s][A
Epoch 5:   1%|          | 20/3750 [00:00<02:38, 23.56it/s][A
Epoch 5:   1%|          | 23/3750 [00:00<02:39, 23.38it/s][A
Epoch 5:   1%|          | 26/3750 [00:01<02:36, 23.85it/s][A
Epoch 5:   1%|          | 29/3750 [00:01<02:34, 24.16it/s][A
Epoch 5:   1%|          | 32/3750 [00:01<02:32, 24.46it/s][A
Epoch 5:   1%|          | 35/3750 [00:01<02:31, 24.52it/s][A
Epoch 5:   1%|          | 38/3750 [00:01<02:30, 24.65it/s][A
Epoch 5:   1%|          | 41/3750 [00:01<02:30, 24.66it/s][A
Epoch 5:   1%|          | 44/3750 [00:01<02:29, 24.79it/s][A
Epoch 5:   1%|▏    

Epoch 5/5 loss: 544.59


Testing: 100%|██████████| 625/625 [00:17<00:00, 36.62it/s]


Test loss: 91.47
Test accuracy: 11.94%
