In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

from resnet import ResNet18, ResNet34
from utils import OptimOneCycleLR, LabelSmoothingCrossEntropy

%matplotlib inline

# Dataset

In [2]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.25, contrast=0.25, hue=0.1, saturation=0.1),
    transforms.RandomRotation(degrees=30.),
    transforms.ToTensor(),
    transforms.Normalize((0.5070754 , 0.48655024, 0.44091907), (0.26733398, 0.25643876, 0.2761503)),
    transforms.RandomErasing(p=0.75, scale=(0.02, 0.1), ratio=(0.2, 5), value=0)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5070754 , 0.48655024, 0.44091907), (0.26733398, 0.25643876, 0.2761503))
])

In [3]:
trainset = torchvision.datasets.CIFAR100(
    root='../data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR100(
    root='../data', train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
num_train = trainset.__len__()
num_test = testset.__len__()
bs_train = 1024
bs_test = 2048

In [5]:
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=bs_train, shuffle=True, num_workers=4, pin_memory=True)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=bs_test, shuffle=True, num_workers=4, pin_memory=True)

# Training

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
net = ResNet18(64, 100)
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

In [8]:
epochs=100
criterion = nn.CrossEntropyLoss()
# criterion = LabelSmoothingCrossEntropy()

# SGD
optimizer = optim.SGD(net.parameters(), lr=0.08, momentum=0.9, weight_decay=1e-5, nesterov=True)
scheduler = OptimOneCycleLR(optimizer, 0.1, 1., 0.001,
                            epochs, len(trainloader), 0.25, 0.25, 0., 0., 'linear')

# AdamW
# optimizer = torch.optim.AdamW(net.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1e-08, weight_decay=1e-4, amsgrad=False)
# scheduler = OptimOneCycleLR(optimizer, 1e-3, 1e-2, 1e-5,
#                             epochs, len(trainloader), 0.1, 0.5, 0.2, 0.15)

In [9]:
print(f"Begin Training for {epochs} epoch")
test_best_acc = 1000.
for epoch in range(1, epochs + 1):
    train_loss = 0
    train_correct = 0
    train_total = 0
    
    net.train()
    for i, (x, y) in enumerate(trainloader):
        x, y = x.to(device), y.to(device)
        
        pred = net(x)
        loss = criterion(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        preds = torch.argmax(pred, dim=1)
        train_loss += loss.item()
        train_correct += (preds == y).sum().item()
        train_total += y.size(0)
        
    test_loss = 0
    test_correct = 0
    test_total = 0
    
    net.eval()
    for x, y in testloader:
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            pred = net(x)
            loss = criterion(pred, y)

            preds = torch.argmax(pred, dim=1)
            test_loss += loss.item()
            test_correct += (preds == y).sum().item()
            test_total += y.size(0)

    print(f'Epoch: {epoch} || train_loss: {train_loss / ((num_train - 1) // bs_train + 1):.5f} || train_acc : {train_correct / train_total:.5f} || test_loss: {test_loss / ((num_test - 1) // bs_test + 1):.5f} || test_acc : {test_correct / test_total:.5f} ')
    
    if test_best_acc > test_correct / test_total:
        torch.save(net.state_dict(), 'models/CIFAR100/ResNet18/model.pt')
        
torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss / ((num_train - 1) // bs_train + 1),
            }, "models/CIFAR100/ResNet18/last_checkpoints.pt")

Begin Training for 100 epoch
Epoch: 1 || train_loss: 4.22012 || train_acc : 0.06078 || test_loss: 4.49792 || test_acc : 0.07590 
Epoch: 2 || train_loss: 3.80961 || train_acc : 0.11108 || test_loss: 3.54655 || test_acc : 0.14840 
Epoch: 3 || train_loss: 3.60165 || train_acc : 0.14504 || test_loss: 3.34925 || test_acc : 0.19080 
Epoch: 4 || train_loss: 3.42955 || train_acc : 0.17522 || test_loss: 3.18163 || test_acc : 0.22350 
Epoch: 5 || train_loss: 3.27281 || train_acc : 0.20198 || test_loss: 3.00147 || test_acc : 0.25230 
Epoch: 6 || train_loss: 3.10431 || train_acc : 0.23290 || test_loss: 2.82128 || test_acc : 0.29650 
Epoch: 7 || train_loss: 2.95424 || train_acc : 0.25932 || test_loss: 2.78559 || test_acc : 0.31370 
Epoch: 8 || train_loss: 2.82819 || train_acc : 0.28712 || test_loss: 2.49408 || test_acc : 0.36090 
Epoch: 9 || train_loss: 2.69828 || train_acc : 0.31514 || test_loss: 2.49864 || test_acc : 0.37260 
Epoch: 10 || train_loss: 2.59103 || train_acc : 0.33710 || test_loss: 2

In [10]:
# x, y = next(iter(trainloader))
# x, y = x.to(device), y.to(device)

# epochs = 100
# optimizer = torch.optim.AdamW(net.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1e-08, weight_decay=1e-4, amsgrad=False)
# scheduler = OptimOneCycleLR(optimizer, 1e-3, 1e-2, 1e-5,
#                             epochs, 1, 0.2, 0.4)

# print(f"Begin Training for {epochs} epoch")

# for epoch in range(1, epochs + 1):
#     train_loss = 0
#     train_correct = 0
#     train_total = 0
    
#     net.train()
        
#     pred = net(x)
#     loss = criterion(pred, y)

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
#     scheduler.step()

#     preds = torch.argmax(pred, dim=1)
#     train_loss += loss.item()
#     train_correct += (preds == y).sum().item()
#     train_total += y.size(0)

#     print(f'Epoch: {epoch} || train_loss: {train_loss / ((num_train - 1) // bs_train + 1):.5f} || train_acc : {train_correct / train_total:.5f}')