In [1]:
import time
import random
import pathlib
from os.path import isfile
import copy
import sys

import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn

from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

from resnet_mask import *
from utils import *

In [2]:
torch.cuda.is_available()

True

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)

In [4]:
## args
layers = 56
prune_type = 'structured'
prune_rate = 0.9
prune_imp = 'L2'
epochs = 300
batch_size = 128
lr = 0.2
momentum = 0.9
wd = 1e-4

cfgs = {
    '18':  (BasicBlock, [2, 2, 2, 2]),
    '34':  (BasicBlock, [3, 4, 6, 3]),
    '50':  (Bottleneck, [3, 4, 6, 3]),
    '101': (Bottleneck, [3, 4, 23, 3]),
    '152': (Bottleneck, [3, 8, 36, 3]),
}
cfgs_cifar = {
    '20':  [3, 3, 3],
    '32':  [5, 5, 5],
    '44':  [7, 7, 7],
    '56':  [9, 9, 9],
    '110': [18, 18, 18],
}

In [5]:
train_data_mean = (0.5, 0.5, 0.5)
train_data_std = (0.5, 0.5, 0.5)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(train_data_mean, train_data_std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_data_mean, train_data_std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [6]:
model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device)
image_size = 32

In [7]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr,
                            momentum=momentum, weight_decay=wd) #nesterov=args.nesterov)
lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [8]:
##### main 함수 보고 train 짜기
best_acc1 = 0.0

for epoch in range(epochs):
    
    acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model, 
                                       criterion=criterion, optimizer=optimizer, reg=reg_cov, odecay=2)
    acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, criterion=criterion)
    
    acc1_train = round(acc1_train_cor.item(), 4)
    acc5_train = round(acc5_train_cor.item(), 4)
    acc1_valid = round(acc1_valid_cor.item(), 4)
    acc5_valid = round(acc5_valid_cor.item(), 4)

    # remember best Acc@1 and save checkpoint and summary csv file
#     summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]

    is_best = acc1_valid > best_acc1
    best_acc1 = max(acc1_valid, best_acc1)
    if is_best:
        summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
#         save_model(arch_name, args.dataset, state, args.save)
#     save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary)

summary

train 0 ====> Acc@1 22.296 Acc@5 73.584
valid 0 ====> Acc@1 30.330 Acc@5 83.780
train 1 ====> Acc@1 38.758 Acc@5 89.150
valid 1 ====> Acc@1 43.130 Acc@5 91.770
train 2 ====> Acc@1 52.182 Acc@5 94.006
valid 2 ====> Acc@1 56.500 Acc@5 95.490
train 3 ====> Acc@1 62.226 Acc@5 96.368
valid 3 ====> Acc@1 60.310 Acc@5 96.110
train 4 ====> Acc@1 68.914 Acc@5 97.608
valid 4 ====> Acc@1 62.990 Acc@5 95.930
train 5 ====> Acc@1 73.496 Acc@5 98.144
valid 5 ====> Acc@1 68.550 Acc@5 96.470
train 6 ====> Acc@1 76.722 Acc@5 98.614
valid 6 ====> Acc@1 74.650 Acc@5 98.380
train 7 ====> Acc@1 79.262 Acc@5 98.970
valid 7 ====> Acc@1 76.580 Acc@5 98.220
train 8 ====> Acc@1 80.964 Acc@5 99.014
valid 8 ====> Acc@1 78.320 Acc@5 98.590
train 9 ====> Acc@1 82.358 Acc@5 99.174
valid 9 ====> Acc@1 72.310 Acc@5 98.380
train 10 ====> Acc@1 83.164 Acc@5 99.254
valid 10 ====> Acc@1 80.530 Acc@5 99.070
train 11 ====> Acc@1 84.740 Acc@5 99.384
valid 11 ====> Acc@1 79.450 Acc@5 98.740
train 12 ====> Acc@1 85.248 Acc@5 99

valid 100 ====> Acc@1 79.900 Acc@5 98.910
train 101 ====> Acc@1 95.872 Acc@5 99.970
valid 101 ====> Acc@1 84.900 Acc@5 98.900
train 102 ====> Acc@1 96.238 Acc@5 99.974
valid 102 ====> Acc@1 86.300 Acc@5 99.300
train 103 ====> Acc@1 96.138 Acc@5 99.984
valid 103 ====> Acc@1 87.400 Acc@5 99.390
train 104 ====> Acc@1 96.138 Acc@5 99.980
valid 104 ====> Acc@1 81.920 Acc@5 99.130
train 105 ====> Acc@1 95.940 Acc@5 99.978
valid 105 ====> Acc@1 86.040 Acc@5 99.380
train 106 ====> Acc@1 95.820 Acc@5 99.976
valid 106 ====> Acc@1 82.450 Acc@5 98.480
train 107 ====> Acc@1 95.896 Acc@5 99.974
valid 107 ====> Acc@1 85.770 Acc@5 99.110
train 108 ====> Acc@1 96.138 Acc@5 99.976
valid 108 ====> Acc@1 85.680 Acc@5 99.170
train 109 ====> Acc@1 95.802 Acc@5 99.972
valid 109 ====> Acc@1 81.070 Acc@5 98.850
train 110 ====> Acc@1 95.968 Acc@5 99.974
valid 110 ====> Acc@1 87.160 Acc@5 99.430
train 111 ====> Acc@1 96.042 Acc@5 99.980
valid 111 ====> Acc@1 84.730 Acc@5 99.310
train 112 ====> Acc@1 96.026 Acc@5

In [None]:
def save_model(arch_name, dataset, state, ckpt_name='ckpt_best.pth'):
    r"""Save the model (checkpoint) at the training time
    """
    dir_ckpt = pathlib.Path('checkpoint')
    dir_path = dir_ckpt / arch_name / dataset
    dir_path.mkdir(parents=True, exist_ok=True)

    if ckpt_name is None:
        ckpt_name = 'ckpt_best.pth'
    model_file = dir_path / ckpt_name
    torch.save(state, model_file)


def save_summary(arch_name, dataset, name, summary):
    r"""Save summary i.e. top-1/5 validation accuracy in each epoch
    under `summary` directory
    """
    dir_summary = pathlib.Path('summary')
    dir_path = dir_summary / 'csv'
    dir_path.mkdir(parents=True, exist_ok=True)

    file_name = '{}_{}_{}.csv'.format(arch_name, dataset, name)
    file_summ = dir_path / file_name

    if summary[0] == 0:
        with open(file_summ, 'w', newline='') as csv_out:
            writer = csv.writer(csv_out)
            header_list = ['Epoch', 'Acc@1_train', 'Acc@5_train', 'Acc@1_valid', 'Acc@5_valid']
            writer.writerow(header_list)
            writer.writerow(summary)
    else:
        file_temp = dir_path / 'temp.csv'
        shutil.copyfile(file_summ, file_temp)
        with open(file_temp, 'r', newline='') as csv_in:
            with open(file_summ, 'w', newline='') as csv_out:
                reader = csv.reader(csv_in)
                writer = csv.writer(csv_out)
                for row_list in reader:
                    writer.writerow(row_list)
                writer.writerow(summary)
        remove(file_temp)
