<a href="https://colab.research.google.com/github/bitfromit2byte/Vision-Transformer/blob/main/Vision_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

**Transform dataset & Create DataLoader**

In [None]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2470, 0.2435, 0.2616])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data',train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
len(train_loader)

In [None]:
next(iter(train_loader))

In [None]:
# One of batch of images
image_batch, label_batch = next(iter(train_loader))

# map class indices to class names
cifar10_classes = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship','truck'
]

# get names for the batch
label_names = [cifar10_classes[i] for i in label_batch]

In [None]:
# first image label
label_names[0]

In [None]:
image, label = image_batch[0], label_names[0]
image.shape, label

In [None]:
plt.imshow(image.permute(1,2,0))
plt.title(label)
plt.axis(False);

**Reproducing the Vision Transformer Architecture**

    based on paper 'An image is worth 16x16 words: transformers for image recognition at scale'
    https://arxiv.org/abs/2010.11929


In [None]:
# Create example values
height = 32
width = 32
color_channels = 3
patch_size = 4

# Number of patches
number_of_patches = int((height * width) / patch_size**2)
print(f'Number of patches (N) with image height (H={height}), width (W={width}) and patch size (P={patch_size}): {number_of_patches}')

In [None]:
# Input shape (size of single image)
embedding_input = (height, width, color_channels)

# Output shape
embedding_output = (number_of_patches, patch_size**2 * color_channels)

print(f'Input shape (single 2d image): {embedding_input}')
print(f'Output shape (flattened 2d image into patches): {embedding_output}')

In [None]:
import numpy as np
image_permuted = image.permute(1, 2, 0)
image_size = 32
patch_size = 4
num_patches = image_size // patch_size

fig, axs = plt.subplots(
    nrows=num_patches,
    ncols=num_patches,
    figsize=(num_patches, num_patches),
    sharex=True,
    sharey=True
)

for i, patch_height in enumerate(range(0, image_size, patch_size)):
    for j, patch_width in enumerate(range(0, image_size, patch_size)):

        patch = image_permuted[
            patch_height:patch_height+patch_size,
            patch_width:patch_width+patch_size,
            :
        ]

        axs[i, j].imshow(np.clip(patch, 0, 1))
        axs[i, j].set_ylabel(i + 1,
                             rotation='horizontal',
                             horizontalalignment='right',
                             verticalalignment='center')
        axs[i, j].set_xlabel(j + 1)
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
        axs[i, j].label_outer()

fig.suptitle(f'{label_names[0]}', fontsize=16)
plt.show()
plt.close(fig)

In [None]:
from torch import nn

patch_size = 4

conv2d = nn.Conv2d(in_channels=3,
                   out_channels=128,
                   kernel_size=patch_size,
                   stride=patch_size,
                   padding=0)

In [None]:
convolution_output = conv2d(image.unsqueeze(0))
print(convolution_output.shape)

In [None]:
import random
random_indexes = random.sample(range(0,128), k=5)
print(f'Showing random convolutional feature maps from indexes: {random_indexes}')

# Create plot
fig, axs = plt.subplots(nrows=1, ncols=5, figsize=(12,12))

for i, idx in enumerate(random_indexes):
    convolution_feature_map = convolution_output[:, idx, :, :] # index on the output tensor of the convolutional layer
    axs[i].imshow(convolution_feature_map.squeeze().detach().numpy())
    axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]);

In [None]:
flatten = nn.Flatten(start_dim=2,
                     end_dim=3)

In [None]:
# Flatten output
flattened_image = flatten(convolution_output)

In [None]:
# batch_size, num_patches, embedding_size
batch_patches_emb = flattened_image.permute(0,2,1)

In [None]:
flattened_feature_map = batch_patches_emb[:, :, 0]

# Plot flattened feature map visually
plt.figure(figsize=(22,22))
plt.imshow(flattened_feature_map.detach().numpy())
plt.title(f'Flattened feature map shape: {flattened_feature_map.shape}')
plt.axis(False)

**Patch Embedding Module**

In [None]:
import torch.nn as nn
class Embedding(nn.Module):
    '''Turns 2d input image into a learnable embedding'''
    def __init__(self,
                 in_channels:int=3,
                 patch_size:int=4,
                 embedding_dim:int=128):
        super().__init__()

        self.patcher = nn.Conv2d(in_channels=in_channels,
                                 out_channels=embedding_dim,
                                 kernel_size=patch_size,
                                 stride=patch_size,
                                 padding=0)

        self.flatten = nn.Flatten(start_dim=2,
                                  end_dim=3)

    def forward(self, x):
        image_resolution = x.shape[-1]

        x_patch = self.patcher(x)
        x_flattened = self.flatten(x_patch)
        return x_flattened.permute(0,2,1)

In [None]:
class MultiheadSelfAttention(nn.Module):
    def __init__(self,
                 embedding_dim:int=128,
                 num_heads:int=4,
                 attn_dropout:float=0):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True)
    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multihead_attn(query=x,
                                             key=x,
                                             value=x,
                                             need_weights=False)
        return attn_output

In [None]:
class MLP(nn.Module):
    '''Creates a layer normalized multilayer perceptron block'''
    def __init__(self,
                 embedding_dim:int=128,
                 mlp_size:int=256,
                 dropout:float=0.1):
        super().__init__()

        # Normalization layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # MLP layer
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                      out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size,
                      out_features=embedding_dim),
            nn.Dropout(p=dropout)
        )

    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self,
                 embedding_dim:int=128,
                 num_heads:int=4,
                 mlp_size:int=256,
                 mlp_dropout:float=0.1,
                 attn_dropout:float=0):
        super().__init__()

        self.msa = MultiheadSelfAttention(embedding_dim=embedding_dim,
                                          num_heads=num_heads,
                                          attn_dropout=attn_dropout)

        self.mlp = MLP(embedding_dim=embedding_dim,
                       mlp_size=mlp_size,
                       dropout=mlp_dropout)

    def forward(self, x):

        x = self.msa(x) + x

        x = self.mlp(x)

        return x

In [None]:
transformer_encoder = TransformerEncoder()

In [None]:
torch_transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=128,
                                                             nhead=4,
                                                             dim_feedforward=256,
                                                             dropout=0.1,
                                                             activation='gelu',
                                                             batch_first=True,
                                                             norm_first=True)

**Vision Transformer Architecture**

In [None]:
class ViT(nn.Module):

    def __init__(self,
                 image_size=32,
                 channels=3,
                 patch_size=4,
                 transformer_layers=6,
                 embedding_dim=128,
                 mlp_size=256,
                 num_heads=4,
                 attn_dropout=0,
                 mlp_dropout=0.1,
                 embedding_dropout=0.1,
                 num_classes=10):
        super().__init__()

        # number of patches
        self.patch_count = (image_size**2) // patch_size**2

        # learnable class embeddings
        self.class_emb = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
                                     requires_grad=True)

        # learnable position embedding
        self.position_emb = nn.Parameter(data=torch.randn(1, self.patch_count+1, embedding_dim),
                                             requires_grad=True)

        # embedding dropout
        self.emb_dropout = nn.Dropout(p=embedding_dropout)

        # patch embedding layer
        self.patch_emb = Embedding(in_channels=channels,
                                   patch_size=patch_size,
                                   embedding_dim=embedding_dim)

        # Transformer Encoder blocks
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=mlp_size,
            dropout=mlp_dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=transformer_layers
        )

        # Classifier Head
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                      out_features=num_classes)
        )

    def forward(self, x):

        batch_size = x.shape[0]

        class_token = self.class_emb.expand(batch_size, -1, -1)

        x = self.patch_emb(x)

        x = torch.cat((class_token, x), dim=1)

        x = self.position_emb + x

        x = self.emb_dropout(x)

        x = self.transformer_encoder(x)

        x = self.classifier(x[:, 0])

        return x

I chose to use PyTorch's transformer encoder instead since it's less error-prone

In [None]:
!pip install torchinfo

In [None]:
from torchinfo import summary

model = ViT()
summary(model=model,
        input_size=(32, 3, 32, 32), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

**Training the Vision Transformer**

In [None]:
import torch
from tqdm.auto import tqdm

def train_epoch(model: torch.nn.Module,
                dataloader: torch.utils.data.DataLoader,
                loss_fn: torch.nn.Module,
                optimizer: torch.optim.Optimizer,
                device: torch.device):
    model.train()
    train_loss, train_acc = 0, 0

    for batch, (input, target) in enumerate(dataloader):
        input, target = input.to(device), target.to(device)

        # Forward pass
        y_pred = model(input)

        # Calculate and accumulate loss
        loss = loss_fn(y_pred, target)
        train_loss += loss.item()

        # Optimizer zero grad
        optimizer.zero_grad()

        # Loss backward
        loss.backward()

        optimizer.step()

        # Calculate cumulative accuracy
        predicted_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (predicted_class == target).sum().item()/len(y_pred)

    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return { 'loss': train_loss,
             'accuracy': train_acc
           }

def evaluate(model: torch.nn.Module,
             dataloader: torch.utils.data.DataLoader,
             loss_fn: torch.nn.Module,
             device: torch.device):


    model.eval()

    test_loss, test_acc = 0, 0

    with torch.inference_mode():
        for batch, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            test_loss += loss.item()

            # calculate cumulative accuracy
            predicted_test_labels = outputs.argmax(dim=1)
            test_acc += ((predicted_test_labels == targets).sum().item()/len(predicted_test_labels))

    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return {
        "loss": test_loss,
        "accuracy": test_acc
    }


def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device):

    history = {'train_loss': [], 'train_acc': [],
               'test_loss': [], 'test_acc': []}

    for epoch in tqdm(range(epochs)):

        train_metrics = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
        test_metrics = evaluate(model, test_dataloader, loss_fn, device)

        history['train_loss'].append(train_metrics['loss'])
        history['train_acc'].append(train_metrics['accuracy'])
        history['test_loss'].append(test_metrics['loss'])
        history['test_acc'].append(test_metrics['accuracy'])

        print(f'Epoch {epoch}/{epochs}')
        print(f"Train Loss: {train_metrics['loss']:.4f} | "
              f"Train Acc: {train_metrics['accuracy'] * 100:.2f}%")
        print(f"Test Loss: {test_metrics['loss']:.4f} | "
              f"Test Acc: {test_metrics['accuracy'] * 100:.2f}%")
        print('-' * 50)

    return history

In [None]:
import torch

# seed for consistent output
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

#Create optimizer
optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=3e-3,
                             betas=(0.9, 0.999),
                             weight_decay=0.3)

# Loss function for multi-output classfication
loss_fn = torch.nn.CrossEntropyLoss()

results = train(model=model,
                train_dataloader=train_loader,
                test_dataloader=test_loader,
                optimizer=optimizer,
                loss_fn=loss_fn,
                epochs=8,
                device=device)

In [None]:
def plot_loss_curves(results):

    # Get the loss values of the results dictionary (training and test)
    loss = results['train_loss']
    test_loss = results['test_loss']

    # Get the accuracy values of the results dictionary (training and test)
    accuracy = results['train_acc']
    test_accuracy = results['test_acc']

    # Figure out how many epochs there were
    epochs = range(len(results['train_loss']))

    # Setup a plot
    plt.figure(figsize=(15, 7))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, label='train_loss')
    plt.plot(epochs, test_loss, label='test_loss')
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy, label='train_accuracy')
    plt.plot(epochs, test_accuracy, label='test_accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.legend()

In [None]:
plot_loss_curves(results)

**Pretrained model of ViT**

In [None]:
import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # requires

# 2. Setup a ViT model instance with pretrained weights
weights = torchvision.models.ViT_B_16_Weights.DEFAULT
pretrained_vit = torchvision.models.vit_b_16(weights=weights).to(device)


# 3. Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# 4. Change the classifier head (set the seeds to ensure same initialization with linear head)
torch.manual_seed(42)
pretrained_vit.heads = nn.Linear(
    in_features=768,  # usually 768
    out_features=10                                # CIFAR-10 classes
).to(device)
# pretrained_vit # uncomment for model output


In [None]:
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

In [None]:

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data',train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)


In [None]:
optimizer = torch.optim.Adam(params=pretrained_vit.heads.parameters(),
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

torch.manual_seed(42)
pretrained_vit_results = train(model=pretrained_vit,
                train_dataloader=train_loader,
                test_dataloader=test_loader,
                optimizer=optimizer,
                loss_fn=loss_fn,
                epochs=8,
                device=device)

In [None]:
plot_loss_curves(pretrained_vit_results)

**Prediction made by Pretrained ViT on CIFAR-10 image of a frog**

In [None]:
import torch
from pathlib import Path

target_dir = 'models'

target_dir_path = Path(target_dir)

target_dir_path.mkdir(parents=True,
                      exist_ok=True)

model_name='pretrained_vit_feature_extractor'

model_save_path = target_dir_path / model_name

print(f"Model saved in: {model_save_path}")
torch.save(obj=pretrained_vit.state_dict(),
           f=model_save_path)

In [None]:
from torchvision import datasets

# Load CIFAR-10 without transforms (important!)
dataset = datasets.CIFAR10(root="./data", train=True, download=True)

# Pick an image (change index if you want)
image, label = dataset[0]

# Save it
image.save("cifar10_sample.png")

print("Saved cifar10_sample.png")


In [None]:
from google.colab import files
files.download("cifar10_sample.png")

In [None]:
from google.colab import files

uploaded = files.upload()

In [None]:
from PIL import Image

custom_image_path = next(iter(uploaded))
image = Image.open(custom_image_path).convert('RGB')

In [None]:
image_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

pretrained_vit.to(device)

pretrained_vit.eval()
with torch.inference_mode():
  transformed_image = image_transform(image).unsqueeze(dim=0)

  target_image_pred = pretrained_vit(transformed_image.to(device))

  target_image_pred_probs = torch.softmax(target_image_pred, dim=1)

  target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)

plt.figure()
plt.imshow(image)
plt.title(f"Pred: {label_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}")
plt.axis(False);

In [None]:
!git clone https://github.com/bitfromit2byte/Vision-Transformer.git