# Vision Transformers

[Paper here](https://arxiv.org/pdf/2010.11929)


In [None]:
# %cd "Deep-Learning-From-Scratch/Generative Models/Pix2Pix"

In [None]:
from rich import print
from tqdm.notebook import tqdm

%load_ext rich

In [None]:
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.datasets import CIFAR100
from torchvision.transforms import v2
from torchvision.utils import make_grid

# import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# wandb.require('core')
# wandb.login()

In [None]:
# Hyperparameters
NUM_CLASSES = 100
PATCH_SIZE = 4
NUM_BLOCKS = 12
DIM_EMBEDDING = 768
DIM_HIDDEN = DIM_EMBEDDING * 4
NUM_HEADS = 12
DROPOUT_RATE = 0.1
IMAGE_SHAPE = (3, 32, 32)

BATCH_SIZE = 16
N_EPOCHS = 30
LEARNING_RATE = 10e-3

## Load the dataset


In [None]:
true_labels = [
    "apple",
    "aquarium_fish",
    "baby",
    "bear",
    "beaver",
    "bed",
    "bee",
    "beetle",
    "bicycle",
    "bottle",
    "bowl",
    "boy",
    "bridge",
    "bus",
    "butterfly",
    "camel",
    "can",
    "castle",
    "caterpillar",
    "cattle",
    "chair",
    "chimpanzee",
    "clock",
    "cloud",
    "cockroach",
    "couch",
    "crab",
    "crocodile",
    "cup",
    "dinosaur",
    "dolphin",
    "elephant",
    "flatfish",
    "forest",
    "fox",
    "girl",
    "hamster",
    "house",
    "kangaroo",
    "keyboard",
    "lamp",
    "lawn_mower",
    "leopard",
    "lion",
    "lizard",
    "lobster",
    "man",
    "maple_tree",
    "motorcycle",
    "mountain",
    "mouse",
    "mushroom",
    "oak_tree",
    "orange",
    "orchid",
    "otter",
    "palm_tree",
    "pear",
    "pickup_truck",
    "pine_tree",
    "plain",
    "plate",
    "poppy",
    "porcupine",
    "possum",
    "rabbit",
    "raccoon",
    "ray",
    "road",
    "rocket",
    "rose",
    "sea",
    "seal",
    "shark",
    "shrew",
    "skunk",
    "skyscraper",
    "snail",
    "snake",
    "spider",
    "squirrel",
    "streetcar",
    "sunflower",
    "sweet_pepper",
    "table",
    "tank",
    "telephone",
    "television",
    "tiger",
    "tractor",
    "train",
    "trout",
    "tulip",
    "turtle",
    "wardrobe",
    "whale",
    "willow_tree",
    "wolf",
    "woman",
    "worm",
]


In [92]:
train_dataset = CIFAR100(
    root="./data/", download=True, train=True, transform=v2.ToImage()
)
test_dataset = CIFAR100(
    root="./data/", download=True, train=False, transform=v2.ToImage()
)


In [None]:
mean_pixel_train = train_dataset.data.mean(axis=(0, 1, 2)) / 255
std_pixel_train = train_dataset.data.std(axis=(0, 1, 2)) / 255

mean_pixel_test = test_dataset.data.mean(axis=(0, 1, 2)) / 255
std_pixel_test = test_dataset.data.std(axis=(0, 1, 2)) / 255

(mean_pixel_train, std_pixel_train), (mean_pixel_test, std_pixel_test)


In [None]:
train_transform = v2.Compose(
    [
        v2.Resize((32)),
        v2.RandomHorizontalFlip(),
        v2.ToImage(),
        v2.ToDtype(dtype=torch.float32, scale=True),
        v2.Normalize(mean_pixel_train, std_pixel_train),
    ]
)

test_transform = v2.Compose(
    [
        v2.Resize((32)),
        v2.ToImage(),
        v2.ToDtype(dtype=torch.float32, scale=True),
        v2.Normalize(mean_pixel_test, std_pixel_test),
    ]
)


In [None]:
def reverse_transform(x, train=True):
    mean_pixel, std_pixel = (
        mean_pixel_train if train else mean_pixel_test,
        std_pixel_train if train else std_pixel_test,
    )

    x = x * torch.tensor(std_pixel, device=x.device).view(3, 1, 1) + torch.tensor(
        mean_pixel, device=x.device
    ).view(3, 1, 1)
    x = v2.ToPILImage()(x)

    return x

In [None]:
full_train_dataset = CIFAR100(
    root="./data/", download=True, train=True, transform=train_transform
)

test_dataset = CIFAR100(
    root="./data/", download=True, train=False, transform=test_transform
)

In [None]:
# Calculate the sizes for training and validation datasets (70-30 split)
train_size = int(0.7 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size

# Split the full training dataset into training and validation datasets
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])


In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [None]:
len(train_loader), len(val_loader), len(test_loader)

In [None]:
fig, axs = plt.subplots(1, 8, figsize=(16, 2))
indices = torch.randperm(len(train_dataset))[:8]
for i, idx in enumerate(indices):
    img, label = train_dataset[idx]
    img = reverse_transform(img, train=True)
    axs[i].imshow(img)
    axs[i].set_title(f"{true_labels[label]}")
    axs[i].axis("off")

plt.show()


## Build the architecture

In [None]:
sample_image, sample_target = train_dataset[132]
plt.figure(figsize=(2, 5))
plt.imshow(reverse_transform(sample_image))
plt.title(f"Label: {true_labels[sample_target]}")
plt.axis("off")
plt.show()

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, d_embd, image_shape, dropout=0.0, debug=False):
        super(PatchEmbedding, self).__init__()

        self.patch_size = patch_size
        self.d_embd = d_embd

        self.n_channels, img_height, img_width = image_shape
        self.n_patches = (img_height // self.patch_size) * (
            img_width // self.patch_size
        )

        self.input_dim = self.n_channels * self.patch_size * self.patch_size

        self.projection = nn.Linear(self.input_dim, self.d_embd)

        # [CLASS] token
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.d_embd))

        # Learnabled positional embedding
        self.pos_embedding = nn.Parameter(
            torch.randn(1, self.n_patches + 1, self.d_embd)
        )

        self.dropout = nn.Dropout(dropout)

        self.pre_norm = nn.LayerNorm(self.input_dim)
        self.post_norm = nn.LayerNorm(self.d_embd)

        self.debug = debug

    def _debug_print(self, tensor, name):
        if self.debug:
            print(f"{name}: {tensor.shape}")

    def forward(self, x):
        self._debug_print(x, "[EMBEDDING] Input")
        b, c, h, w = x.shape

        assert (
            h % self.patch_size == 0 and w % self.patch_size == 0
        ), "Invalid patch size"
        assert c == self.n_channels, "Invalid number of channels"

        # (B, C, H, W) -> (B, C, H // patch_size, patch_size, W // patch_size, patch_size)
        x = x.reshape(
            b,
            c,
            h // self.patch_size,
            self.patch_size,
            w // self.patch_size,
            self.patch_size,
        )
        # (B, C, H // patch_size, patch_size, W // patch_size, patch_size) -> (B, H // patch_size, W // patch_size, C, patch_size, patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5)
        # (B, H // patch_size, W // patch_size, C, patch_size, patch_size) -> (B, H // patch_size * W // patch_size, C * patch_size * patch_size)
        x = x.reshape(b, -1, self.input_dim)
        self._debug_print(x, "[EMBEDDING] Flattened")

        x = self.pre_norm(x)
        x = self.projection(x)
        self._debug_print(x, "[EMBEDDING] Projected")
        x = self.post_norm(x)

        cls_token = self.cls_token.repeat(b, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        self._debug_print(x, "[EMBEDDING] CLS token")

        x += self.pos_embedding
        self._debug_print(x, "[EMBEDDING] POS embedding")

        x = self.dropout(x)

        return x

In [None]:
PatchEmbedding(patch_size=4, d_embd=512, image_shape=sample_image.shape, debug=True)(
    sample_image.unsqueeze(0)
)


In [None]:
class EncoderMLP(nn.Module):
    def __init__(self, d_embd, d_hidden, dropout=0.0, debug=False):
        super(EncoderMLP, self).__init__()

        self.net = nn.Sequential(
            nn.LayerNorm(d_embd),
            nn.Linear(d_embd, d_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_hidden, d_embd),
            nn.Dropout(dropout),
        )

        self.debug = debug

    def _debug_print(self, tensor, name):
        if self.debug:
            print(f"{name}: {tensor.shape}")

    def forward(self, x):
        self._debug_print(x, "[MLP] Input")
        x = self.net(x)
        self._debug_print(x, "[MLP] Output")

        return x

In [None]:
EncoderMLP(d_embd=512, d_hidden=2048, debug=True)(torch.randn(1, 197, 512))

In [None]:
class EncoderAttention(nn.Module):
    def __init__(self, d_embd, n_heads, dropout=0.0, debug=False):
        super(EncoderAttention, self).__init__()

        self.d_embd = d_embd
        self.n_heads = n_heads
        self.d_head = d_embd // n_heads

        self.norm = nn.LayerNorm(d_embd)
        self.dropout = nn.Dropout(dropout)

        self.qkv = nn.Linear(d_embd, 3 * d_embd, bias=False)
        self.projection = nn.Sequential(nn.Linear(d_embd, d_embd), nn.Dropout(dropout))

        self.debug = debug

    def _debug_print(self, tensor, name):
        if self.debug:
            print(f"{name}: {tensor.shape}")

    def forward(self, x):
        b, n, _ = x.shape

        self._debug_print(x, "[ATTENTION] Input")
        x = self.norm(x)

        q, k, v = self.qkv(x).view(b, self.n_heads, n, -1).chunk(3, dim=-1)
        self._debug_print(q, "[ATTENTION] Q")
        self._debug_print(k, "[ATTENTION] K")
        self._debug_print(v, "[ATTENTION] V")

        score = torch.matmul(q, k.transpose(-1, -2)) * (self.d_head**-0.5)
        self._debug_print(score, "[ATTENTION] Score")

        attn = F.softmax(score, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v).reshape(b, n, -1)
        self._debug_print(out, "[ATTENTION] Reshaped Scores")

        out = self.projection(out)
        self._debug_print(out, "[ATTENTION] Output")

        return out

In [None]:
EncoderAttention(d_embd=512, n_heads=8, debug=True)(torch.randn(1, 197, 512))

In [None]:
class Transformer(nn.Module):
    def __init__(self, d_embd, d_hidden, n_heads, dropout=0.0, debug=False):
        super(Transformer, self).__init__()

        self.d_embd = d_embd
        self.d_hidden = d_hidden
        self.n_heads = n_heads

        self.attention = EncoderAttention(d_embd, n_heads, dropout, debug)
        self.mlp = EncoderMLP(d_embd, d_hidden, dropout, debug)

        self.debug = debug

    def _debug_print(self, tensor, name):
        if self.debug:
            print(f"{name}: {tensor.shape}")

    def forward(self, x):
        self._debug_print(x, "[TRANSFORMER] Input")

        x = x + self.attention(x)
        self._debug_print(x, "[TRANSFORMER] Attention")

        x = x + self.mlp(x)
        self._debug_print(x, "[TRANSFORMER] MLP")

        return x

In [None]:
Transformer(d_embd=512, d_hidden=2048, n_heads=8, debug=True)(
    PatchEmbedding(patch_size=4, d_embd=512, image_shape=sample_image.shape)(
        sample_image.unsqueeze(0)
    )
)

In [None]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        n_classes,
        patch_size,
        n_blocks,
        d_embd,
        d_hidden,
        n_heads,
        image_shape,
        dropout=0.0,
        debug=False,
    ):
        super(VisionTransformer, self).__init__()

        self.patch_size = patch_size
        self.d_embd = d_embd
        self.d_hidden = d_hidden
        self.n_heads = n_heads

        self.patch_embedding = PatchEmbedding(
            patch_size, d_embd, image_shape, dropout, debug
        )

        self.encoder = nn.ModuleList(
            [
                Transformer(d_embd, d_hidden, n_heads, dropout, debug)
                for _ in range(n_blocks)
            ]
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_embd),
            nn.Linear(d_embd, d_embd),
            nn.Linear(d_embd, n_classes),
        )
        self.norm = nn.LayerNorm(d_embd)

        self.debug = debug

    def _debug_print(self, tensor, name):
        if self.debug:
            print(f"{name}: {tensor.shape}")

    def forward(self, x):
        self._debug_print(x, "[VISION TRANSFORMER] Input")

        x = self.patch_embedding(x)
        self._debug_print(x, "[VISION TRANSFORMER] Patch Embedding")

        for i, block in enumerate(self.encoder):
            x = block(x)
            self._debug_print(x, f"[VISION TRANSFORMER] Encoder Block {i}")

        x = self.norm(x)
        self._debug_print(x, "[VISION TRANSFORMER] LayerNorm")

        # Extract the [CLS] token
        x = x[:, 0]
        self._debug_print(x, "[VISION TRANSFORMER] CLS Token")

        x = self.mlp_head(x)
        self._debug_print(x, "[VISION TRANSFORMER] Output")

        return x

In [None]:
VisionTransformer(
    n_classes=100,
    patch_size=4,
    n_blocks=2,
    d_embd=512,
    d_hidden=2048,
    n_heads=4,
    image_shape=sample_image.shape,
    debug=True,
)(sample_image.unsqueeze(0))

In [None]:
# wandb.init(
#     project="pix2pix",
#     config={
#         "Generator Optimizer": g_optimizer.__class__.__name__,
#         "Discriminator Optimizer": d_optimizer.__class__.__name__,
#         "Loss Function": loss_fn.__class__.__name__,
#         "L1 Loss Function": l1_loss.__class__.__name__,
#         "Batch Size": BATCH_SIZE,
#         "Epochs": N_EPOCHS,
#         "Learning Rate": LEARNING_RATE,
#         "L1 Lambda": L1_LAMBDA,
#         "Generator Feature Map Size": GEN_FEATURE_MAP,
#         "Discriminator Feature Map Size": DISC_FEATURE_MAPS,
#         "Total Parameters": total_param_count,
#     },
# )

In [None]:
model = VisionTransformer(
    n_classes=NUM_CLASSES,
    patch_size=PATCH_SIZE,
    n_blocks=NUM_BLOCKS,
    d_embd=DIM_EMBEDDING,
    d_hidden=DIM_HIDDEN,
    n_heads=NUM_HEADS,
    dropout=DROPOUT_RATE,
    image_shape=IMAGE_SHAPE,
    debug=False,
).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

total_param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Parameters: {total_param_count // 1e6:.2f}M")

In [None]:
for epoch in tqdm(range(2), desc="Epochs"):
    # -------------------
    # Training loop
    # -------------------
    train_loss = 0.0
    for idx, (imgs, targets) in enumerate(
        tqdm(train_loader, desc=f"Training Epoch {epoch}")
    ):
        batch_size = imgs.size(0)
        imgs, targets = imgs.to(device), targets.to(device)

        optimizer.zero_grad()

        predictions = model(imgs)
        predictions = F.softmax(predictions, dim=1)

        loss = loss_fn(predictions, targets)

        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)
    print(f"Epoch {epoch} | Train Loss: {train_loss:.4f}")

    # -------------------
    # Validation loop
    # -------------------

    if epoch % 10 == 0:
        val_loss = 0.0
        val_correct_top1 = 0
        val_correct_top5 = 0
        total = 0

        with torch.no_grad():
            for idx, (imgs, targets) in enumerate(
                tqdm(val_loader, desc=f"Validation Epoch {epoch}")
            ):
                batch_size = imgs.size(0)
                imgs, targets = imgs.to(device), targets.to(device)

                predictions = model(imgs)

                loss = loss_fn(predictions, targets)

                val_loss += loss.item()

                _, pred_top5 = predictions.topk(5, dim=-1)
                val_correct_top1 += pred_top5[:, 0].eq(targets.view(-1, 1)).sum().item()
                val_correct_top5 += (
                    pred_top5.eq(targets.view(-1, 1).expand_as(pred_top5)).sum().item()
                )
                total += targets.size(0)

        val_loss /= len(val_loader)
        val_top1_acc = val_correct_top1 / total
        val_top5_acc = val_correct_top5 / total

        print(
            f"Epoch {epoch} | Validation Loss: {val_loss:.4f} | Validation Acc@1: {val_top1_acc:.2%} | Validation Acc@5: {val_top5_acc:.2%}"
        )

        # Save checkpoint
        torch.save(model.state_dict(), f"./data/models/vit_epoch_{epoch}.pth")

# Save final model
torch.save(model.state_dict(), "./data/models/vit_final.pth")

In [None]:
test_img, test_target = test_dataset[563]

model.eval()
with torch.no_grad():
    test_img = test_img.unsqueeze(0).to(device)
    test_pred = model(test_img)
    test_pred = F.softmax(test_pred, dim=1)

    test_pred = test_pred.cpu().numpy().squeeze()
    test_pred_label = np.argmax(test_pred)

    print(f"Predicted Label: {true_labels[test_pred_label]}")
    print(f"True Label: {true_labels[test_target]}")

    test_img = reverse_transform(test_img.squeeze(), train=False)
    plt.imshow(test_img)
    plt.title(
        f"Predicted: {true_labels[test_pred_label]} | True: {true_labels[test_target]}"
    )
    plt.axis("off")

    plt.show()
