In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [2]:
# 定义 transform，包括缩放、中心裁剪、随机水平翻转、归一化
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 只需要归一化和中心裁剪
transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 加载 CIFAR10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'

@torch.no_grad()
def validate(net:nn.Module, dataloader:DataLoader, loss_fn:nn.Module):
    net.eval()
    metrics = {'loss':0, 'acc':0, 'num_samples':0, 'num_batches':0}
    pbar = tqdm(total=len(dataloader), desc=f"Eval test processing", leave=False)
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        output = net(images)
        loss = loss_fn(output, labels)
        metrics['loss'] += loss.item()
        pred = output.argmax(dim=1)
        metrics['acc'] += (pred==labels).sum().item()
        metrics['num_samples'] += len(labels)
        metrics['num_batches'] += 1
        pbar.update(1)
    pbar.close()
    net.train()
    metrics['loss'] /= metrics['num_batches']
    metrics['acc'] /= metrics['num_samples']
    return metrics

def save_model(net:nn.Module, path:str):
    torch.save(net.state_dict(), path)

def train(net:nn.Module, train_loader:DataLoader, test_loader:DataLoader,
            loss_fn:nn.Module,
            optimizer:torch.optim.Optimizer=None,
            num_epochs:int=5):
    n_batches = len(train_loader)
    best_acc = 0
    pbar = tqdm(total=num_epochs * n_batches, desc="Training batches", leave=True, unit="batch")
    for epoch in range(num_epochs):
        for i, (images,labels) in enumerate(train_loader):
            i_batch = epoch * n_batches + i
            if i_batch % 100 == 0 or i_batch == num_epochs * n_batches - 1:
                metrics = validate(net, test_loader, loss_fn)
                print(f"\n---step {i_batch} {metrics}---")
                if i_batch > 0.5 * epoch * n_batches and metrics['acc'] > best_acc:
                    best_acc = metrics['acc']
                    save_model(net, 'best_model.pth')
            images, labels = images.to(device), labels.to(device)
            output = net(images.to(device))
            loss = loss_fn(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_description(f"loss: {loss.item()}")
            pbar.update(1)
    pbar.close()

In [4]:
import torchvision.models as models
resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet50 = resnet50.to(device)

In [5]:
n_classes = 10
in_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(in_features, n_classes).to(device)

In [6]:
for x in resnet50.parameters():
    x.requires_grad = False

for x in resnet50.fc.parameters():
    x.requires_grad = True

In [7]:
resnet50.load_state_dict(torch.load('best_model_1.pth'))
for x in resnet50.parameters():
    x.requires_grad = True

In [8]:
learning_rate = 1e-3
optimizer = torch.optim.Adam(resnet50.parameters(), learning_rate)
loss_fn = nn.CrossEntropyLoss()
num_epochs = 2
batch_size = 128
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
train(
  resnet50,
  trainloader,
  testloader,
  loss_fn=loss_fn,
  num_epochs=num_epochs,
  optimizer=optimizer,
)

Training batches:   0%|          | 0/782 [00:00<?, ?batch/s]

  self.pid = os.fork()


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 0 {'loss': 0.29998534314240083, 'acc': 0.8997, 'num_samples': 10000, 'num_batches': 79}---


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 100 {'loss': 0.5224341342720804, 'acc': 0.8342, 'num_samples': 10000, 'num_batches': 79}---


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 200 {'loss': 0.341302950359598, 'acc': 0.8863, 'num_samples': 10000, 'num_batches': 79}---


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 300 {'loss': 0.32832302718977385, 'acc': 0.8889, 'num_samples': 10000, 'num_batches': 79}---


  self.pid = os.fork()


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 400 {'loss': 0.2973980626350717, 'acc': 0.8959, 'num_samples': 10000, 'num_batches': 79}---


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 500 {'loss': 0.2973129209838336, 'acc': 0.9029, 'num_samples': 10000, 'num_batches': 79}---


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 600 {'loss': 0.29825972822270813, 'acc': 0.8988, 'num_samples': 10000, 'num_batches': 79}---


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 700 {'loss': 0.36646155164211613, 'acc': 0.8824, 'num_samples': 10000, 'num_batches': 79}---


Eval test processing:   0%|          | 0/79 [00:00<?, ?it/s]


---step 781 {'loss': 0.24061085943934285, 'acc': 0.9173, 'num_samples': 10000, 'num_batches': 79}---
