In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np

In [2]:
from resnet import ResNet, BasicBlock  # Assuming you have a ResNet18 implementation in resnet.py

准备数据

In [3]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)


100%|██████████| 170M/170M [00:38<00:00, 4.37MB/s] 


训练与评估

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(BasicBlock, [2, 2, 2, 2]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

def train(model, loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    acc = 100. * correct / total
    print(f"[Train] Epoch {epoch} | Loss: {running_loss:.3f} | Acc: {acc:.2f}%")

def test(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    acc = 100. * correct / total
    print(f"[Test]  Loss: {total_loss:.3f} | Acc: {acc:.2f}%")

In [5]:
for epoch in range(1, 11):
    train(model, train_loader, optimizer, criterion, epoch)
    test(model, test_loader, criterion)

[Train] Epoch 1 | Loss: 806.490 | Acc: 28.78%
[Test]  Loss: 160.204 | Acc: 40.71%
[Train] Epoch 2 | Loss: 569.843 | Acc: 46.69%
[Test]  Loss: 129.082 | Acc: 52.76%
[Train] Epoch 3 | Loss: 450.840 | Acc: 58.52%
[Test]  Loss: 111.535 | Acc: 60.92%
[Train] Epoch 4 | Loss: 366.694 | Acc: 66.64%
[Test]  Loss: 100.469 | Acc: 65.50%
[Train] Epoch 5 | Loss: 306.437 | Acc: 72.22%
[Test]  Loss: 79.877 | Acc: 72.43%
[Train] Epoch 6 | Loss: 256.547 | Acc: 77.02%
[Test]  Loss: 71.332 | Acc: 75.51%
[Train] Epoch 7 | Loss: 220.182 | Acc: 80.30%
[Test]  Loss: 71.426 | Acc: 75.09%
[Train] Epoch 8 | Loss: 198.480 | Acc: 82.47%
[Test]  Loss: 68.716 | Acc: 76.20%
[Train] Epoch 9 | Loss: 182.594 | Acc: 83.81%
[Test]  Loss: 62.708 | Acc: 78.57%
[Train] Epoch 10 | Loss: 171.886 | Acc: 84.76%
[Test]  Loss: 67.419 | Acc: 77.59%


疑似有点过拟合了

In [8]:
torch.save(model.state_dict(), "resnet18_cifar10.pth")
print("Model saved as resnet18_cifar10.pth")

Model saved as resnet18_cifar10.pth


In [6]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [9]:
inference_model = ResNet(BasicBlock, [2, 2, 2, 2]).to(device)
inference_model.load_state_dict(torch.load("resnet18_cifar10.pth"))
inference_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, aff

In [10]:
# Inference transform (no flip!)
inference_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load image and predict
def predict_image(image_path):
    img = Image.open(image_path).convert('RGB')
    img_tensor = inference_transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = inference_model(img_tensor)
        _, predicted = torch.max(outputs, 1)
        print(f"Predicted: {classes[predicted.item()]}")

predict_image("./data/cifar10_test_example.png")

In [13]:
predict_image("D:\model_project\cifar10\data\images\plane.png")  # Replace with your image path

Predicted: plane
