In [None]:
import torch
import torchvision
import torch.nn as nn
from torchvision.transforms import v2

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = v2.Compose(
    [
        v2.PILToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize((0.5,), (0.5,)),
    ]
)

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST(
    "./data", train=True, transform=transform, download=True
)
validation_set = torchvision.datasets.FashionMNIST(
    "./data", train=False, transform=transform, download=True
)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(
    validation_set, batch_size=4, shuffle=False
)

# Class labels
classes = (
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle Boot",
)

# Report split sizes
print("Training set has {} instances".format(len(training_set)))
print("Validation set has {} instances".format(len(validation_set)))

In [None]:
import matplotlib.pyplot as plt
import numpy as np


# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5  # un-normalize
    np_img = img.numpy()
    if one_channel:
        plt.imshow(np_img, cmap="Greys")
    else:
        plt.imshow(np.transpose(np_img, (1, 2, 0)))


data_iter = iter(training_loader)
images, labels = next(data_iter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print("  ".join(classes[labels[j]] for j in range(4)))

---

## Vision Transformer Model

We will implement:
- Positional Encoding ([Visit here for code provided by PyTorch](https://pytorch.org/tutorials/beginner/translation_transformer.html#seq2seq-network-using-transformer) )
- Patch Embedding
- Vision Transformer Model (using PyTorch Encoder Layer)

---

## Positional Encoding

<img src="./assets/positional-encoding.webp" width="900" height="500"/>

In [None]:
import math


# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float = 0.0, maxlen: int = 5000):
        if emb_size % 2 != 0:
            raise Exception("Embedding size must be even")
        super(PositionalEncoding, self).__init__()

        # fancy logarithmic stuff to finally calculate 10000^(-2i/emb_size)
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))

        pos_embedding[:, 0::2] = torch.sin(
            pos * den
        )  # from 0th index to last index with a step of 2
        pos_embedding[:, 1::2] = torch.cos(
            pos * den
        )  # from 1st index to last index with a step of 2
        # pos_embedding = pos_embedding.unsqueeze(0) # add a new dimension at the first index, we will use batch_first = True (handle batch dimension)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "pos_embedding", pos_embedding
        )  # register the tensor as buffer - not updated during backprop

    def forward(self, token_embedding):
        # no learnable parameters
        return self.dropout(token_embedding + self.pos_embedding)

---

## Patch Embedding

<img src="./assets/patch-embedding.png" width="700px" height="400px">

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size,  # take each patch_size pixel
            ),
            nn.Flatten(2),  # flatten from the 2nd dimension to the end
        )

        # special classification token
        self.special_classification_token = nn.Parameter(
            torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True
        )

        self.position_embeddings = PositionalEncoding(
            emb_size=embed_dim, dropout=dropout, maxlen=num_patches + 1
        )

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        cls_token = self.special_classification_token.expand(
            x.shape[0], -1, -1
        )  # (B, input_channel, E)

        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.position_embeddings(x)
        x = self.dropout(x)
        return x

---

## Vision Transformer Model

<img src="./assets/ViT.png" width="700px" height="400px">

In [None]:
class ViT(nn.Module):
    def __init__(
        self,
        num_patches,
        img_size,
        num_classes,
        patch_size,
        embed_dim,
        num_encoders,
        num_heads,
        hidden_dim,
        dropout,
        activation,
        in_channels,
    ):
        super().__init__()
        self.embeddings_block = PatchEmbedding(
            embed_dim, patch_size, num_patches, dropout, in_channels
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            norm_first=True,
        )
        self.encoder_blocks = nn.TransformerEncoder(
            encoder_layer, num_layers=num_encoders
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes),
        )

    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])  # Apply MLP on the CLS token only
        return x

In [None]:
RANDOM_SEED = 42
BATCH_SIZE = 512
EPOCHS = 40
LEARNING_RATE = 1e-4
NUM_CLASSES = 10
PATCH_SIZE = 4
IMG_SIZE = 28
IN_CHANNELS = 1
NUM_HEADS = 8
DROPOUT = 0.001
HIDDEN_DIM = 768
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION = "gelu"
NUM_ENCODERS = 4
EMBED_DIM = (PATCH_SIZE**2) * IN_CHANNELS  # 16
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2  # 49


device = "cuda" if torch.cuda.is_available() else "cpu"

---

## Dummy data and testing the model

In [None]:
model = ViT(
    NUM_PATCHES,
    IMG_SIZE,
    NUM_CLASSES,
    PATCH_SIZE,
    EMBED_DIM,
    NUM_ENCODERS,
    NUM_HEADS,
    HIDDEN_DIM,
    DROPOUT,
    ACTIVATION,
    IN_CHANNELS,
).to(device)
x = torch.randn(512, 1, 28, 28).to(device)
print(model(x).shape)  # BATCH_SIZE X NUM_CLASSES

---

## Now, let's train the model

In [None]:
import torch.optim as optim

loss_fn = nn.CrossEntropyLoss()

# values from the paper
optimizer = optim.Adam(
    model.parameters(),
    betas=ADAM_BETAS,
    lr=LEARNING_RATE,
    weight_decay=ADAM_WEIGHT_DECAY,
)

In [None]:
import random
import timeit
from tqdm import tqdm


def train_one_epoch():
    running_loss = 0.0
    last_loss = 0.0

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs.to(device))

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels.to(device))
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000  # loss per batch
            print("  batch {} loss: {}".format(i + 1, last_loss))
            running_loss = 0.0

    return last_loss

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
epoch_number = 0

EPOCHS = 5  # good enough for testing, also dataset is small
my_loss_curve = []

best_v_loss = 1_000_000.0

for epoch in range(EPOCHS):
    print("EPOCH {}:".format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch()

    running_v_loss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, v_data in enumerate(validation_loader):
            v_inputs, v_labels = v_data
            v_outputs = model(v_inputs.to(device))
            v_loss = loss_fn(v_outputs, v_labels.to(device))
            running_v_loss += v_loss.item()

    avg_v_loss = running_v_loss / (i + 1)
    print("LOSS train {} valid {}".format(avg_loss, avg_v_loss))

    my_loss_curve.append((avg_loss, avg_v_loss))

    # Track best performance, and save the model's state
    if avg_v_loss < best_v_loss:
        print("-" * 80)
        print("  New best loss! Saving model.")
        best_v_loss = avg_v_loss
        model_path = "models/best-model-trained.pth"
        torch.save(model.state_dict(), model_path)
        print("\n" + "-" * 80 + "\n")

    epoch_number += 1

---

### let's plot the loss curve

In [None]:
import matplotlib.pyplot as plt

# Plot the loss curve
train_losses = [x[0] for x in my_loss_curve]
valid_losses = [x[1] for x in my_loss_curve]

plt.plot(train_losses, label="Training loss")
plt.plot(valid_losses, label="Validation loss")
plt.legend()
plt.title("Loss curve")
plt.show()

---

## Visualize prediction

<img src="./assets/prediction-by-ViT.png" height="400" width="400"/>

In [None]:
def show_my_img(my_img, my_label, pred, ax, j):
    row = j // 2
    col = j % 2

    ax[row, col].imshow(my_img.permute(1, 2, 0))  # permute to (H, W, C)
    ax[row, col].axis("off")
    if my_label == pred:
        ax[row, col].set_title(classes[my_label], color="green")
    else:
        ax[row, col].set_title(classes[pred] + "=> " + classes[my_label], color="red")


for i in validation_loader:
    v_inputs, v_labels = i

    v_outputs = model(v_inputs.to(device))
    my_preds = [x.argmax().item() for x in v_outputs]
    fig, ax = plt.subplots(nrows=2, ncols=2)
    for j in range(len(my_preds)):
        show_my_img(v_inputs[j], v_labels[j], my_preds[j], ax, j)
    plt.show()
    break