<a href="https://colab.research.google.com/github/minhhieu132005/pytorch-resnet50-CIFAR10/blob/main/Resnet50_pretrain_model_Cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset


In [None]:
batch_size = 64
epochs = 50
num_classes = 10
#learning_rate = 0.03
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])
full_dataset_aug = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

full_dataset_clean = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_test)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

generator = torch.Generator().manual_seed(42)

train_size = 45000
val_size = 5000

train_subset, val_subset_temp = random_split(
    full_dataset_aug, [train_size, val_size], generator=generator
)

val_indices = val_subset_temp.indices

val_subset = Subset(full_dataset_clean, val_indices)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def get_model():
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

    for param in model.parameters():
        param.requires_grad = False

    for param in model.layer4.parameters():
        param.requires_grad = True

    model.fc = nn.Linear(model.fc.in_features, 10)
    return model

model = get_model().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([
    {'params': model.layer4.parameters(), 'lr': 1e-4}, # Học chậm cho phần thân
    {'params': model.fc.parameters(), 'lr': 1e-2}      # Học nhanh cho phần đầu
], momentum=0.9, weight_decay=5e-4)
total_step = len(train_loader)


In [None]:
print("Start Training...")
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
best_val_acc = 0.0
for epoch in range(epochs):
    # --- TRAINING ---
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_train_loss = running_loss / total
    epoch_train_acc = correct / total
    history["train_loss"].append(epoch_train_loss)
    history["train_acc"].append(epoch_train_acc)

    # --- VALIDATION ---
    model.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    epoch_val_loss = val_running_loss / val_total
    epoch_val_acc = val_correct / val_total
    history["val_loss"].append(epoch_val_loss)
    history["val_acc"].append(epoch_val_acc)

    print(f'Epoch [{epoch+1}/{epochs}] '
          f'Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc:.4f} | '
          f'Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}')

    # Save Best Model
    if epoch_val_acc > best_val_acc:
        best_val_acc = epoch_val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'best_val_acc': best_val_acc,
        }, 'best_cifar10_resnet50.pth')

Epoch [1/50] Train Loss: 0.7924 | Train Acc: 0.7433 | Val Loss: 0.6565 | Val Acc: 0.7870


In [None]:
checkpoint = torch.load('best_cifar10_resnet50.pth', map_location=device)
model_test = get_model().to(device)
model_test.load_state_dict(checkpoint['model_state_dict'])
model_test.eval()
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the {} train images: {} %'.format(50000, 100* correct/total))