In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import matplotlib as plt
import time

from model import ConvBlock, CustomResNet, initialize_weights
from torchvision.models.resnet import BasicBlock, Bottleneck

from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import ToTensor, Compose, RandomCrop, RandomHorizontalFlip, Normalize
from torch.utils.data import DataLoader, random_split
from torch.optim import SGD, lr_scheduler
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128
epoch = 180
gamma = 0.1
milestones = [90, 120] # from resnet paper
temp = 5
alpha = 0.5

In [2]:
transform = Compose([
    ToTensor(),
    RandomCrop(size=[32, 32], padding=4),
    RandomHorizontalFlip(p=0.5),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

transform_test = Compose([
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = CIFAR10(root='./data', train=False, transform=transform_test, download=True)

# DataLoader 정의
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# 데이터셋 확인
print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(test_dataset)}')

Files already downloaded and verified
Files already downloaded and verified
Train dataset size: 50000
Validation dataset size: 10000


In [3]:
model_name = 'resnet56_kd'

model = CustomResNet(block=ConvBlock,
                   layers=[9, 9, 9],
                   num_classes=10).to(device)
model.apply(initialize_weights)

total_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters: {total_params}')

t_model = CustomResNet(block=BasicBlock,
                   layers=[9, 9, 9],
                   num_classes=10).to(device).eval()
t_model.load_state_dict(torch.load('./result/resnet56_baseline.pt')['model'])

Total number of parameters: 415546


<All keys matched successfully>

In [4]:
optimizer = SGD(params=model.parameters(), lr=0.1, nesterov=True, momentum=0.9, weight_decay=0.0001)
scheduler = lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=gamma, verbose=True)
criterion = nn.CrossEntropyLoss()

history = dict(train_loss=[], test_acc=[], train_time=[], test_time=[])
for ep in range(epoch):
    # train phase
    train_loss = 0.0
    model.train()
    s_time = time.time()
    for image, target in train_loader:
        image = image.to(device)
        target = f.one_hot(target, 10).float().to(device)

        pred = model(image)
        t_pred = t_model(image)
        distill_loss = f.kl_div(torch.softmax(pred / temp, dim=1), torch.softmax(t_pred / temp, dim=1))
        student_loss = criterion(pred, target)
        loss = alpha * distill_loss + (1-alpha) * student_loss
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    e_time = time.time()
    history['train_loss'].append(train_loss/len(train_loader))
    history['train_time'].append(e_time - s_time)

    # test phase
    test_acc = 0.0

    model.eval()
    s_time = time.time()
    for image, target in test_loader:
        image = image.to(device)
        target = f.one_hot(target, 10).float().to(device)

        pred = model(image)
        test_acc += torch.sum(torch.argmax(pred, dim=1) == torch.argmax(target, dim=1)).item()
    e_time = time.time()
    history['test_acc'].append(test_acc/len(test_dataset))
    history['test_time'].append(e_time - s_time)
    print(f'epoch={ep}, train_loss={train_loss/len(train_loader):.3f}, test_acc={test_acc/len(test_dataset):.3f}')

    checkpoint = dict(
        model=model.state_dict(),
        optimizer=optimizer.state_dict(),
        history=history,
        epoch=ep
    )
    torch.save(checkpoint, f'./result/{model_name}.pt')
    
    scheduler.step()



epoch=0, train_loss=0.839, test_acc=0.383
epoch=1, train_loss=0.704, test_acc=0.460
epoch=2, train_loss=0.612, test_acc=0.487
epoch=3, train_loss=0.517, test_acc=0.552
epoch=4, train_loss=0.438, test_acc=0.631
epoch=5, train_loss=0.377, test_acc=0.686
epoch=6, train_loss=0.334, test_acc=0.686
epoch=7, train_loss=0.302, test_acc=0.690
epoch=8, train_loss=0.281, test_acc=0.700
epoch=9, train_loss=0.262, test_acc=0.758
epoch=10, train_loss=0.245, test_acc=0.746
epoch=11, train_loss=0.232, test_acc=0.784
epoch=12, train_loss=0.220, test_acc=0.785
epoch=13, train_loss=0.207, test_acc=0.747
epoch=14, train_loss=0.201, test_acc=0.769
epoch=15, train_loss=0.195, test_acc=0.789
epoch=16, train_loss=0.188, test_acc=0.746
epoch=17, train_loss=0.179, test_acc=0.793
epoch=18, train_loss=0.175, test_acc=0.765
epoch=19, train_loss=0.172, test_acc=0.798
epoch=20, train_loss=0.163, test_acc=0.769
epoch=21, train_loss=0.162, test_acc=0.747
epoch=22, train_loss=0.158, test_acc=0.730
epoch=23, train_loss=