In [1]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class ImagenetteDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.images = self._load_images()

    def _load_images(self):
        images = []
        for cls_name in self.classes:
            cls_dir = os.path.join(self.root_dir, cls_name)
            for img_name in os.listdir(cls_dir):
                images.append((os.path.join(cls_dir, img_name), self.class_to_idx[cls_name]))
        return images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path, label = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# 定义预处理步骤
def get_transform(is_train=True):
    # ImageNet数据集的均值和标准差
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    
    if is_train:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    
    return transform

# 使用示例
train_transform = get_transform(is_train=True)
val_transform = get_transform(is_train=False)

# 假设您的Imagenette数据集路径
train_dataset = ImagenetteDataset(root_dir='/scratch/wenjie/imagenette2-320/train', transform=train_transform)
val_dataset = ImagenetteDataset(root_dir='/scratch/wenjie/imagenette2-320/val', transform=val_transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# 打印一个批次的形状和值范围
for images, labels in train_loader:
    print(f"Batch shape: {images.shape}")
    print(f"Value range: [{images.min():.2f}, {images.max():.2f}]")
    break

Batch shape: torch.Size([32, 3, 224, 224])
Value range: [-2.12, 2.64]


In [2]:
# MobileNetV2, 用来训练ImageNet12数据集
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt


# 加载预训练的 MobileNetV2 模型
model = models.mobilenet_v2(pretrained=False)
# 修改最后的全连接层以输出 12 个类别
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 12)



In [3]:
# 在gpu上训练
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [4]:
import torch.nn as nn
# 训练模型
num_epochs = 50
learning_rate = 0.01


momentum = 0.9
weight_decay = 0
# weight_decay = 1e-4 # L2正则化系数

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    model.train()
    model.to(device)
    train_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    print('Epoch: {}, Loss: {:.4f}'.format(epoch, train_loss))

    # 在测试集上评估模型
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Epoch {epoch+1}, Test Accuracy: {100 * correct / total}%")

Epoch: 0, Loss: 2.1903
Epoch 1, Test Accuracy: 36.968152866242036%
Epoch: 1, Loss: 1.8983
Epoch 2, Test Accuracy: 47.77070063694268%
Epoch: 2, Loss: 1.6967
Epoch 3, Test Accuracy: 54.72611464968153%
Epoch: 3, Loss: 1.5630
Epoch 4, Test Accuracy: 57.197452229299365%
Epoch: 4, Loss: 1.4634
Epoch 5, Test Accuracy: 54.955414012738856%
Epoch: 5, Loss: 1.3681
Epoch 6, Test Accuracy: 67.03184713375796%
Epoch: 6, Loss: 1.2756
Epoch 7, Test Accuracy: 68.94267515923566%
Epoch: 7, Loss: 1.2061
Epoch 8, Test Accuracy: 68.73885350318471%
Epoch: 8, Loss: 1.1612
Epoch 9, Test Accuracy: 71.43949044585987%
Epoch: 9, Loss: 1.1113
Epoch 10, Test Accuracy: 73.09554140127389%
Epoch: 10, Loss: 1.0719
Epoch 11, Test Accuracy: 76.28025477707007%
Epoch: 11, Loss: 0.9991
Epoch 12, Test Accuracy: 76.28025477707007%
Epoch: 12, Loss: 0.9972
Epoch 13, Test Accuracy: 77.07006369426752%
Epoch: 13, Loss: 0.9809
Epoch 14, Test Accuracy: 75.94904458598727%
Epoch: 14, Loss: 0.9200
Epoch 15, Test Accuracy: 77.910828025477