In [1]:
from models import VGG_A, VGG_A_BN
import torch
from torch import optim
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import os
import pickle
from tqdm import trange


transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def train(model, optimizer, criterion, epochs=20, save_dir="results/"):
    train_losses = []
    grads = []

    os.makedirs(save_dir, exist_ok=True)

    model.train()
    for epoch in trange(epochs, desc="Training", unit="epoch"):
        train_loss = 0.0

        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            grad = model.classifier[4].weight.grad.clone()
            grads.append(grad)

            train_loss = loss.item()
            train_losses.append(train_loss)

    torch.save(model.state_dict(), os.path.join(save_dir, "model.pth"))

    results = {
        "train_losses": train_losses,
        "grads": grads
    }

    with open(os.path.join(save_dir, "training_results.pkl"), "wb") as f:
        pickle.dump(results, f)

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label="Training Loss", color="blue")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Time")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "loss_plot.png"))  # 保存图像
    plt.show()
    
    return train_losses, grads

def evaluate(model):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

    return accuracy

In [None]:
learning_rates = [1e-3, 2e-3, 1e-4, 5e-4]
criterion = nn.CrossEntropyLoss()

for learning_rate in learning_rates:
    model1 = VGG_A().to(device)
    optimizer = optim.Adam(model1.parameters(), lr=learning_rate, weight_decay=1e-4)

    train(model1, optimizer, criterion, 20, f"results/VGG-{learning_rate:.0e}")
    evaluate(model1)


for learning_rate in learning_rates:
    model2 = VGG_A_BN().to(device)
    optimizer = optim.Adam(model2.parameters(), lr=learning_rate, weight_decay=1e-4)

    train(model2, optimizer, criterion, 20, f"results/VGG_BN-{learning_rate:.0e}")
    evaluate(model2)

In [None]:
criterion = nn.CrossEntropyLoss()

model = VGG_A().to(device)
# model = VGG_A_BN().to(device)

save_path = ""
model.load_state_dict(torch.load(save_path))

evaluate(model)