In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms, datasets

In [2]:
class ViT(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=100, dim=768, depth=12, heads=12, mlp_dim=3872, dropout_rate=0.3):
        super(ViT, self).__init__()
        self.image_size = image_size  # 输入图像的大小
        self.patch_size = patch_size  # 切分图像的块大小
        self.num_patches = (image_size // patch_size) ** 2  # 计算总的图像块数量
        self.patch_dim = 3 * patch_size ** 2  # 每个图像块的维度
        # 利用卷积层将图像切分为多个图像块，并将每个图像块投影到dim维空间
        self.conv = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # CLS token
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))
        # Dropout正则化
        self.dropout = nn.Dropout(dropout_rate)
        # 使用Transformer对图像块进行编码
        activation = nn.GELU()
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim, activation=activation), num_layers=depth)
        # 添加LayerNorm层
        self.layer_norm = nn.LayerNorm(dim)
        # 全连接层，用于分类任务
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        # 图像通过卷积层进行切分和线性投影
        x = self.conv(x)
        # 对卷积层的输出进行reshape，以符合Transformer的输入要求
        x = x.flatten(2).transpose(1, 2)
        # 在序列的开始添加CLS token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # 对CLS token进行复制以匹配batch size
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置编码
        x = x + self.pos_embedding
        # 添加dropout正则化
        x = self.dropout(x)
        # 图像块通过Transformer进行编码
        x = self.transformer_encoder(x)
        # 取出CLS token的表征用于分类
        x = x[:, 0]
        x = self.layer_norm(x)
        # 通过全连接层进行分类
        x = self.fc(x)
        return x

In [3]:
batch_size = 64
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.CIFAR100(root="./data/cifar-100", train=True, download=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [4]:
def plot_weights(model_weight):
    weights = []
    for name, param in model_weight.named_parameters():
        if param.requires_grad and 'weight' in name:
            weights += list(param.detach().cpu().numpy().flatten())

    plt.hist(weights, bins=100)
    plt.show()

In [5]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)

In [6]:
# 实例化ViT模型
model = ViT()
model.apply(weights_init)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 定义学习率调度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
# 训练模型
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)



In [7]:
model

ViT(
  (conv): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (dropout): Dropout(p=0.3, inplace=False)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=3872, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=3872, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (activation): GELU(approximate='none')
      )
    )
  )
  (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (fc): Linear(in_features=768, out_features=100, 

In [8]:
for epoch in range(num_epochs):
    for idx, (images, labels) in enumerate(train_dataloader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        _, pred = torch.max(outputs, 1)
        lr = optimizer.param_groups[0]["lr"]
        print(f"epoch: {epoch+1}/{num_epochs},\tstep: {idx},\tloss: {loss.item():.4f},\tacc: {(torch.sum(pred == labels.data)/batch_size)*100:.3f}%,\tlr: {lr}")