<a href="https://colab.research.google.com/github/gnoejh/ict1022/blob/main/Architectures/vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer (ViT)

Vision Transformer (ViT) represents a paradigm shift in computer vision by applying the transformer architecture, originally designed for natural language processing, to image recognition tasks.

## Introduction

Introduced in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2020), Vision Transformers challenged the long-standing dominance of convolutional neural networks (CNNs) in computer vision.

The key insight of ViT is to treat an image as a sequence of patches, similar to how words are treated in NLP, and then process these patches using transformer encoders. This approach eliminates the inductive bias of CNNs (locality and translation equivariance) in favor of learning relationships between image patches through self-attention mechanisms.

## Architecture

![Vision Transformer Architecture](https://storage.googleapis.com/tfjs-examples/assets/vit/vit.png)

The Vision Transformer architecture consists of the following components:

1. **Patch Embedding**:
   - The input image is divided into fixed-size patches (typically 16×16 pixels)
   - Each patch is flattened and linearly projected to obtain patch embeddings

2. **Position Embeddings**:
   - Learnable positional embeddings are added to patch embeddings to retain spatial information
   - A special [CLS] token is prepended to the sequence (used for classification)

3. **Transformer Encoder**:
   - Standard transformer encoder blocks with:
     - Multi-head self-attention (MSA)
     - MLP blocks (Normalization, Dense layers, GELU activation)
     - Residual connections

4. **Classification Head**:
   - The output embedding of the [CLS] token is used for final classification
   - A simple MLP head maps the embedding to class predictions

## Implementation Example

Below is a simplified implementation of Vision Transformer in PyTorch:

In [None]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Linear projection of flattened patches
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        # x: (batch_size, in_channels, img_size, img_size)
        x = self.proj(x)  # (batch_size, embed_dim, h, w) where h=w=img_size//patch_size
        x = x.flatten(2)  # (batch_size, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (batch_size, n_patches, embed_dim)
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        # Self-attention block
        ln_x = self.ln1(x)
        attn_output, _ = self.attn(ln_x, ln_x, ln_x)
        x = x + attn_output
        
        # MLP block
        ln_x = self.ln2(x)
        x = x + self.mlp(ln_x)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, 
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0,
                 num_classes=1000, dropout=0.1):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches
        
        # Class token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer encoder layers
        self.blocks = nn.ModuleList([
            TransformerEncoder(embed_dim, num_heads, mlp_ratio, dropout) 
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        # Initialize patch_embed, cls_token, pos_embed, and linear layers
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)
        
    def forward(self, x):
        # Get patch embeddings
        x = self.patch_embed(x)  # (B, N, D)
        
        # Add class token
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # (B, 1, D)
        x = torch.cat((cls_token, x), dim=1)  # (B, N+1, D)
        
        # Add position embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Classification from [CLS] token
        x = self.norm(x)
        x = x[:, 0]  # Take only the cls token embedding
        x = self.head(x)
        
        return x

## Using a Pre-trained Vision Transformer (ViT)

In [None]:
# Using timm library (PyTorch Image Models)
import timm
import torch
from PIL import Image
import torchvision.transforms as transforms

# Load a pretrained ViT model
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Example inference (requires an image file)
# img = Image.open('example.jpg')
# img_tensor = transform(img).unsqueeze(0)  # Add batch dimension
# with torch.no_grad():
#     output = model(img_tensor)
# predicted_class = output.argmax().item()

## Key Innovations and Advantages

1. **Global Context**: Unlike CNNs that build features hierarchically, ViT processes global image context from the first layer through self-attention.

2. **Scalability**: ViT scales extremely well with increased data and model size, often outperforming CNNs when trained on large datasets.

3. **Transfer Learning**: Pre-training on large datasets (like JFT-300M) enables strong transfer learning capabilities to downstream tasks.

4. **Architectural Simplicity**: ViT uses a standard Transformer encoder without specialized components for vision tasks.

5. **Data Efficiency**: While the original ViT requires large datasets, subsequent variants have improved data efficiency.

## Limitations

1. **Data Hunger**: Original ViT performs poorly when trained from scratch on smaller datasets (e.g., ImageNet-1K) without pre-training on larger datasets.

2. **Computational Cost**: Self-attention has quadratic complexity with respect to the number of patches, limiting resolution.

3. **Inductive Bias**: ViT lacks the inductive biases of CNNs (locality, translation equivariance), requiring more data to learn these patterns.

## Variants and Evolution

Since the original ViT paper, several variants have been proposed:

- **DeiT** (Data-efficient Image Transformers): Improves training efficiency through distillation from CNN teachers.
- **Swin Transformer**: Introduces hierarchical structure with shifted windows, combining CNN and Transformer strengths.
- **MViT** (Multiscale Vision Transformers): Uses a pyramidal structure with pooling attention.
- **ViT-G/14**: Scaling up to giant models (1.8B parameters) for state-of-the-art performance.
- **MAE** (Masked Autoencoders): Self-supervised pre-training by reconstructing masked image patches.
- **ViTDet**: Adaptation of ViT for object detection tasks.
- **Segment Anything Model (SAM)**: Uses ViT backbone for universal image segmentation.

## Comparison: CNNs vs. Vision Transformers

| Aspect | CNNs | Vision Transformers |
|--------|------|--------------------|
| **Inductive Bias** | Strong spatial locality and translation equivariance | Minimal inductive bias, learns spatial relationships |
| **Data Efficiency** | Better with limited data | Requires more data unless special techniques are used |
| **Scalability** | Diminishing returns with scale | Scales very well with more data and parameters |
| **Global Context** | Requires deep networks to capture global context | Global context from first layer via self-attention |
| **Computational Cost** | Linear scaling with image size | Quadratic complexity with sequence length |
| **Interpretability** | Feature maps have spatial correspondence | Attention maps can show patch relationships |
| **Architecture** | Domain-specific design (kernels, pooling) | Generic transformer blocks |
| **State-of-the-art** | Historically dominated computer vision | Increasingly competitive, especially at scale |

## Applications

Vision Transformers have been successfully applied to numerous computer vision tasks:

- **Image Classification**: Achieving state-of-the-art on benchmarks like ImageNet
- **Object Detection**: ViTDet adapts Vision Transformers for detection tasks
- **Semantic Segmentation**: Models like SETR use ViT for pixel-level predictions
- **Image Generation**: Combined with diffusion models for high-quality image synthesis
- **Video Understanding**: Extended to process video frames with space-time attention
- **Multi-modal Learning**: Foundation for models like CLIP that connect vision and language
- **Self-supervised Learning**: Core component in masked autoencoding approaches like MAE

## References

1. Dosovitskiy, A., et al. (2020). [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). ICLR 2021.

2. Touvron, H., et al. (2021). [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877). ICML 2021.

3. Liu, Z., et al. (2021). [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030). ICCV 2021.

4. He, K., et al. (2022). [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377). CVPR 2022.

5. Zhai, X., et al. (2022). [Scaling Vision Transformers](https://arxiv.org/abs/2106.04560). CVPR 2022.

6. Radford, A., et al. (2021). [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020). ICML 2021.