In [1]:
import time
import random
import pathlib
from os.path import isfile
import copy
import sys

import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn

from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

from resnet_mask import *
from utils import *

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)

In [3]:
## args
layers = 56
prune_type = 'structured'
prune_rate = 0.9
prune_imp = 'L2'
epochs = 300
batch_size = 128
lr = 0.2
momentum = 0.9
wd = 1e-4

cfgs = {
    '18':  (BasicBlock, [2, 2, 2, 2]),
    '34':  (BasicBlock, [3, 4, 6, 3]),
    '50':  (Bottleneck, [3, 4, 6, 3]),
    '101': (Bottleneck, [3, 4, 23, 3]),
    '152': (Bottleneck, [3, 8, 36, 3]),
}
cfgs_cifar = {
    '20':  [3, 3, 3],
    '32':  [5, 5, 5],
    '44':  [7, 7, 7],
    '56':  [9, 9, 9],
    '110': [18, 18, 18],
}

In [4]:
train_data_mean = (0.5, 0.5, 0.5)
train_data_std = (0.5, 0.5, 0.5)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(train_data_mean, train_data_std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_data_mean, train_data_std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device)
image_size = 32

In [6]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr,
                            momentum=momentum, weight_decay=wd) #nesterov=args.nesterov)
lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [7]:
##### main 함수 보고 train 짜기
best_acc1 = 0.0

for epoch in range(epochs):
    
    acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model, 
                                           prune={'type':'structured','rate':0.5}, reg=reg_cov, odecay=2,
                                           criterion=criterion, optimizer=optimizer, device=device)
    acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, 
                                              criterion=criterion, device=device)
    
    acc1_train = round(acc1_train_cor.item(), 4)
    acc5_train = round(acc5_train_cor.item(), 4)
    acc1_valid = round(acc1_valid_cor.item(), 4)
    acc5_valid = round(acc5_valid_cor.item(), 4)

    # remember best Acc@1 and save checkpoint and summary csv file
#     summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]

    is_best = acc1_valid > best_acc1
    best_acc1 = max(acc1_valid, best_acc1)
    if is_best:
        summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
#         save_model(arch_name, args.dataset, state, args.save)
#     save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary)

summary

train 0 ====> Acc@1 17.148 Acc@5 66.014
valid 0 ====> Acc@1 16.760 Acc@5 69.480
train 1 ====> Acc@1 33.094 Acc@5 86.082
valid 1 ====> Acc@1 41.350 Acc@5 90.360
train 2 ====> Acc@1 44.942 Acc@5 91.710
valid 2 ====> Acc@1 48.860 Acc@5 94.280
train 3 ====> Acc@1 56.512 Acc@5 94.978
valid 3 ====> Acc@1 61.160 Acc@5 96.450
train 4 ====> Acc@1 66.582 Acc@5 97.204
valid 4 ====> Acc@1 47.070 Acc@5 89.340
train 5 ====> Acc@1 72.636 Acc@5 98.002
valid 5 ====> Acc@1 67.500 Acc@5 97.510
train 6 ====> Acc@1 76.064 Acc@5 98.488
valid 6 ====> Acc@1 70.480 Acc@5 98.010
train 7 ====> Acc@1 79.022 Acc@5 98.820
valid 7 ====> Acc@1 70.120 Acc@5 98.220
train 8 ====> Acc@1 80.780 Acc@5 98.984
valid 8 ====> Acc@1 74.270 Acc@5 97.540
train 9 ====> Acc@1 82.676 Acc@5 99.232
valid 9 ====> Acc@1 77.310 Acc@5 98.360
train 10 ====> Acc@1 83.562 Acc@5 99.230
valid 10 ====> Acc@1 78.530 Acc@5 98.330
train 11 ====> Acc@1 84.676 Acc@5 99.352
valid 11 ====> Acc@1 77.830 Acc@5 98.650
train 12 ====> Acc@1 85.672 Acc@5 99

valid 100 ====> Acc@1 84.660 Acc@5 98.940
train 101 ====> Acc@1 95.786 Acc@5 99.988
valid 101 ====> Acc@1 85.950 Acc@5 99.390
train 102 ====> Acc@1 96.202 Acc@5 99.980
valid 102 ====> Acc@1 86.980 Acc@5 99.590
train 103 ====> Acc@1 96.124 Acc@5 99.972
valid 103 ====> Acc@1 80.510 Acc@5 98.650
train 104 ====> Acc@1 95.704 Acc@5 99.972
valid 104 ====> Acc@1 82.360 Acc@5 99.230
train 105 ====> Acc@1 96.050 Acc@5 99.976
valid 105 ====> Acc@1 84.150 Acc@5 99.320
train 106 ====> Acc@1 95.732 Acc@5 99.984
valid 106 ====> Acc@1 83.530 Acc@5 99.140
train 107 ====> Acc@1 96.038 Acc@5 99.978
valid 107 ====> Acc@1 81.510 Acc@5 98.990
train 108 ====> Acc@1 96.148 Acc@5 99.986
valid 108 ====> Acc@1 86.030 Acc@5 99.300
train 109 ====> Acc@1 96.032 Acc@5 99.976
valid 109 ====> Acc@1 87.160 Acc@5 99.450
train 110 ====> Acc@1 96.106 Acc@5 99.972
valid 110 ====> Acc@1 85.330 Acc@5 99.290
train 111 ====> Acc@1 96.176 Acc@5 99.980
valid 111 ====> Acc@1 86.310 Acc@5 98.860
train 112 ====> Acc@1 95.860 Acc@5

In [None]:
# def filter_prune(model, filter_mask):
#     idx = 0
#     for name, item in model.named_parameters():  
#         #.module.named_parameters():
#         if len(item.size())==4 and 'mask' in name:
#             for i in range(item.size(0)):
#                 item.data[i,:,:,:] = 1 if filter_mask[idx] else 0
#                 idx += 1

In [None]:
# losses = AverageMeter('Loss', ':.4e')
# top1 = AverageMeter('Acc@1', ':6.2f')
# top5 = AverageMeter('Acc@5', ':6.2f')

# prune_freq = 16
# milestones = [150, 225]

# model.train()

# for i, (inputs, targets) in enumerate(trainloader):
#     inputs = inputs.to(device)
#     targets = targets.to(device)

#     ### filter_mask = get_filter_mask(model, 0.5)   
#     ## 여기에서 64가 나와서 문제였음 (2128이 나와야할거 같은데..?)
#     # if (globals()['iterations']+1) % args.prune_freq==0 and (epoch+1) <= args.milestones[1]:
#     # 위 부분을 생략하고 진행해서 그런 것 같음
#     #and epoch+1 <= milestones[1]:
#     #if (i+1) % prune_freq == 0: print('prune start')
        
#     filter_mask = get_filter_mask(model, 0.5)
#     if len(filter_mask) == 2128:
#         print('prune start')
#         ### filter_prune(model, filter_mask)
#         idx = 0
#         for name, item in model.named_parameters():
#             if len(item.size())==4 and 'mask' in name:
#                 for j in range(item.size(0)):
#                     item.data[j,:,:,:] = 1 if filter_mask[idx] else 0
#                     idx += 1
#     if len(filter_mask) != 64:
#         print(len(filter_mask))
#     # else: print('no pruning since', i+1 % prune_freq)

#     outputs = model(inputs)
# #     if reg:
# #         oloss = reg(model)
# #         oloss = odecay * oloss
# #         loss = criterion(outputs, targets) + oloss
# #     else:
#     loss = criterion(outputs, targets)

#     acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
#     losses.update(loss.item(), inputs.size(0))
#     top1.update(acc1[0], inputs.size(0))
#     top5.update(acc5[0], inputs.size(0))

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
    
#     if i % 20 == 0: print(i, top1.avg, top5.avg)

In [None]:
# for i in range(10):
#     print(globals()['iterations'])

### 필터 프루닝 함수 확인

In [None]:
# importance_all = None
# for name, item in model.named_parameters():
#     if len(item.size()) == 4 and 'weight' in name:
#         filters = item.data.view(item.size(0), -1).cpu()
#         weight_len = filters.size(1)
#         importance = filters.abs().sum(dim=1).numpy() / weight_len
        
#         if importance_all is None:
#             importance_all = importance
#         else:
#             importance_all = np.append(importance_all, importance)
            
# threshold = np.sort(importance_all)[int(len(importance_all) * 0.5)]
# filter_mask = np.greater(importance_all, threshold)
# print(filter_mask)

In [None]:
# idx = 0
# for name, item in model.named_parameters():  
#     #.module.named_parameters():
#     if len(item.size())==4 and 'mask' in name:
#         for i in range(item.size(0)):
#             item.data[i,:,:,:] = 1 if filter_mask[idx] else 0
#             idx += 1
#             print(item.data[i,:,:,:])
        

In [None]:
# len(filter_mask)