In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchsummary import summary
from fvcore.nn import FlopCountAnalysis
import numpy as np

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=32, in_channels=3, embed_size=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.projection = nn.Conv2d(in_channels, embed_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # [B, C, H, W] -> [B, E, H', W']
        x = x.flatten(2)        # [B, E, H', W'] -> [B, E, N]
        x = x.transpose(1, 2)   # [B, E, N] -> [B, N, E]
        return x

class MLPBlock(nn.Module):
    def __init__(self, embed_size):
        super(MLPBlock, self).__init__()
        self.fc1 = nn.Linear(embed_size, embed_size * 4)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(embed_size * 4, embed_size)
        self.dropout = nn.Dropout(0.0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(EncoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_size)
        self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(embed_size)
        self.mlp = MLPBlock(embed_size)
        self.dropout = nn.Dropout(0.0)

    def forward(self, x):
        # Self-attention layer
        x = self.norm1(x)
        attention_output = self.self_attention(x, x, x)[0]
        x = x + attention_output

        # MLP block
        x = self.norm2(x)
        mlp_output = self.mlp(x)
        x = x + mlp_output

        return x

class Encoder(nn.Module):
    def __init__(self, embed_size=768, num_heads=8, num_layers=6):
        super().__init__()
        # sequential version is faster than nn.ModuleList
        self.layers = nn.Sequential(*[EncoderBlock(embed_size, num_heads) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(embed_size)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x
    
class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=32, in_channels=3, embed_size=768, num_heads=8, num_layers=6, num_classes=10):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
        self.positional_encoding = self.create_positional_encoding(1 + self.patch_embedding.n_patches, embed_size)
        self.transformer_encoder = Encoder(embed_size, num_heads, num_layers)
        self.fc = nn.Linear(embed_size, num_classes)

    def create_positional_encoding(self, num_positions, embed_size):
        position = torch.arange(num_positions).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * -(np.log(10000.0) / embed_size))
        positional_encoding = torch.zeros(num_positions, embed_size)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        return nn.Parameter(positional_encoding, requires_grad=False)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patch_embedding(x)
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.positional_encoding
        x = self.transformer_encoder(x)
        x = x[:, 0]
        x = self.fc(x)
        return x

In [3]:
# 创建模型、损失函数和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT(img_size = 224,
            patch_size = 32,
            in_channels = 3,
            embed_size = 768,
            num_heads = 12,
            num_layers = 12,
            num_classes = 10).to(device)

In [4]:
print(model)

ViT(
  (patch_embedding): PatchEmbedding(
    (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
  )
  (transformer_encoder): Encoder(
    (layers): Sequential(
      (0): EncoderBlock(
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLPBlock(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): EncoderBlock(
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDyna

In [5]:
summary(model, torch.zeros((1, 3, 224, 224)).to(device))
flops = FlopCountAnalysis(model, torch.zeros((1, 3, 224, 224)).to(device))
print('FLOPs: {:.2f}G'.format(flops.total()/1e9))

Layer (type:depth-idx)                        Output Shape              Param #
├─PatchEmbedding: 1-1                         [-1, 49, 768]             --
|    └─Conv2d: 2-1                            [-1, 768, 7, 7]           2,360,064
├─Encoder: 1-2                                [-1, 50, 768]             --
|    └─Sequential: 2                          []                        --
|    |    └─EncoderBlock: 3-1                 [-1, 50, 768]             7,087,872
|    |    └─EncoderBlock: 3-2                 [-1, 50, 768]             7,087,872
|    |    └─EncoderBlock: 3-3                 [-1, 50, 768]             7,087,872
|    |    └─EncoderBlock: 3-4                 [-1, 50, 768]             7,087,872
|    |    └─EncoderBlock: 3-5                 [-1, 50, 768]             7,087,872
|    |    └─EncoderBlock: 3-6                 [-1, 50, 768]             7,087,872
|    |    └─EncoderBlock: 3-7                 [-1, 50, 768]             7,087,872
|    |    └─EncoderBlock: 3-8          

Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::div encountered 24 time(s)
Unsupported operator aten::unflatten encountered 12 time(s)
Unsupported operator aten::mul encountered 48 time(s)
Unsupported operator aten::softmax encountered 12 time(s)
Unsupported operator aten::add encountered 24 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
transformer_encoder.layers.0.dropout, transformer_encoder.layers.0.self_attention.out_proj, transformer_encoder.layers.1.dropout, transformer_encoder.layers.1.self_attention.out_proj, transformer_encoder.layers.10.dropout, transformer_encoder.layers.10.self_attention.out_proj, transformer

FLOPs: 4.37G


In [6]:
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='~/.cache', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='~/.cache', train=False, transform=transform, download=True)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    loop = tqdm(train_loader, leave=True)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        # 前向传播和计算损失
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss=loss.item())

    # 评估模型
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy after epoch {epoch+1}: {accuracy}%')

  0%|          | 0/782 [00:00<?, ?it/s]

Epoch [1/10]:  23%|██▎       | 179/782 [00:53<03:01,  3.32it/s, loss=2.41]


KeyboardInterrupt: 