## vit 结构
![image.png](../add_pic/vit.png)

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [22]:
def image2emb_naive(image, patch_size, weight):
    # image: [B, C, H, W]
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
    patch_embedding = patch @ weight
    return patch_embedding



def image2emb_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride=stride)
    bs, ic, ih, iw = conv_output.shape
    patch_embedding = conv_output.reshape((bs, ic, ih * iw)).transpose(-1, -2)
    return patch_embedding



# test code for image2emb
bs, ic, ih, iw = 1, 3, 8, 8
patch_size = 4
model_dim = 8
max_num_token = 16
num_classes = 10
batch_size = 1
label = torch.randint(10, (batch_size, ))
patch_depth = patch_size * patch_size * ic
image = torch.randn(bs, ic, ih, iw)
weight = torch.randn(patch_depth, model_dim)
patch_embedding_navie = image2emb_naive(image, patch_size, weight)
print(patch_embedding_navie.shape)

kernel = weight.transpose(0, 1).reshape((-1, ic, patch_size, patch_size))
patch_embedding_conv = image2emb_conv(image, kernel, stride=patch_size)
print(image2emb_conv(image, kernel, stride=patch_size).shape)

torch.Size([1, 4, 8])
torch.Size([1, 4, 8])


In [23]:
# prepend CLS token embedding
cls_token_embedding = torch.randn(batch_size, 1, model_dim, requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)

# add position embedding
position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)
seq_len = token_embedding.shape[1]
position_embedding = torch.tile(position_embedding_table[:seq_len], (token_embedding.shape[0], 1, 1))
token_embedding = token_embedding + position_embedding

In [26]:
# pass embedding to transformer encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
encoder_output = transformer_encoder(token_embedding)

# do clssification
cls_token_output = encoder_output[:, 0, :]
linear_layer = nn.Linear(model_dim, num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, label)
print(loss)

tensor(2.4045, grad_fn=<NllLossBackward0>)


