# Image Classification Using Vision Transformer

[[Paper](https://arxiv.org/abs/2010.11929)] [[Notebook](https://github.com/fastestimator/fastestimator/blob/master/apphub/image_classification/vit/vit.ipynb)] [[TF Implementation](https://github.com/fastestimator/fastestimator/blob/master/apphub/image_classification/vit/vit_tf.py)] [[Torch Implementation](https://github.com/fastestimator/fastestimator/blob/master/apphub/image_classification/vit/vit_torch.py)]

Vision Transformer (ViT) is a new alternative to Convolution Neural Networks (CNNs) in the field of computer vision. The idea of <b>ViT</b> was inspired from the success of the [Transformer](https://arxiv.org/abs/1706.03762) and [BERT](https://arxiv.org/abs/1810.04805) architectures in NLP applications. In this example, we will implement a ViT in PyTorch and showcase how to pre-train a ViT and then fine-tune it on a downstream task for good results with minimal downstream training time. 

## ViT Model
The ViT model is almost the same as the original Transformer except for the following differences:
1. Input image is broken down into small patches, which are used as sequences similar to language. The patching and embedding are implemented by a Convolution2D operation in the `patch_embedding`.
2. Different from original Transformer, the positional embedding is now a trainable parameter.
3. Similar to BERT, a `CLS` token is added before the patch sequence. But in contrast to BERT, the value of the `CLS` token is trainable.
4. After the Transformer encoding, only the embedding corresponding to the `CLS` token will be used as feature for the classification layer.

In [1]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class ViTEmbeddings(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_channels=3, em_dim=768, drop=0.1) -> None:
        super().__init__()
        assert image_size % patch_size == 0, "image size must be an integer multiply of patch size"
        self.patch_embedding = nn.Conv2d(num_channels, em_dim, kernel_size=patch_size, stride=patch_size, bias=False)
        self.position_embedding = nn.Parameter(torch.zeros(1, (image_size // patch_size)**2 + 1, em_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, em_dim))
        self.dropout = nn.Dropout(drop)

    def forward(self, x):
        x = self.patch_embedding(x).flatten(2).transpose(1, 2)  # [B,C, H, W] -> [B, num_patches, em_dim]
        x = torch.cat([self.cls_token.expand(x.size(0), -1, -1), x], dim=1)  # [B, num_patches+1, em_dim]
        x = x + self.position_embedding
        x = self.dropout(x)
        return x

    
class ViTEncoder(nn.Module):
    def __init__(self, num_layers, image_size, patch_size, num_channels, em_dim, drop, num_heads, ff_dim):
        super().__init__()
        self.embedding = ViTEmbeddings(image_size, patch_size, num_channels, em_dim, drop)
        encoder_layer = TransformerEncoderLayer(em_dim,
                                                nhead=num_heads,
                                                dim_feedforward=ff_dim,
                                                activation='gelu',
                                                dropout=drop)
        self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers)
        self.layernorm = nn.LayerNorm(em_dim, eps=1e-6)

    def forward(self, x):
        x = self.embedding(x)
        x = x.transpose(0, 1)  # Switch batch and sequence length dimension for pytorch convention
        x = self.encoder(x)
        x = self.layernorm(x[0])
        return x
    

class ViTModel(nn.Module):
    def __init__(self,
                 num_classes,
                 num_layers=12,
                 image_size=224,
                 patch_size=16,
                 num_channels=3,
                 em_dim=768,
                 drop=0.1,
                 num_heads=12,
                 ff_dim=3072):
        super().__init__()
        self.vit_encoder = ViTEncoder(num_layers=num_layers,
                                      image_size=image_size,
                                      patch_size=patch_size,
                                      num_channels=num_channels,
                                      em_dim=em_dim,
                                      drop=drop,
                                      num_heads=num_heads,
                                      ff_dim=ff_dim)
        self.linear_classifier = nn.Linear(em_dim, num_classes)

    def forward(self, x):
        x = self.vit_encoder(x)
        x = self.linear_classifier(x)
        return x

Now let's define some parameters that will be used later:

In [2]:
batch_size=128
pretrain_epochs=100
finetune_epochs=1
train_steps_per_epoch=None
eval_steps_per_epoch=None

## Upstream Pre-training
We will use CIFAIR 100 as our upstream dataset. The data preprocessing and augmentation is the standard Padded Crop + Dropout used in [this example](https://github.com/fastestimator/fastestimator/blob/master/apphub/image_classification/cifar10_fast/cifar10_fast_torch.py).

In [3]:
import tempfile

import fastestimator as fe
from fastestimator.dataset.data import cifair10, cifair100
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import ChannelTranspose, CoarseDropout, Normalize
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.metric import Accuracy


def pretrain(batch_size,
             epochs,
             model_dir=tempfile.mkdtemp(),
             train_steps_per_epoch=None,
             eval_steps_per_epoch=None):
    train_data, eval_data = cifair100.load_data()
    pipeline = fe.Pipeline(
        train_data=train_data,
        eval_data=eval_data,
        batch_size=batch_size,
        ops=[
            Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
            PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
            RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
            Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
            CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
            ChannelTranspose(inputs="x", outputs="x")
        ])
    model = fe.build(
        model_fn=lambda: ViTModel(num_classes=100,
                                  image_size=32,
                                  patch_size=4,
                                  num_layers=6,
                                  num_channels=3,
                                  em_dim=256,
                                  num_heads=8,
                                  ff_dim=512),
        optimizer_fn=lambda x: torch.optim.SGD(x, lr=0.01, momentum=0.9, weight_decay=1e-4))
    network = fe.Network(ops=[
        ModelOp(model=model, inputs="x", outputs="y_pred"),
        CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
        UpdateOp(model=model, loss_name="ce")
    ])
    traces = [
        Accuracy(true_key="y", pred_key="y_pred")
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces,
                             train_steps_per_epoch=train_steps_per_epoch,
                             eval_steps_per_epoch=eval_steps_per_epoch,
                             log_steps=0)
    estimator.fit(warmup=False)
    return model

## Start Pre-training

Let's train the ViT model for 100 epochs, and get the pre-trained weight. This would take ~40 minutes on single GTX 1080 TI GPU. 

Here we are only training a mini version of the actual ViT model, and the CIFAR100 performance after 100 epochs is similar to the 55% top-1 performance [reported in the community](https://keras.io/examples/vision/image_classification_with_vision_transformer/). However, training the official `ViTModel` model with its original parameters on the JFT-300M dataset would produce much better encoder weights at the cost of a much longer training time. The paper used this strategy to reach near 81% ImageNet downstream top-1 accuracy.

In [4]:
pretrained_model = pretrain(batch_size=batch_size,
                            epochs=pretrain_epochs,
                            train_steps_per_epoch=train_steps_per_epoch,
                            eval_steps_per_epoch=eval_steps_per_epoch)

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
FastEstimator-Start: step: 1; logging_interval: 0; num_device: 1;
FastEstimator-Eval: step: 391; epoch: 1; accuracy: 0.1088; ce: 3.8288884;
FastEstimator-Eval: step: 782; epoch: 2; accuracy: 0.158; ce: 3.5546496;
FastEstimator-Eval: step: 1173; epoch: 3; accuracy: 0.1876; ce: 3.3546138;
FastEstimator-Eval: step: 1564; epoch: 4; accuracy: 0.225; ce: 3.1463547;
FastEstimator-Eval: step: 1955; epoch: 5; accuracy: 0.2515; ce: 3.0236564;
FastEstimator-Eval: step: 2346; epoch: 6; accura

## Downstream Fine-tuning

A general rule-of-thumb to ensure successful downstream fine-tuning is to choose a downstream task with less variety and complexity than the upstream training. In this example, given that we used CIFAIR100 as our upstream task, a good candidate for the downstream dataset is CIFAIR10.  The official implementation mapped this practice to a larger scale, using JFT-300M as their upstream task and then ImageNet as their downstream task.

Given the similarity between our downstream and upstream datasets, the fine-tuning configuration is almost the same as before.

In [7]:
def finetune(pretrained_model,
             batch_size,
             epochs,
             model_dir=tempfile.mkdtemp(),
             train_steps_per_epoch=None,
             eval_steps_per_epoch=None):
    train_data, eval_data = cifair10.load_data()
    pipeline = fe.Pipeline(
        train_data=train_data,
        eval_data=eval_data,
        batch_size=batch_size,
        ops=[
            Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
            PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
            RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
            Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
            CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
            ChannelTranspose(inputs="x", outputs="x")
        ])
    model = fe.build(
        model_fn=lambda: ViTModel(num_classes=100,
                                  image_size=32,
                                  patch_size=4,
                                  num_layers=6,
                                  num_channels=3,
                                  em_dim=256,
                                  num_heads=8,
                                  ff_dim=512),
        optimizer_fn=lambda x: torch.optim.SGD(x, lr=0.01, momentum=0.9, weight_decay=1e-4))
    # load the encoder's weight
    if hasattr(model, "module"):
        model.module.vit_encoder.load_state_dict(pretrained_model.module.vit_encoder.state_dict())
    else:
        model.vit_encoder.load_state_dict(pretrained_model.vit_encoder.state_dict())
    network = fe.Network(ops=[
        ModelOp(model=model, inputs="x", outputs="y_pred"),
        CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
        UpdateOp(model=model, loss_name="ce")
    ])
    traces = [
        Accuracy(true_key="y", pred_key="y_pred")
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces,
                             train_steps_per_epoch=train_steps_per_epoch,
                             eval_steps_per_epoch=eval_steps_per_epoch)
    estimator.fit(warmup=False)

## Start the Fine-tuning

The downstream ViT is re-using the ViT encoder pre-trained on the CIFAR100 dataset. To illustrate the effect of using the pre-trained encoder, we will only train the downstream task for a **single** epoch.

In [8]:
finetune(pretrained_model,
         batch_size=batch_size,
         epochs=finetune_epochs,
         train_steps_per_epoch=train_steps_per_epoch,
         eval_steps_per_epoch=eval_steps_per_epoch)

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1;
FastEstimator-Train: step: 1; ce: 4.801615;
FastEstimator-Train: step: 100; ce: 1.0262994; steps/sec: 17.72;
FastEstimator-Train: step: 200; ce: 0.74568576; steps/sec: 17.57;
FastEstimator-Train: step: 300; ce: 0.7660386; steps/sec: 17.54;
FastEstimator-Train: step: 391; epoch: 1; epoch_time: 22.4 sec;
FastEstimator-Eval: step: 391; epoch: 1; accuracy: 0.7426; ce: 0.7396317;
FastEstimator-Finish: step: 391; model2

With only one epoch of training, we are able to get 74% top-1 accuracy on the CIFAIR 10 test set.  Not bad huh?