# Vision Transformer (ViT) 从零开始实现
论文：[An Image is Worth 16X16 Words: Transformers for Image Recongnition at Scale](https://arxiv.org/pdf/2010.11929)

## 网络结构
<img src="./img/ViT-1.png" width = "600" height = "400" alt="ViT 网络结构" align=center />

In [1]:
import torch
from torch import nn

## Split Image
<img src="./img/ViT-2.png" width = "600" height = "300" align=center />

In [None]:
class Patches(nn.Module):
    '''图像分割'''
    def __init__(self, patch_h, patch_w, **kwargs):
        super(Patches, self).__init__(**kwargs)
        self.patch_h = patch_h
        self.patch_w = patch_w
    
    def forward(self, images):
        '''
        args:
            images:     tensor (batch_size, channels, height, width)
        return:
            patches:    tensor (batch_size, num_patches, channels * patch_size * patch_size)
        '''
        batch, channels, height, width = images.shape
        
        num_h = height // self.patch_h
        num_w = width // self.patch_w
        num_patches = num_h * num_w
        
        images = images.reshape(batch, channels, num_h, self.patch_h, num_w, self.patch_w)
        images = images.permute(0, 2, 4, 1, 3, 5) # 会改变顺序
        patches = images.reshape(batch, num_patches, channels, self.patch_h, self.patch_w)
        patches = patches.reshape(batch, num_patches, -1)
        
        return patches

  
images = torch.arange(24, dtype=torch.float).reshape(1, 2, 3, 4) # B, C, H, W
patches_module = Patches(1, 2)
patches_module.eval()
patches_1 = patches_module(images)
patches_1

tensor([[[ 0.,  1., 12., 13.],
         [ 2.,  3., 14., 15.],
         [ 4.,  5., 16., 17.],
         [ 6.,  7., 18., 19.],
         [ 8.,  9., 20., 21.],
         [10., 11., 22., 23.]]])

## Linear Projection

In [25]:
class LinearProjection(nn.Module):
    '''Patch 线形投影'''
    def __init__(self, input_dim, projection_dim, **kwargs):
        
        super(LinearProjection, self).__init__(**kwargs)
        
        self.projection = nn.Linear(input_dim, projection_dim)
    
    def forward(self, patches):
        '''
        args:
            patches: tensor (batch_size, num_patches, input_dim)
        return:
            tensor (batch_size, num_patches, projection_dim)
        '''
        return self.projection(patches)

proj_module = LinearProjection(4, 10)
proj_module.eval()
patches_2 = proj_module(patches_1)
patches_2.shape
    

torch.Size([1, 6, 10])

## Position Encoding

In [26]:
class PositionEmbedding(nn.Module):
    '''为输入序列添加可学习的位置嵌入'''
    def __init__(self, max_seq_len=2048, **kwargs):
        super(PositionEmbedding, self).__init__(**kwargs)
        self.max_seq_len=max_seq_len
    
    def forward(self, inputs):
        '''
        args:
            inputs: tensor (batch_size, seq_len, hidden_dim)
        return:
            tensor (batch_size, seq_len, hidden_dim)
        '''
        _, seq_len, hidden_dim = inputs.shape
        
        # 动态创建位置嵌入表（如果不存在）
        if not hasattr(self, 'position_embedding'):
            # 初始化为可学习参数
            self.position_embedding = nn.Parameter(torch.zeros(1, self.max_seq_len, hidden_dim))
            nn.init.normal_(self.position_embedding, std=0.02)
        
        pos_emb = self.position_embedding[:, :seq_len]
        
        return inputs + pos_emb


pos_module = PositionEmbedding()
pos_module.eval()
patches_3 = pos_module(patches_2)
patches_3

tensor([[[  1.5023,   4.7399,   2.4235,  -7.3212,   3.8686,   1.4247,   1.0773,
            5.7908,  -1.7574,   1.4652],
         [  2.3151,   5.2620,   3.9575,  -8.3688,   4.9957,   1.6796,   1.5285,
            6.7154,  -0.4849,   2.1983],
         [  3.0486,   5.8617,   5.5067,  -9.4144,   6.1236,   1.8503,   1.9986,
            7.6157,   0.8165,   2.9525],
         [  3.8512,   6.3575,   7.0050, -10.4415,   7.2433,   2.0758,   2.4740,
            8.4977,   2.1733,   3.6683],
         [  4.5799,   6.9025,   8.5117, -11.5410,   8.3848,   2.3060,   2.9845,
            9.3924,   3.4485,   4.4218],
         [  5.3206,   7.4151,  10.0249, -12.5665,   9.4877,   2.5473,   3.4275,
           10.3151,   4.7673,   5.1656]]], grad_fn=<AddBackward0>)

In [28]:
class ClassToken(nn.Module):
    '''每个输入序列前添加可学习的分类令牌'''
    def __init__(self, hidden_dim, **kwargs):
        super(ClassToken, self).__init__(**kwargs)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        nn.init.normal_(self.cls_token, std=0.02)
    
    def forward(self, inputs):
        '''
        args:
            inputs: tensor (batch_size, seq_len, hidden_dim)
        return:
            tensor (batch_size, seq_len + 1, hidden_dim)
        '''
        batch_size = inputs.shape[0]
        
        # 广播分类令牌到每一个batch
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        
        return torch.cat([cls_tokens, inputs], dim = 1)

cls_module = ClassToken(10)
cls_module.eval()
patches_4 = cls_module(patches_3)
patches_4.shape

torch.Size([1, 7, 10])