In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from vit_cifar import ViT
import os

In [22]:
import torch
print(torch.cuda.is_available())  # True if CUDA is available
print(torch.cuda.current_device())  # Should give a device id, usually 0


True
0


In [23]:
# 配置参数
batch_size = 64
epochs = 10
lr = 3e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [24]:
# 数据预处理
transform = transforms.Compose([  # ViT 需要较大输入，但是这个任务不需要
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [25]:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [26]:
# 加载模型
model = ViT(img_size=224, patch_size=16, num_classes=10)
model.to(device)



ViT(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
        )
        (linear1): Linear(in_features=192, out_features=384, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=384, out_features=192, bias=True)
        (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (mlp_head): Sequential(
    (0): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=192, out_features=10, b

In [27]:
# 损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [28]:
# 训练函数
def train():
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 解决 UserWarning 的修改
            total_loss += loss.detach().item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, Accuracy: {100*correct/total:.2f}%")

    os.makedirs("checkpoints", exist_ok=True)
    torch.save(model.state_dict(), "checkpoints/vit_cifar10.pth")
    print("模型已保存为 checkpoints/vit_cifar10.pth")


In [29]:
# 测试函数
def test():
    # 将模型设置为评估模式，这样在进行前向传播时不会计算梯度，也不会更新权重
    model.eval()
    # 初始化正确预测的计数器
    correct = 0
    # 初始化测试样本的总数计数器
    total = 0
    # 使用torch.no_grad()上下文管理器，确保在前向传播过程中不计算梯度，节省内存和计算资源
    with torch.no_grad():
        # 遍历测试数据加载器中的所有批次
        for images, labels in test_loader:
            # 将图像和标签移动到指定的设备（如GPU）
            images, labels = images.to(device), labels.to(device)
            # 通过模型前向传播得到输出
            outputs = model(images)
            # 使用torch.max()函数找到每个样本的预测类别，返回最大值的索引
            _, predicted = torch.max(outputs.data, 1)
            # 更新测试样本的总数
            total += labels.size(0)
            # 计算预测正确的样本数，并转换为Python标准整数
            correct += (predicted == labels).sum().item()
    # 打印测试集的准确率，格式化为两位小数
    print(f"测试集准确率: {100 * correct / total:.2f}%")

In [30]:
if __name__ == '__main__':
    train()
    test()

RuntimeError: The size of tensor a (5) must match the size of tensor b (197) at non-singleton dimension 1

In [None]:
from PIL import Image
# CIFAR-10 类别名称
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck']
# 加载模型
model = ViT(img_size=224, patch_size=16, num_classes=10)
model.load_state_dict(torch.load("checkpoints/vit_cifar10.pth", map_location=device))
model.to(device)
model.eval()
def predict(img_path):
    image = Image.open(img_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs.data, 1)
        print(f"预测结果: {classes[predicted.item()]}")

predict("path")调用

In [None]:
predict("D:\model_project\cifar10\data\images\plane.png")  # Replace with your image path