**An Image is worth 16×16 Words**
```python 
class PatchEmbed(nn.Module)         # 对输入图像进行分块和展平操作
class Attention(nn.Module)          # 实现多头自注意力机制
class Mlp(nn.Module)                # 实现Transformer编码器中的MLP模块
class Block(nn.Module)              # 实现一整个Transformer编码器
class VisionTransformer(nn.Module)  # 实现ViT整体的最终架构
```

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

B = 1    # batch size
N = 196  # patches
C = 768  # channel / embed_dim

image = torch.randn([1, 3, 224, 224])

PatchEmbed实际上利用kernel_size和stride都等于patch_size的二维卷积将图像分块  
```python
                [1, 3, 224, 224]
Conv2d      ->  [1, 768, 14, 14]
flatten     ->  [1, 768, 196]
transpose   ->  [1, 196, 768]
```

In [6]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(in_channels=in_c, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

In [7]:
patchembed = PatchEmbed()
image = patchembed(image)
print(image.shape)

torch.Size([1, 196, 768])


拆分成多头计算注意力然后拼起来, patch_num应该是196+1, 有一个cls_token  
一个线性层计算出qkv, 然后就是常规的注意力操作
```python
                [1, 196, 768]
linear      ->  [1, 196, 3, 8, 96]
permute     ->  [3, 1, 8, 196, 96]
q @ k       ->  [1, 8, 196, 196]
attn @ v    ->  [1, 8, 196, 96]
transpose   ->  [1, 196, 8, 96]
reshape     ->  [1, 196, 768]
```

In [24]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.qkv = nn.Linear(in_features=dim, out_features=dim * 3)
        self.proj = nn.Linear(in_features=dim, out_features=dim)
        self.num_heads = num_heads
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batchsz, heads, nums, dim//heads]
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * np.power(C, -0.5)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

In [25]:
attention = Attention(dim=C)
image = attention(image)
print(image.shape)

torch.Size([1, 196, 768])


In [None]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self):
        return

In [None]:
class Block(nn.Module):
    def __init__(self,...):
        self.attn = Attention(...)
        self.mlp = Mlp(...)
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, ...):
        self.patch_embed = PatchEmbed(...)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.blocks = nn.Sequential(*[Block(...) for i in range(depth)])
    def forward(self, x):
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)  # 在n这个维度上cancat上类别编码
        x = x + self.pos_embed  # 每个元素加上位置编码
        x = self.blocks(x)  # 经过堆叠的transformer编码器
        return x[:, 0]  # 只返回类别