In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import matplotlib as plt
import time

from model import ConvBlock, CustomResNet, initialize_weights
from torchvision.models.resnet import BasicBlock, Bottleneck

from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import ToTensor, Compose, RandomCrop, RandomHorizontalFlip, Normalize
from torch.utils.data import DataLoader, random_split
from torch.optim import SGD, lr_scheduler
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128
epoch = 180
gamma = 0.1
milestones = [90, 120] # from resnet paper
temp = 5
alpha = 0.5

In [2]:
transform = Compose([
    ToTensor(),
    RandomCrop(size=[32, 32], padding=4),
    RandomHorizontalFlip(p=0.5),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

transform_test = Compose([
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = CIFAR10(root='./data', train=False, transform=transform_test, download=True)

# DataLoader 정의
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# 데이터셋 확인
print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(test_dataset)}')

Files already downloaded and verified
Files already downloaded and verified
Train dataset size: 50000
Validation dataset size: 10000


In [4]:
model_name = 'resnet83_baseline'

model = CustomResNet(block=Bottleneck,
                   layers=[9, 9, 9],
                   num_classes=10).to(device)
model.apply(initialize_weights)

total_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters: {total_params}')

Total number of parameters: 869530


In [5]:
optimizer = SGD(params=model.parameters(), lr=0.1, nesterov=True, momentum=0.9, weight_decay=0.0001)
scheduler = lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=gamma, verbose=True)
criterion = nn.CrossEntropyLoss()

history = dict(train_loss=[], test_acc=[], train_time=[], test_time=[])
for ep in range(epoch):
    # train phase
    train_loss = 0.0
    model.train()
    s_time = time.time()
    for image, target in train_loader:
        image = image.to(device)
        target = f.one_hot(target, 10).float().to(device)

        pred = model(image)
        loss = criterion(pred, target)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    e_time = time.time()
    history['train_loss'].append(train_loss/len(train_loader))
    history['train_time'].append(e_time - s_time)

    # test phase
    test_acc = 0.0

    model.eval()
    s_time = time.time()
    for image, target in test_loader:
        image = image.to(device)
        target = f.one_hot(target, 10).float().to(device)

        pred = model(image)
        test_acc += torch.sum(torch.argmax(pred, dim=1) == torch.argmax(target, dim=1)).item()
    e_time = time.time()
    history['test_acc'].append(test_acc/len(test_dataset))
    history['test_time'].append(e_time - s_time)
    print(f'epoch={ep}, train_loss={train_loss/len(train_loader):.3f}, test_acc={test_acc/len(test_dataset):.3f}')

    checkpoint = dict(
        model=model.state_dict(),
        optimizer=optimizer.state_dict(),
        history=history,
        epoch=ep
    )
    torch.save(checkpoint, f'./result/{model_name}.pt')
    
    scheduler.step()



epoch=0, train_loss=2.403, test_acc=0.278
epoch=1, train_loss=1.805, test_acc=0.332
epoch=2, train_loss=1.611, test_acc=0.421
epoch=3, train_loss=1.425, test_acc=0.519
epoch=4, train_loss=1.253, test_acc=0.562
epoch=5, train_loss=1.109, test_acc=0.629
epoch=6, train_loss=1.003, test_acc=0.629
epoch=7, train_loss=0.914, test_acc=0.658
epoch=8, train_loss=0.833, test_acc=0.643
epoch=9, train_loss=0.762, test_acc=0.726
epoch=10, train_loss=0.706, test_acc=0.744
epoch=11, train_loss=0.654, test_acc=0.748
epoch=12, train_loss=0.617, test_acc=0.732
epoch=13, train_loss=0.583, test_acc=0.778
epoch=14, train_loss=0.551, test_acc=0.755
epoch=15, train_loss=0.521, test_acc=0.775
epoch=16, train_loss=0.501, test_acc=0.797
epoch=17, train_loss=0.479, test_acc=0.806
epoch=18, train_loss=0.455, test_acc=0.810
epoch=19, train_loss=0.444, test_acc=0.785
epoch=20, train_loss=0.426, test_acc=0.789
epoch=21, train_loss=0.410, test_acc=0.806
epoch=22, train_loss=0.403, test_acc=0.833
epoch=23, train_loss=