In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import matplotlib as plt
import time

from model import CustomResNet
from torchvision.models.resnet import BasicBlock, Bottleneck

from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import ToTensor, Compose, RandomCrop, RandomHorizontalFlip, Normalize
from torch.utils.data import DataLoader, random_split
from torch.optim import SGD, lr_scheduler
from tqdm import tqdm

dataset_name = 'cifar10'
student_name = 'resnet56_baseline'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128
epoch = 180
gamma = 0.2
milestones = [60, 120, 150]

In [2]:
transform = Compose([
    ToTensor(),
    RandomCrop(size=[32, 32], padding=4),
    RandomHorizontalFlip(p=0.5),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

transform_test = Compose([
    ToTensor(),
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = CIFAR10(root='./data', train=False, transform=transform_test, download=True)

# DataLoader 정의
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# 데이터셋 확인
print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(test_dataset)}')

Files already downloaded and verified
Files already downloaded and verified
Train dataset size: 50000
Validation dataset size: 10000


In [3]:
model = CustomResNet(block=BasicBlock,
                   layers=[9, 9, 9],
                   num_classes=10).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters: {total_params}')
print(model)

Total number of parameters: 855770
CustomResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layers): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3),

In [4]:
optimizer = SGD(params=model.parameters(), lr=0.1, nesterov=True, momentum=0.9, weight_decay=0.0001)
scheduler = lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=gamma)
criterion = nn.CrossEntropyLoss()

In [5]:
history = dict(train_loss=[], test_acc=[], train_time=[], test_time=[])
for ep in range(epoch):
    # train step
    train_loss = 0.0
    model.train()
    s_time = time.time()
    for image, target in train_loader:
        image = image.to(device)
        target = f.one_hot(target, 10).float().to(device)

        pred = model(image)
        loss = criterion(pred, target)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    e_time = time.time()
    history['train_loss'].append(train_loss/len(train_loader))
    history['train_time'].append(e_time - s_time)

    # test step
    test_acc = 0.0
    model.eval()
    s_time = time.time()
    for image, target in test_loader:
        image = image.to(device)
        target = f.one_hot(target, 10).float().to(device)

        pred = model(image)
        test_acc += torch.sum(torch.argmax(pred, dim=1) == torch.argmax(target, dim=1)).item()
    e_time = time.time()
    history['test_acc'].append(test_acc/len(test_dataset))
    history['test_time'].append(e_time - s_time)
    print(f'epoch={ep:3d}, train_loss={train_loss/len(train_loader):.4f}, test_acc={test_acc/len(test_dataset):.3f}')

    checkpoint = dict(
        model=model.state_dict(),
        optimizer=optimizer.state_dict(),
        history=history,
        epoch=ep
    )
    torch.save(checkpoint, f'./result/{dataset_name}_{student_name}.pt')
    scheduler.step()

epoch=  0, train_loss=2.1335, test_acc=0.306
epoch=  1, train_loss=1.7103, test_acc=0.412
epoch=  2, train_loss=1.4934, test_acc=0.485
epoch=  3, train_loss=1.2722, test_acc=0.563
epoch=  4, train_loss=1.0840, test_acc=0.630
epoch=  5, train_loss=0.9232, test_acc=0.692
epoch=  6, train_loss=0.7970, test_acc=0.644
epoch=  7, train_loss=0.7057, test_acc=0.753
epoch=  8, train_loss=0.6418, test_acc=0.763
epoch=  9, train_loss=0.5888, test_acc=0.766
epoch= 10, train_loss=0.5534, test_acc=0.793
epoch= 11, train_loss=0.5189, test_acc=0.775
epoch= 12, train_loss=0.4930, test_acc=0.766
epoch= 13, train_loss=0.4690, test_acc=0.777
epoch= 14, train_loss=0.4444, test_acc=0.825
epoch= 15, train_loss=0.4287, test_acc=0.830
epoch= 16, train_loss=0.4172, test_acc=0.828
epoch= 17, train_loss=0.3969, test_acc=0.826
epoch= 18, train_loss=0.3802, test_acc=0.839
epoch= 19, train_loss=0.3707, test_acc=0.853
epoch= 20, train_loss=0.3616, test_acc=0.843
epoch= 21, train_loss=0.3476, test_acc=0.824
epoch= 22,

In [38]:
checkpoint = torch.load(f'./result/cifar10_resnet56_baseline.pt')
sum(checkpoint['history']['train_time']) + sum(checkpoint['history']['test_time'])

4625.111322402954

In [42]:
checkpoint = torch.load(f'./result/cifar10_resnet56_recast.pt')
sum(checkpoint['history'][0]['train_time']), sum(checkpoint['history'][0]['test_time'])

(1033.6733675003052, 185.7128963470459)

In [28]:
from model import ConvBlock
student_name = 'resnet56_recast'

model = CustomResNet(block=BasicBlock,
                layers=[9, 9, 9],
                num_classes=10).to(device).eval()
model.layers[0] = ConvBlock(16,16,2)
state = model.state_dict()

In [30]:
state.keys()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layers.0.conv1.weight', 'layers.0.bn1.weight', 'layers.0.bn1.bias', 'layers.0.bn1.running_mean', 'layers.0.bn1.running_var', 'layers.0.bn1.num_batches_tracked', 'layers.1.conv1.weight', 'layers.1.bn1.weight', 'layers.1.bn1.bias', 'layers.1.bn1.running_mean', 'layers.1.bn1.running_var', 'layers.1.bn1.num_batches_tracked', 'layers.1.conv2.weight', 'layers.1.bn2.weight', 'layers.1.bn2.bias', 'layers.1.bn2.running_mean', 'layers.1.bn2.running_var', 'layers.1.bn2.num_batches_tracked', 'layers.2.conv1.weight', 'layers.2.bn1.weight', 'layers.2.bn1.bias', 'layers.2.bn1.running_mean', 'layers.2.bn1.running_var', 'layers.2.bn1.num_batches_tracked', 'layers.2.conv2.weight', 'layers.2.bn2.weight', 'layers.2.bn2.bias', 'layers.2.bn2.running_mean', 'layers.2.bn2.running_var', 'layers.2.bn2.num_batches_tracked', 'layers.3.conv1.weight', 'layers.3.bn1.weight', 'layers.3.bn1.bias', '