# Visual Transformer

![Structure](images/visual_transformer/visual_transformer.webp)

In [12]:
import numpy as np
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST
import matplotlib.pyplot as plt


In [13]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x1224e1f50>

In [14]:
def patchify(images, n_patches):
    n, c, h, w = images.shape
    assert h == w

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i+1) * patch_size,
                              j * patch_size: (j+1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

![Patch](images/visual_transformer/visual_transformer_patch.webp)

In [15]:
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            if j % 2 == 0:
                result[i][j] = np.cos(i / (10000 ** ((j-1)/d)))
            else:
                result[i][j] = np.sin(i / (10000 ** (j/d)))
    return result

### Positional Encoding
* x = token_dim
* y = patches
* differs in y direction

![Positional Encoding](images/visual_transformer/positional_encoding.png)

Simply put: we want, for a single image, each patch to get updated based on some similarity measure with the other patches. We do so by linearly mapping each patch (that is now an 8-dimensional vector in our example) to 3 distinct vectors: q, k, and v (query, key, value).

Then, for a single patch, we are going to compute the dot product between its q vector with all of the k vectors, divide by the square root of the dimensionality of these vectors (sqrt(8)), softmax these so-called attention cues, and finally multiply each attention cue with the v vectors associated with the different k vectors and sum all up.

In this way, each patch assumes a new value that is based on its similarity (after the linear mapping to q, k, and v) with other patches. This whole procedure, however, is carried out H times on H sub-vectors of our current 8-dimensional patches, where H is the number of Heads.

See https://data-science-blog.com/blog/2021/04/07/multi-head-attention-mechanism/

In [16]:
class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d  # token_dim = 8
        self.n_heads = n_heads
        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)  # 4
        self.q_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequence has shape (N, seq_length=50, token_dim=8)
        # We go into shape   (N, seq_length=50, n_heads=2, token_dim / n_heads=4)
        # And come back to   (N, seq_length=50, item_dim) through concatenation
        result = []
        for sequence in sequences:  # sequences are the image patches
            seq_result = []
            for head in range(self.n_heads):
                # mapping: 4x4
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                # seq has shape (seq_length=50, d_head=4)
                seq = sequence[:, head * self.d_head: (head+1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                # (50,4)*(4,50) = (50,50)
                attention = self.softmax(q @ k.T / self.d_head ** 0.5)

                # (50,50) * (4,50) = (50, 4)
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


In [17]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads
        self.normal = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.normal(x))
        out = out + self.mlp(self.norm2(out))
        return out


In [18]:
class MyViT(nn.Module):
    def __init__(self, chw=(1, 28, 28), n_patches=7, hidden_d=8, n_blocks=2, n_heads=2, out_d=10):
        super(MyViT, self).__init__()
        self.chw = chw
        self.n_patches = n_patches
        self.hidden_d = hidden_d
        self.n_blocks = n_blocks  # Encoder Blocks
        self.n_heads = n_heads

        assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # 1) Linear Mapper
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Learnable classification token
        # stacked before of the input embeddings
        # torch.Size([7, 49, 8])
        # torch.Size([7, 50, 8])
        # It will capture information about the other tokens
        # When information about all other tokens will be present here,
        # we will be able to classify the image using only this special token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional encoding
        self.register_buffer('positional_embeddings', get_positional_embeddings(
            self.n_patches ** 2 + 1, self.hidden_d), persistent=False)

        # 4) Transformer encoder
        self.blocks = nn.ModuleList(MyViTBlock(
            hidden_d, n_heads) for _ in range(n_blocks))

        # 5) Classification using the classification token only
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d), nn.Softmax(dim=-1))

    def forward(self, images):
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)

        # Tokenize the image into shape(N, 49, 8)
        tokens = self.linear_mapper(patches)
        print("tokens.shape", tokens.shape)

        # Adding classification token
        # (N, 50, 8)
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
        print("tokens.shape after adding class token", tokens.shape)

        # Add positional embedding
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        # Transformer
        # (N, 50, 8)
        for block in self.blocks:
            out = block(out)

        # Get Classification token
        out = out[:, 0]
        out = self.mlp(out)
        print("mlp.shape ", out.shape)
        return out


In [9]:
def main():
    transform = ToTensor()
    train_set = MNIST(root='./datasets', train=True,
                      download=True, transform=transform)
    test_set = MNIST(root='./datasets', train=False,
                     download=True, transform=transform)
    train_loader = DataLoader(train_set, shuffle=True, batch_size=512)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=512)
    device = torch.device("mps")
    model = MyViT((1, 28, 28), n_patches=7, n_blocks=2,
                  hidden_d=8, n_heads=2, out_d=10).to(device)
    N_EPOCHS = 5
    LR = 0.005

    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1} in training', leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)
            correct += torch.sum(torch.argmax(y_hat, dim=1)
                                 == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

In [10]:
def test():

    plt.imshow(get_positional_embeddings(49+1, 8),
               cmap="hot", interpolation="nearest")
    plt.show()

    block = MyViTBlock(hidden_d=8, n_heads=2)
    y = torch.randn(7, 50, 8)
    block(y)


# Sources
* [Vision Transformers from Scratch (PyTorch): A step-by-step guide](https://medium.com/@brianpulfer/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c)