In [62]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchsummary import summary
from fvcore.nn import FlopCountAnalysis

In [63]:
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [64]:
# CIFAR-10 数据集的转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图片大小以匹配 ViT 输入
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='~/.cache', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.CIFAR10(root='~/.cache', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [65]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_size):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, embed_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # [B, C, H, W] -> [B, E, H', W']
        x = x.flatten(2)        # [B, E, H', W'] -> [B, E, N]
        x = x.transpose(1, 2)   # [B, E, N] -> [B, N, E]
        return x

class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=32, in_channels=3, embed_size=768, num_heads=12, num_layers=12, num_classes=10):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
        self.positional_encoding = nn.Parameter(torch.randn(1, 1 + (img_size // patch_size) ** 2, embed_size))

        encoder_layers = nn.TransformerEncoderLayer(embed_size, num_heads, dim_feedforward=embed_size * 4)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)

        self.to_cls_token = nn.Identity()
        self.fc = nn.Linear(embed_size, num_classes)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patch_embedding(x)
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.positional_encoding
        x = self.transformer_encoder(x)
        x = self.to_cls_token(x[:, 0])
        return self.fc(x)

In [66]:
model = ViT().to(device)
print(model)



ViT(
  (patch_embedding): PatchEmbedding(
    (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=3072, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=3072, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (to_cls_token): Identity()
  (fc): Linear(in_features=768, out_features=10, bias=True)
)


In [67]:
summary(model, (3, 224, 224))
flops = FlopCountAnalysis(model, torch.randn(1, 3, 224, 224).to(device))
print("FLOPs: {:.2f}GFLOPS".format(flops.total() / 1e9))

Layer (type:depth-idx)                        Output Shape              Param #
├─PatchEmbedding: 1-1                         [-1, 49, 768]             --
|    └─Conv2d: 2-1                            [-1, 768, 7, 7]           2,360,064
├─TransformerEncoder: 1-2                     [-1, 50, 768]             --
|    └─ModuleList: 2                          []                        --
|    |    └─TransformerEncoderLayer: 3-1      [-1, 50, 768]             7,087,872
|    |    └─TransformerEncoderLayer: 3-2      [-1, 50, 768]             7,087,872
|    |    └─TransformerEncoderLayer: 3-3      [-1, 50, 768]             7,087,872
|    |    └─TransformerEncoderLayer: 3-4      [-1, 50, 768]             7,087,872
|    |    └─TransformerEncoderLayer: 3-5      [-1, 50, 768]             7,087,872
|    |    └─TransformerEncoderLayer: 3-6      [-1, 50, 768]             7,087,872
|    |    └─TransformerEncoderLayer: 3-7      [-1, 50, 768]             7,087,872
|    |    └─TransformerEncoderLayer: 3-

Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::div encountered 12 time(s)
Unsupported operator aten::unflatten encountered 12 time(s)
Unsupported operator aten::mul encountered 48 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 12 time(s)
Unsupported operator aten::add encountered 24 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
transformer_encoder.layers.0.self_attn.out_proj, transformer_encoder.layers.1.self_attn.out_proj, transformer_encoder.layers.10.self_attn.out_proj, transformer_encoder.layers.11.self_attn.out_proj, transformer_encoder.layers.2.self_attn.out_proj, transformer_encoder.layers.3.self_attn.out_proj, transformer_encoder.layers.

FLOPs: 4.37GFLOPS


In [68]:
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

# 训练循环
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit='batch') as tepoch:
        for images, labels in tepoch:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            tepoch.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

    # 模型评估
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")

Epoch 1/10:  22%|██▏       | 171/782 [00:32<01:55,  5.30batch/s, loss=2.36]


KeyboardInterrupt: 