# 28、Vision Transformer（ViT）模型原理及PyTorch逐行实现

论文地址：<https://arxiv.org/abs/2010.11929>
## ViT 想法
ViT想要把Transformer模型应用到图像识别领域，对于nlp任务，单位是一个字（token），直接应用到像素点上，序列长度会很高，元素信息量太少，因此以区域作为单位建模。
![](http://assets.hypervoid.top/img/2025/07/05/202507051748682-b9aa.png)
### DNN角度

1. 对图片切分成一个个块（image to patch）
2. 对patch经过仿射变换来变成embedding

### CNN角度

将图片变成embedding可以看成卷积过程，其中 `kernel_size=stride` ，卷积后将特征图拉直。



## 借鉴

### BERT：class token embedding



## position embedding
对比后发现一维可训练embedding效果好

## Encoder
只使用了encoder模块，没有使用decoder

In [1]:
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F

In [20]:
# 分别使用 unfold 和卷积操作实现 image->embedding 操作
def image2emb_naive(img: Tensor, weight, patch_size=16):
    assert img.dim() == 4
    batch_size, channel, img_width, img_height = img.shape
    # shape: batch_size, patched, patch_num
    region = F.unfold(img, patch_size, stride=patch_size)
    region.transpose_(-1, -2)
    print(region.shape)
    return region @ weight


def image2emb_conv(img: Tensor, kernel, patch_size=16):
    conv = F.conv2d(img, kernel, stride=patch_size)
    bs, oc, ow, oh = conv.shape
    patch_emb = conv.reshape((bs, oc, ow*oh)).transpose(-1, -2) # (bs, ow*oh, oc)
    return patch_emb

In [None]:
import torchvision
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms


ds = torchvision.datasets.mnist.FashionMNIST("./fashion_minist_dataset", True, download=True, transform=transforms.ToTensor())
dl = DataLoader(ds, 2, shuffle=False)

patch_size, model_dim = 14, 8
for img, _ in dl:
    batch_size, channel, img_width, img_height = img.shape
    weight = torch.randn(patch_size**2 * channel, model_dim)
    patch_token_emb\ = image2emb_naive(img, weight, 14)
    print("emb.shape=", patch_token_emb\.shape)
    print(patch_token_emb\)
    # kernel = torch.randn((model_dim, 1, patch_size, patch_size))
    kernel = weight.transpose(0, 1).reshape((-1, channel, patch_size, patch_size))
    emb2 = image2emb_conv(img, kernel, 14)
    print("emb.shape=", emb2.shape)
    print(emb2)
    assert torch.allclose(patch_token_emb\, emb2)
    break

torch.Size([2, 4, 196])
emb.shape= torch.Size([2, 4, 8])
tensor([[[  0.5155,  -0.3687,  -3.6338,  -1.2878,   1.2945,   1.8334,  -3.7690,
           -0.6200],
         [  7.3336, -15.6543,   6.0121,   4.6137,   4.7995, -10.7871,   0.3487,
            6.9605],
         [ -1.3895,   1.5816,  -2.5958,   3.9162,  -5.0130,   2.0175,   6.0092,
            3.8195],
         [  3.8014,  -2.9465,  10.2140,   9.6521,   4.3162, -10.4472,  11.1373,
            8.1564]],

        [[ -0.9922, -11.4510,  -7.3893,   0.5645,  -0.6928,   0.1476,  -5.3571,
            0.7916],
         [  9.9496,  -5.7659,   4.8640,   3.0778,   2.8867, -14.5847,   5.3720,
           16.2563],
         [ -2.6126,  -3.1197,  -7.0267,  -8.0693,   0.0918,   6.7568,   6.4510,
            3.3721],
         [  7.1353,  -5.7266,  12.7986,   9.8556,   5.0186, -14.1588,   5.5587,
           10.3287]]])
emb.shape= torch.Size([2, 4, 8])
tensor([[[  0.5155,  -0.3687,  -3.6338,  -1.2878,   1.2945,   1.8334,  -3.7690,
           -0.6200

In [None]:
# 添加 classification token 用于分类任务

patch_size, model_dim = 14, 8
for img, _ in dl:
    batch_size, channel, img_width, img_height = img.shape
    weight = torch.randn(patch_size**2 * channel, model_dim)
    patch_token_emb = image2emb_naive(img, weight, 14)
    # 2. 增加 cls token embedding
    cls_token_emb = torch.randn((batch_size, 1, model_dim), requires_grad=True)
    emb = torch.cat([cls_token_emb, patch_token_emb], dim=1)
    # 3. 增加位置编码
    pos_emb_table = torch.randn(32, model_dim, requires_grad=True) # 32表示最大token数量
    break