In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy, precision
import torchmetrics.functional as tf

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar

from data.cifar100 import CIFAR100DataModule
from data.cifar10 import CIFAR10DataModule

# Vision Transformer Demo
In this notebook, I present a simple implementation of the Vision Transformer (ViT) from "An Image is Worth 16x16 Words: Transformers for Image Recognition At Scale." I have attempted to represent the essential components as faithfully as possible. This demo is much smaller-scale than the original, since I am training it on CIFAR-100 (with 60,000 images) instead of the 300-million JFT-300M dataset owned by Google.

Though originally developed for NLP, the transformer architecture is gradually making its way into many different areas of deep learning, including image classification and labeling and even reinforcement learning. It's an amazingly versatile architecture and very powerful at representing whatever it's being used to model.

As part of my effort to understand fundamental architectures and their applications better, I decided to implement the vision transformer (ViT) from the paper¹ directly, without referencing the official codebase. Here, I'll explain how it works (and how my version is implemented). I'll start with a brief review of how transformers work, but I won't get too deep into the weeds here since there are many other excellent guides to transformers (see The Illustrated Transformer for my favorite one). In addition, I'll cover my basic suggestions for using the process of implementing papers for learning.

In [2]:
# Set this to whatever folder you wish CIFAR-100 to be downloaded into
CIFAR = "/media/curttigges/project-files/datasets/cifar-100/"

## Image Preparation
This function is essential to the functioning of ViT. Essentially, we take images provided by the dataloader and cut them into a series of patches, which are then flattened into a single dimension as shown. This prepares them to be passed through a linear projection layer.

In [3]:
def img_to_patch(x, patch_size):
    '''Transforms image into list of patches of the specified dimensions

    Args:
        x (Tensor): Tensor of dimensions B x C x H x W, representing a batch.
        B=Batch size, C=Channel count.
        patch_size (int): Size of one side of (square) patch.

    Returns:
        patch_seq (Tensor): List of patches of dimension B x N x [C * P ** 2],
        where N is the number of patches and P is patch_size.

    Notes:
        May need to add padding
    '''
    B, C, H, W = x.shape

    # reshape to B x C x H_count x H_patch x W_count x W_patch
    x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.flatten(1,2)
    x = x.flatten(2, 4)
    
    return x

## Transformer Encoder
This is a fairly standard transformer encoder layer, which I have implemented myself (instead of using the PyTorch default) for demonstration purposes. Note that the paper uses this norm-first variant.

In [4]:
class ViTEncoder(nn.Module):
    '''Basic transformer encoder, as specified in the paper

    Args:
        input_dim (int): Dimensions of transformer input (input embed size)
        hidden_dim (int): Size of MLP head
        num_heads (int): Number of self-attention heads
        dropout (float): Probability of dropout
    '''
    def __init__(self, input_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(input_dim)
        self.attn = nn.MultiheadAttention(input_dim, num_heads)
        self.norm2 = nn.LayerNorm(input_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, input_dim)
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x):
        out = self.norm1(x)
        out, _ = self.attn(out, out, out)
        
        # First residual connection
        resid = x + out

        # Pass through MLP layer
        out = self.norm2(resid)
        out = F.gelu(self.fc1(out))
        out = self.drop1(out)
        out = self.fc2(out)
        out = self.drop2(out)

        # Second residual connection
        out = out + resid

        return out

## Complete Model
Here we can see the entirety of the model, which is surprisingly simple. It simply does the following in order:
- Input is cut into patches and flattened.
- Input is then passed through a linear projection layer and an activation to create an embedding.
- A learnable class embedding is concatenated to the embedding.
- Learnable positional embeddings are added to the result from above.
- The result is transposed and sent through the transformer, which consists of a number of encoders.
- Finally, only the class embedding is taken from the transformer output, and this is passed through a linear classification head (with two layers).

In [5]:
class ViTClassifier(nn.Module):
    '''Encoder-only vision transformer

    Args:
        embed_dim (int): Size of embedding output from linear projection layer
        hidden_dim (int): Size of MLP head
        class_head_dim (int): Size of classification head
        num_encoders (int): Number of encoder layers
        num_heads (int): Number of self-attention heads
        patch_size (int): Size of patches
        num_patches (int): Total count of patches (patch sequence size) 
        dropout (float): Probability of dropout
    '''
    def __init__(
        self, embed_size, hidden_size, class_head_dim, num_encoders, 
        num_heads, patch_size, num_patches, dropout):
        super().__init__()

        # Key parameters
        self.patch_size = patch_size
        self.num_patches = num_patches

        # Initial projection of flattened patches into an embedding
        self.input = nn.Linear(3*(patch_size**2), embed_size)
        self.drop = nn.Dropout(dropout)

        # Transformer with arbitrary number of encoders, heads, and hidden size
        self.transformer = nn.Sequential(
            *(ViTEncoder(embed_size, hidden_size, num_heads, dropout) for _ in range(num_encoders))
        )
        
        # Classification head
        self.fc1 = nn.Linear(embed_size, class_head_dim)
        self.fc2 = nn.Linear(class_head_dim, 100)

        # Learnable parameters for class and position embedding
        self.class_embed = nn.Parameter(torch.randn(1, 1, embed_size))
        self.pos_embed = nn.Parameter(torch.randn(1, 1 + num_patches, embed_size))

    def forward(self, x):
        # x will be in the shape B x N x C x P x P
        x = img_to_patch(x, self.patch_size)       

        # pass input through projection layer; shape is B x N x (C * P**2)
        x = F.relu(self.input(x))
        B, N, L = x.shape

        # concatenate class embedding and add positional encoding
        class_embed = self.class_embed.repeat(B, 1, 1)
        x = torch.cat([class_embed, x], dim=1)
        x = x + self.pos_embed[:, :N+1]
        x = self.drop(x)

        # apply transformer
        x = x.transpose(0, 1) # result is N x B x (C * P**2)
        x = self.transformer(x)
        x = x[0] # grab the class embedding
        
        # pass through classification head
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Training
We use PyTorch Lightning for training, as this greatly simplifies and organizes the relevant code. The model is trained with the Adam optimizer with β1 = 0.9, β2 = 0.999, and a weight decay of 0.1. I use OneCycleLR with cosine decay to optimize the learning rate.

In [6]:
class ViTTrainModule(pl.LightningModule):
    '''Encoder-only vision transformer

    Args:
        embed_dim (int): Size of embedding output from linear projection layer
        hidden_dim (int): Size of MLP head
        class_head_dim (int): Size of classification head
        num_encoders (int): Number of encoder layers
        num_heads (int): Number of self-attention heads
        patch_size (int): Size of patches
        num_patches (int): Total count of patches (patch sequence size) 
        dropout (float): Probability of dropout
        batch_size (int): Batch size (used for OneCycleLR)
        learning_rate (float): Maximum learning rate
        weight_decay (float): Optimizer weight decay
    '''
    def __init__(
        self, 
        embed_dim, 
        hidden_dim, 
        class_head_dim, 
        num_encoders, 
        num_heads, 
        patch_size, 
        num_patches, 
        dropout, 
        batch_size, 
        learning_rate=0.001,
        weight_decay=0.03):
        super().__init__()

        # Key parameters
        self.save_hyperparameters()

        # Transformer with arbitrary number of encoders, heads, and hidden size
        self.model = ViTClassifier(
            embed_dim,
            hidden_dim,
            class_head_dim,
            num_encoders,
            num_heads,
            patch_size,
            num_patches,
            dropout
        )

    def forward(self, x):
        x = self.model(x)        
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)

        category_prec = precision(y_hat, y.type(torch.int), average='macro', num_classes=100)
        category_recall = tf.recall(y_hat, y.type(torch.int), average='macro', num_classes=100)
        category_f1 = tf.f1_score(y_hat, y.type(torch.int), average='macro', num_classes=100)

        overall_prec = precision(y_hat, y.type(torch.int))
        overall_recall = tf.recall(y_hat, y.type(torch.int))
        overall_f1 = tf.f1_score(y_hat, y.type(torch.int))

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

            self.log(f"{stage}_cat_prec", category_prec, prog_bar=True)
            self.log(f"{stage}_cat_recall", category_recall, prog_bar=True)
            self.log(f"{stage}_cat_f1", category_f1, prog_bar=True)

            self.log(f"{stage}_ovr_prec", overall_prec, prog_bar=True)
            self.log(f"{stage}_ovr_recall", overall_recall, prog_bar=True)
            self.log(f"{stage}_ovr_f1", overall_f1, prog_bar=True)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.hparams.learning_rate,
            betas=(0.9,0.999),
            weight_decay=self.hparams.weight_decay)
        
        steps_per_epoch = 60000 // self.hparams.batch_size
   
        lr_scheduler_dict = {
            "scheduler":OneCycleLR(
                optimizer,
                self.hparams.learning_rate,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=steps_per_epoch,
                anneal_strategy='cos'
            ),
            "interval":"step",
        }
        return {"optimizer":optimizer, "lr_scheduler":lr_scheduler_dict}

## Specifications
Here we specify the size and other attributes of the model. In the paper, the authors trained models of various sizes, but their base model consisted of 12 encoder layers, a hidden (embedding) size of 768, an MLP width of 3072, and 12 heads for the self-attention layer, which came to 86 million parameters in total. For this demo, I initialize a smaller model for ease of training, and because excessively large model don't do well on smaller datasets like CIFAR-100.

In [7]:
model_kwargs = {
    "embed_dim":256, 
    "hidden_dim":512,
    "class_head_dim":512, 
    "num_encoders":24,
    "num_heads":8,
    "patch_size":4,
    "num_patches":64,
    "dropout":0.1,
    "batch_size":256,
    "learning_rate":0.001,
    "weight_decay":0.03
}


## Initialization
Here we initialize our data module and model.

In [8]:
cifar100 = CIFAR100DataModule(
    batch_size=model_kwargs["batch_size"], 
    num_workers=12, 
    data_dir=CIFAR)

pl.seed_everything(42)
model = ViTTrainModule(**model_kwargs)

Global seed set to 42


## Instrumentation
If you have a Weights & Biases account and want to monitor the training progress of the model, you can initialize this with the following cell.

In [9]:
wandb_logger = WandbLogger(project="vit-cifar100")
wandb_logger.watch(model, log="all")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33mascendant[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


## Training
Finally, we initialize our trainer and run it. Be sure to comment out the WandB line if not using Weights & Biases.

In [10]:
trainer = pl.Trainer(
    max_epochs=180,
    accelerator='gpu', 
    devices=1,
    logger=wandb_logger, #comment out if not using WandB
    callbacks=[TQDMProgressBar(refresh_rate=10)])
    
trainer.fit(model, datamodule=cifar100)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Global seed set to 42
Global seed set to 42


Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type          | Params
----------------------------------------
0 | model | ViTClassifier | 12.9 M
----------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.451    Total estimated model params size (MB)


Epoch 179: 100%|██████████| 196/196 [1:41:41<00:00, 31.13s/it, loss=0.456, v_num=b10f, val_loss=3.340, val_acc=0.498, val_cat_prec=0.469, val_cat_recall=0.462, val_cat_f1=0.432, val_ovr_prec=0.498, val_ovr_recall=0.498, val_ovr_f1=0.498]   


## Results
As you can see, the model is able to reach an accuracy of only 49.8%, which is far from SOTA performance. However, vision transformers typically only perform better than CNNs on enormous datasets. These are generally impractical for individuals to train (especially since Google has exclusive access to JFT-300M), so I have limited myself to CIFAR.

It does perform better on CIFAR-10, as we can see below:

In [7]:
model_kwargs = {
    "embed_dim":256, 
    "hidden_dim":512,
    "class_head_dim":512, 
    "num_encoders":6,
    "num_heads":8,
    "patch_size":4,
    "num_patches":64,
    "dropout":0.2,
    "batch_size":256,
    "learning_rate":0.0003,
    "weight_decay":0.03
}


In [8]:
CIFAR10 = "/media/curttigges/project-files/datasets/cifar-10/"

cifar100 = CIFAR10DataModule(
    batch_size=model_kwargs["batch_size"], 
    num_workers=12, 
    data_dir=CIFAR10)

pl.seed_everything(42)
model = ViTTrainModule(**model_kwargs)

Global seed set to 42


In [9]:
wandb_logger = WandbLogger(project="vit-cifar10")
wandb_logger.watch(model, log="all")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33mascendant[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [11]:
trainer = pl.Trainer(
    max_epochs=180,
    accelerator='gpu', 
    devices=1,
    logger=wandb_logger, #comment out if not using WandB
    callbacks=[TQDMProgressBar(refresh_rate=10)])
    
trainer.fit(model, datamodule=cifar100)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Global seed set to 42
Global seed set to 42


Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type          | Params
----------------------------------------
0 | model | ViTClassifier | 3.4 M 
----------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.500    Total estimated model params size (MB)


Epoch 179: 100%|██████████| 196/196 [28:17<00:00,  8.66s/it, loss=0.512, v_num=8mhy, val_loss=0.771, val_acc=0.788, val_cat_prec=0.786, val_cat_recall=0.786, val_cat_f1=0.782, val_ovr_prec=0.788, val_ovr_recall=0.788, val_ovr_f1=0.788]  


On CIFAR-10, we get an accuracy of 78.8%. Not bad, but not as good as the 9x% that is SOTA. We will need a much, much larger dataset to exceed CNN models. Nevertheless, this illustrates that the basic approach works, and that transformers can be surprisingly good for computer vision!