# Vit Transformer

## Модель

In [None]:
import torch
from torch import nn

In [None]:
# Смоделируем данные

n_features = 10  # Количество признаков
n_classes = 3  # Количество классов
batch_size = 5 

data = torch.randn((batch_size, n_features))
print(data.shape)
print(data)

torch.Size([5, 10])
tensor([[-1.1127,  0.6661,  1.0794,  0.0196,  0.5441,  1.1445,  0.3761, -1.4601,
         -1.4382, -0.0627],
        [ 0.1915,  1.4360,  0.6060, -0.6318,  0.0714,  0.2601, -0.3320,  0.3403,
         -0.6557, -0.3798],
        [-0.4686, -0.2286, -1.2541,  0.5025,  0.5169, -1.3900,  0.1105,  0.8309,
         -0.5409,  0.1406],
        [-0.6098, -0.2844, -0.0605,  1.4607, -0.4396, -0.7302, -1.7419, -0.4694,
         -2.1952, -0.5802],
        [-1.8152,  2.1991,  1.0109,  0.4256,  0.4993,  0.9191, -0.4266, -1.4751,
          0.9440, -0.6904]])


In [None]:
# Зададим простую модель
model = nn.Linear(n_features, n_classes)

In [None]:
# Применим модель к вектору
answer = model(data)
print(answer.shape)
print(answer)

torch.Size([5, 3])
tensor([[-0.3545,  0.2211, -0.1282],
        [-0.0493, -0.0335, -0.1272],
        [ 0.3399,  0.1628, -0.8176],
        [-0.8173,  0.4748, -0.1106],
        [-0.8670,  1.0389,  0.9653]], grad_fn=<AddmmBackward0>)


In [None]:
# Модель как наследник nn.Module
class SimpleNN(nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()

        self.lin = nn.Linear(n_features, n_classes)

    def forward(self, x):
        return self.lin(x)

In [None]:
# Попробуем применить модель в виде класса к данным
model = SimpleNN(n_features, n_classes)

answer = model(data)
print(answer.shape)
print(answer)

torch.Size([5, 3])
tensor([[-0.7035,  0.2531,  0.7247],
        [-0.3799, -0.1155,  0.6153],
        [ 0.5077,  0.1399, -0.8757],
        [ 0.7817,  0.9204, -0.3860],
        [-1.5674,  0.1022,  1.7614]], grad_fn=<AddmmBackward0>)


In [None]:
!pip install torchsummary
from torchsummary import summary

model = SimpleNN(n_features, n_classes).cuda()

# 5, 10
input_size = (batch_size, n_features)
print(summary(model, input_size))

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 5, 3]              33
Total params: 33
Trainable params: 33
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
None


In [None]:
# Модель как sequential
model = nn.Sequential(nn.Linear(n_features, n_classes))

answer = model(data)
print(answer.shape)
print(answer)

torch.Size([5, 3])
tensor([[-1.5693,  0.9656,  0.4190],
        [-0.1636, -0.6284, -0.5807],
        [ 0.3587, -0.1695, -0.4409],
        [-1.4528,  0.1908,  0.1061],
        [-0.8333, -0.1868,  0.3930]], grad_fn=<AddmmBackward0>)


In [None]:
# Модель как nn.ModuleList

model = nn.ModuleList([nn.Linear(n_features, n_classes)])

# answer = model(data)
# print(answer.shape)
# print(answer)

answer = model[0](data)
print(answer.shape)
print(answer)


torch.Size([5, 3])
tensor([[-0.8756, -0.2555,  0.2025],
        [-0.2768,  0.1934,  0.4638],
        [ 0.3753,  0.2621,  0.2184],
        [ 0.0898,  0.0724,  0.3421],
        [-1.1695,  0.7667,  0.3793]], grad_fn=<AddmmBackward0>)


In [None]:
# Проверим параметры модели
class ParametersCheck(nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()

        self.lin = nn.Linear(n_features, n_classes)
        self.seq = nn.Sequential(nn.Linear(n_features, n_classes))
        self.module_list = nn.ModuleList([nn.Linear(n_features, n_classes)])
        self.list_of_layers = [nn.Linear(n_features, n_classes)]


In [None]:
model = ParametersCheck(n_features, n_classes)

for i, param in enumerate(model.parameters()):
    print(f'Параметр #{i + 1}.')
    print(f'\t{param.shape}')

Параметр #1.
	torch.Size([3, 10])
Параметр #2.
	torch.Size([3])
Параметр #3.
	torch.Size([3, 10])
Параметр #4.
	torch.Size([3])
Параметр #5.
	torch.Size([3, 10])
Параметр #6.
	torch.Size([3])


## ViT

![alt text](https://drive.google.com/uc?export=view&id=1J5TvycDPs8pzfvlXvtO5MCFBy64yp9Fa)

In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

![](https://amaarora.github.io/images/vit-01.png)

## Часть 1. Patch Embedding, CLS Token, Position Encoding

![](https://amaarora.github.io/images/vit-02.png)

In [None]:
# input image `B, C, H, W`
x = torch.randn(1, 3, 224, 224)
# 2D conv
conv = nn.Conv2d(3, 768, 16, 16)
conv(x).reshape(-1, 196).transpose(0,1).shape

torch.Size([196, 768])

In [None]:
class PatchEmbedding(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = (img_size, img_size)
        self.patch_size = (patch_size, patch_size)
        self.num_patches = (self.img_size[1] // self.patch_size[1]) * (self.img_size[0] // self.patch_size[0])
        self.patch_embeddings = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, image):
        # B, C, H, W = image.shape # assert
        patches = self.patch_embeddings(image).flatten(2).transpose(1, 2)
        #patches = patches.reshape(-1, 196).transpose(0,1)
        return patches

In [None]:
patch_embed = PatchEmbedding()
x = torch.randn(1, 3, 224, 224)
patch_embed(x).shape 

torch.Size([1, 196, 768])

![](https://amaarora.github.io/images/vit-03.png)

## Часть 2. Transformer Encoder

![](https://amaarora.github.io/images/ViT.png)

![](https://amaarora.github.io/images/vit-07.png)

In [None]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()

        self.seq = nn.Sequential(
          nn.Linear(in_features, hidden_features),
          nn.GELU(),
          nn.Linear(hidden_features, out_features),
          nn.GELU())
        
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        
        x = self.drop(self.seq(x))

        return x

In [None]:
x = torch.randn(1, 197,768)
mlp = MLP(768, 3072, 768)
out = mlp(x)
out.shape

torch.Size([1, 197, 768])

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., out_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = ...
        self.attn_drop = nn.Dropout(attn_drop)
        self.out = ...
        self.out_drop = nn.Dropout(out_drop)

    def forward(self, x):
        
        # Attention
        ...

        ...

        # Out projection

        ...

        return x


![](https://amaarora.github.io/images/vit-08.png)

In [None]:
# attn = (q @ k.transpose(-2, -1)) * self.scale
# attn = attn.softmax(dim=-1)

In [None]:
x = torch.randn(1, 197, 768)
attention = Attention(768, 8)
out = attention(x)
out.shape

In [None]:
class Block(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4, drop_rate=0.):
        super().__init__()

        # Normalization
        ...

        # Attention
        ...

        # Dropout
        ...

        # Normalization
        ...

        # MLP
        ...
                

    def forward(self, x):
        # Attetnion
        ...

        # MLP
        ...
        return x

In [None]:
x = torch.randn(1, 197, 768)
block = Block(768, 8)
out = attention(x)
out.shape

В оригинальной реализации теперь используется [DropPath](https://github.com/rwightman/pytorch-image-models/blob/e98c93264cde1657b188f974dc928b9d73303b18/timm/layers/drop.py)

In [None]:
class Transformer(nn.Module):
    def __init__(self, depth, dim, num_heads=8, mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, mlp_ratio, drop_rate)
            for i in range(depth)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [None]:
x = torch.randn(1, 197, 768)
block = Transformer(12, 768)
out = attention(x)
out.shape

![](https://amaarora.github.io/images/vit-06.png)

In [None]:
from torch.nn.modules.normalization import LayerNorm

class ViT(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 
                 qkv_bias=False, drop_rate=0.,):
        super().__init__()

        # Присвоение переменных
        

        # Path Embeddings, CLS Token, Position Encoding
        self.patch_embed = PatchEmbedding(img_size=img_size, 
                                               patch_size=patch_size,
                                               in_chans=in_chans,
                                               embed_dim=embed_dim)
        
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # Transformer Encoder
        
        
        # Classifier
     

    def forward(self, x):
        B = x.shape[0]
      
        # Path Embeddings, CLS Token, Position Encoding
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        # Transformer Encoder
    

        # Classifier
      

        return x

In [None]:
x = torch.randn(1, 3, 224, 224)
vit = ViT()
out = vit(x)
out.shape

torch.Size([1, 197, 768])

# Домашнее задание


1. Выбрать датасет для классификации изображений с размерностью 64x64+ 
2. Обучить ViT на таком датасете.
3. Попробовать поменять размерности и посмотреть, что поменяется при обучении.


Примечание:
- Датасеты можно взять [тут](https://pytorch.org/vision/stable/datasets.html#built-in-datasets) или найти в другом месте.
- Из за того, что ViT учится медленно, количество примеров в датасете можно ограничить до 1к-5к.