In [1]:
import time
import random
import pathlib
from os.path import isfile
import copy
import sys

import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn

from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

from resnet_mask import *
from utils import *

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)

In [3]:
## args
layers = 56
prune_type = 'structured'
prune_rate = 0.9
prune_imp = 'L2'
epochs = 300
batch_size = 128
lr = 0.2
momentum = 0.9
wd = 1e-4

cfgs = {
    '18':  (BasicBlock, [2, 2, 2, 2]),
    '34':  (BasicBlock, [3, 4, 6, 3]),
    '50':  (Bottleneck, [3, 4, 6, 3]),
    '101': (Bottleneck, [3, 4, 23, 3]),
    '152': (Bottleneck, [3, 8, 36, 3]),
}
cfgs_cifar = {
    '20':  [3, 3, 3],
    '32':  [5, 5, 5],
    '44':  [7, 7, 7],
    '56':  [9, 9, 9],
    '110': [18, 18, 18],
}

In [4]:
train_data_mean = (0.5, 0.5, 0.5)
train_data_std = (0.5, 0.5, 0.5)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(train_data_mean, train_data_std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_data_mean, train_data_std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data


Files already downloaded and verified


In [5]:
model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device)
image_size = 32

In [6]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr,
                            momentum=momentum, weight_decay=wd) #nesterov=args.nesterov)
lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [7]:
##### main 함수 보고 train 짜기
best_acc1 = 0.0

for epoch in range(epochs):
    
    acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model, 
                                           prune={'type':'structured','rate':0.5}, reg=reg_cov, odecay=2,
                                           criterion=criterion, optimizer=optimizer, device=device)
    acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, 
                                              criterion=criterion, device=device)
    
    acc1_train = round(acc1_train_cor.item(), 4)
    acc5_train = round(acc5_train_cor.item(), 4)
    acc1_valid = round(acc1_valid_cor.item(), 4)
    acc5_valid = round(acc5_valid_cor.item(), 4)

    # remember best Acc@1 and save checkpoint and summary csv file
#     summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]

    is_best = acc1_valid > best_acc1
    best_acc1 = max(acc1_valid, best_acc1)
    if is_best:
        summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
#         save_model(arch_name, args.dataset, state, args.save)
#     save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary)

summary

train 0 ====> Acc@1 17.148 Acc@5 66.014


valid 0 ====> Acc@1 16.760 Acc@5 69.480


train 1 ====> Acc@1 33.094 Acc@5 86.082


valid 1 ====> Acc@1 41.350 Acc@5 90.360


train 2 ====> Acc@1 44.942 Acc@5 91.710


valid 2 ====> Acc@1 48.860 Acc@5 94.280


train 3 ====> Acc@1 56.512 Acc@5 94.978


valid 3 ====> Acc@1 61.160 Acc@5 96.450


train 4 ====> Acc@1 66.582 Acc@5 97.204


valid 4 ====> Acc@1 47.070 Acc@5 89.340


train 5 ====> Acc@1 72.636 Acc@5 98.002


valid 5 ====> Acc@1 67.500 Acc@5 97.510


train 6 ====> Acc@1 76.064 Acc@5 98.488


valid 6 ====> Acc@1 70.480 Acc@5 98.010


train 7 ====> Acc@1 79.022 Acc@5 98.820


valid 7 ====> Acc@1 70.120 Acc@5 98.220


train 8 ====> Acc@1 80.780 Acc@5 98.984


valid 8 ====> Acc@1 74.270 Acc@5 97.540


train 9 ====> Acc@1 82.676 Acc@5 99.232


valid 9 ====> Acc@1 77.310 Acc@5 98.360


train 10 ====> Acc@1 83.562 Acc@5 99.230


valid 10 ====> Acc@1 78.530 Acc@5 98.330


train 11 ====> Acc@1 84.676 Acc@5 99.352


valid 11 ====> Acc@1 77.830 Acc@5 98.650


train 12 ====> Acc@1 85.672 Acc@5 99.432


valid 12 ====> Acc@1 78.770 Acc@5 98.810


train 13 ====> Acc@1 86.406 Acc@5 99.516


valid 13 ====> Acc@1 81.040 Acc@5 99.100


train 14 ====> Acc@1 87.082 Acc@5 99.578


valid 14 ====> Acc@1 82.030 Acc@5 99.070


train 15 ====> Acc@1 87.758 Acc@5 99.638


valid 15 ====> Acc@1 81.150 Acc@5 99.020


train 16 ====> Acc@1 88.168 Acc@5 99.642


valid 16 ====> Acc@1 79.520 Acc@5 98.810


train 17 ====> Acc@1 88.780 Acc@5 99.710


valid 17 ====> Acc@1 81.790 Acc@5 98.960


train 18 ====> Acc@1 89.314 Acc@5 99.724


valid 18 ====> Acc@1 79.610 Acc@5 98.850


train 19 ====> Acc@1 89.876 Acc@5 99.724


valid 19 ====> Acc@1 81.280 Acc@5 98.690


train 20 ====> Acc@1 90.332 Acc@5 99.776


valid 20 ====> Acc@1 82.410 Acc@5 98.780


train 21 ====> Acc@1 90.382 Acc@5 99.762


valid 21 ====> Acc@1 81.630 Acc@5 99.090


train 22 ====> Acc@1 90.674 Acc@5 99.802


valid 22 ====> Acc@1 77.080 Acc@5 98.180


train 23 ====> Acc@1 91.024 Acc@5 99.810


valid 23 ====> Acc@1 80.050 Acc@5 98.640


train 24 ====> Acc@1 91.578 Acc@5 99.838


valid 24 ====> Acc@1 82.200 Acc@5 98.580


train 25 ====> Acc@1 91.772 Acc@5 99.846


valid 25 ====> Acc@1 84.290 Acc@5 99.270


train 26 ====> Acc@1 91.818 Acc@5 99.842


valid 26 ====> Acc@1 81.680 Acc@5 99.150


train 27 ====> Acc@1 91.826 Acc@5 99.850


valid 27 ====> Acc@1 82.260 Acc@5 99.020


train 28 ====> Acc@1 92.406 Acc@5 99.902


valid 28 ====> Acc@1 78.220 Acc@5 98.620


train 29 ====> Acc@1 92.518 Acc@5 99.886


valid 29 ====> Acc@1 77.580 Acc@5 98.630


train 30 ====> Acc@1 92.574 Acc@5 99.878


valid 30 ====> Acc@1 85.670 Acc@5 99.440


train 31 ====> Acc@1 92.876 Acc@5 99.902


valid 31 ====> Acc@1 82.320 Acc@5 99.230


train 32 ====> Acc@1 92.766 Acc@5 99.878


valid 32 ====> Acc@1 84.790 Acc@5 99.380


train 33 ====> Acc@1 92.892 Acc@5 99.932


valid 33 ====> Acc@1 83.730 Acc@5 99.410


train 34 ====> Acc@1 93.390 Acc@5 99.928


valid 34 ====> Acc@1 79.700 Acc@5 99.020


train 35 ====> Acc@1 93.278 Acc@5 99.912


valid 35 ====> Acc@1 81.110 Acc@5 99.130


train 36 ====> Acc@1 93.724 Acc@5 99.942


valid 36 ====> Acc@1 85.550 Acc@5 99.270


train 37 ====> Acc@1 93.784 Acc@5 99.912


valid 37 ====> Acc@1 82.850 Acc@5 99.190


train 38 ====> Acc@1 93.824 Acc@5 99.922


valid 38 ====> Acc@1 83.640 Acc@5 99.210


train 39 ====> Acc@1 93.804 Acc@5 99.936


valid 39 ====> Acc@1 82.620 Acc@5 99.220


train 40 ====> Acc@1 93.826 Acc@5 99.928


valid 40 ====> Acc@1 81.350 Acc@5 98.820


train 41 ====> Acc@1 94.070 Acc@5 99.942


valid 41 ====> Acc@1 84.470 Acc@5 99.150


train 42 ====> Acc@1 93.916 Acc@5 99.930


valid 42 ====> Acc@1 82.590 Acc@5 98.540


train 43 ====> Acc@1 94.002 Acc@5 99.942


valid 43 ====> Acc@1 84.120 Acc@5 99.290


train 44 ====> Acc@1 94.170 Acc@5 99.948


valid 44 ====> Acc@1 85.200 Acc@5 99.200


train 45 ====> Acc@1 94.742 Acc@5 99.956


valid 45 ====> Acc@1 83.630 Acc@5 99.160


train 46 ====> Acc@1 94.112 Acc@5 99.926


valid 46 ====> Acc@1 83.040 Acc@5 98.900


train 47 ====> Acc@1 94.684 Acc@5 99.960


valid 47 ====> Acc@1 80.970 Acc@5 99.100


train 48 ====> Acc@1 94.206 Acc@5 99.958


valid 48 ====> Acc@1 85.150 Acc@5 99.340


train 49 ====> Acc@1 94.590 Acc@5 99.950


valid 49 ====> Acc@1 80.160 Acc@5 98.960


train 50 ====> Acc@1 94.478 Acc@5 99.972


valid 50 ====> Acc@1 84.860 Acc@5 99.190


train 51 ====> Acc@1 94.634 Acc@5 99.956


valid 51 ====> Acc@1 83.270 Acc@5 99.010


train 52 ====> Acc@1 94.620 Acc@5 99.950


valid 52 ====> Acc@1 77.950 Acc@5 98.700


train 53 ====> Acc@1 94.862 Acc@5 99.954


valid 53 ====> Acc@1 83.950 Acc@5 99.150


train 54 ====> Acc@1 95.000 Acc@5 99.974


valid 54 ====> Acc@1 82.640 Acc@5 98.790


train 55 ====> Acc@1 94.878 Acc@5 99.962


valid 55 ====> Acc@1 83.240 Acc@5 99.150


train 56 ====> Acc@1 95.194 Acc@5 99.972


valid 56 ====> Acc@1 85.570 Acc@5 99.430


train 57 ====> Acc@1 94.820 Acc@5 99.962


valid 57 ====> Acc@1 85.330 Acc@5 99.110


train 58 ====> Acc@1 95.228 Acc@5 99.964


valid 58 ====> Acc@1 84.050 Acc@5 99.000


train 59 ====> Acc@1 95.132 Acc@5 99.966


valid 59 ====> Acc@1 84.160 Acc@5 99.010


train 60 ====> Acc@1 94.710 Acc@5 99.950


valid 60 ====> Acc@1 84.420 Acc@5 99.320


train 61 ====> Acc@1 94.980 Acc@5 99.968


valid 61 ====> Acc@1 84.870 Acc@5 99.250


train 62 ====> Acc@1 95.234 Acc@5 99.974


valid 62 ====> Acc@1 84.140 Acc@5 99.410


train 63 ====> Acc@1 95.068 Acc@5 99.968


valid 63 ====> Acc@1 84.850 Acc@5 99.160


train 64 ====> Acc@1 95.424 Acc@5 99.960


valid 64 ====> Acc@1 84.060 Acc@5 99.390


train 65 ====> Acc@1 95.244 Acc@5 99.974


valid 65 ====> Acc@1 83.310 Acc@5 98.910


train 66 ====> Acc@1 95.396 Acc@5 99.958


valid 66 ====> Acc@1 85.950 Acc@5 99.410


train 67 ====> Acc@1 95.436 Acc@5 99.966


valid 67 ====> Acc@1 82.150 Acc@5 98.900


train 68 ====> Acc@1 95.394 Acc@5 99.970


valid 68 ====> Acc@1 80.650 Acc@5 98.440


train 69 ====> Acc@1 95.308 Acc@5 99.968


valid 69 ====> Acc@1 84.040 Acc@5 99.250


train 70 ====> Acc@1 95.700 Acc@5 99.980


valid 70 ====> Acc@1 84.210 Acc@5 99.030


train 71 ====> Acc@1 95.296 Acc@5 99.970


valid 71 ====> Acc@1 83.460 Acc@5 99.250


train 72 ====> Acc@1 95.614 Acc@5 99.976


valid 72 ====> Acc@1 83.870 Acc@5 99.310


train 73 ====> Acc@1 95.448 Acc@5 99.968


valid 73 ====> Acc@1 83.510 Acc@5 99.160


train 74 ====> Acc@1 95.838 Acc@5 99.970


valid 74 ====> Acc@1 84.320 Acc@5 99.330


train 75 ====> Acc@1 95.784 Acc@5 99.978


valid 75 ====> Acc@1 82.340 Acc@5 99.230


train 76 ====> Acc@1 95.512 Acc@5 99.976


valid 76 ====> Acc@1 85.500 Acc@5 99.310


train 77 ====> Acc@1 95.602 Acc@5 99.968


valid 77 ====> Acc@1 86.060 Acc@5 99.280


train 78 ====> Acc@1 95.626 Acc@5 99.974


valid 78 ====> Acc@1 82.350 Acc@5 99.270


train 79 ====> Acc@1 95.682 Acc@5 99.980


valid 79 ====> Acc@1 86.090 Acc@5 99.300


train 80 ====> Acc@1 95.458 Acc@5 99.974


valid 80 ====> Acc@1 81.490 Acc@5 99.260


train 81 ====> Acc@1 95.374 Acc@5 99.968


valid 81 ====> Acc@1 84.940 Acc@5 99.300


train 82 ====> Acc@1 95.418 Acc@5 99.978


valid 82 ====> Acc@1 80.100 Acc@5 98.360


train 83 ====> Acc@1 95.586 Acc@5 99.976


valid 83 ====> Acc@1 85.990 Acc@5 99.370


train 84 ====> Acc@1 95.578 Acc@5 99.982


valid 84 ====> Acc@1 84.650 Acc@5 99.390


train 85 ====> Acc@1 95.450 Acc@5 99.966


valid 85 ====> Acc@1 83.690 Acc@5 99.280


train 86 ====> Acc@1 95.704 Acc@5 99.982


valid 86 ====> Acc@1 82.990 Acc@5 99.310


train 87 ====> Acc@1 95.820 Acc@5 99.968


valid 87 ====> Acc@1 85.780 Acc@5 99.190


train 88 ====> Acc@1 95.796 Acc@5 99.964


valid 88 ====> Acc@1 83.930 Acc@5 99.020


train 89 ====> Acc@1 95.518 Acc@5 99.968


valid 89 ====> Acc@1 84.600 Acc@5 99.170


train 90 ====> Acc@1 95.898 Acc@5 99.982


valid 90 ====> Acc@1 83.770 Acc@5 98.950


train 91 ====> Acc@1 95.642 Acc@5 99.978


valid 91 ====> Acc@1 85.140 Acc@5 99.290


train 92 ====> Acc@1 95.816 Acc@5 99.980


valid 92 ====> Acc@1 84.200 Acc@5 99.120


train 93 ====> Acc@1 95.944 Acc@5 99.970


valid 93 ====> Acc@1 84.620 Acc@5 99.200


train 94 ====> Acc@1 96.052 Acc@5 99.992


valid 94 ====> Acc@1 84.190 Acc@5 99.380


train 95 ====> Acc@1 96.004 Acc@5 99.988


valid 95 ====> Acc@1 85.040 Acc@5 99.210


train 96 ====> Acc@1 95.984 Acc@5 99.980


valid 96 ====> Acc@1 84.380 Acc@5 99.280


train 97 ====> Acc@1 95.980 Acc@5 99.982


valid 97 ====> Acc@1 82.970 Acc@5 98.690


train 98 ====> Acc@1 95.854 Acc@5 99.974


valid 98 ====> Acc@1 85.240 Acc@5 99.330


train 99 ====> Acc@1 96.268 Acc@5 99.986


valid 99 ====> Acc@1 84.140 Acc@5 99.460


train 100 ====> Acc@1 95.684 Acc@5 99.974


valid 100 ====> Acc@1 84.660 Acc@5 98.940


train 101 ====> Acc@1 95.786 Acc@5 99.988


valid 101 ====> Acc@1 85.950 Acc@5 99.390


train 102 ====> Acc@1 96.202 Acc@5 99.980


valid 102 ====> Acc@1 86.980 Acc@5 99.590


train 103 ====> Acc@1 96.124 Acc@5 99.972


valid 103 ====> Acc@1 80.510 Acc@5 98.650


train 104 ====> Acc@1 95.704 Acc@5 99.972


valid 104 ====> Acc@1 82.360 Acc@5 99.230


train 105 ====> Acc@1 96.050 Acc@5 99.976


valid 105 ====> Acc@1 84.150 Acc@5 99.320


train 106 ====> Acc@1 95.732 Acc@5 99.984


valid 106 ====> Acc@1 83.530 Acc@5 99.140


train 107 ====> Acc@1 96.038 Acc@5 99.978


valid 107 ====> Acc@1 81.510 Acc@5 98.990


train 108 ====> Acc@1 96.148 Acc@5 99.986


valid 108 ====> Acc@1 86.030 Acc@5 99.300


train 109 ====> Acc@1 96.032 Acc@5 99.976


valid 109 ====> Acc@1 87.160 Acc@5 99.450


train 110 ====> Acc@1 96.106 Acc@5 99.972


valid 110 ====> Acc@1 85.330 Acc@5 99.290


train 111 ====> Acc@1 96.176 Acc@5 99.980


valid 111 ====> Acc@1 86.310 Acc@5 98.860


train 112 ====> Acc@1 95.860 Acc@5 99.986


valid 112 ====> Acc@1 81.040 Acc@5 98.800


train 113 ====> Acc@1 96.086 Acc@5 99.982


valid 113 ====> Acc@1 84.320 Acc@5 99.230


train 114 ====> Acc@1 95.934 Acc@5 99.986


valid 114 ====> Acc@1 85.500 Acc@5 99.240


train 115 ====> Acc@1 95.954 Acc@5 99.978


valid 115 ====> Acc@1 84.250 Acc@5 99.140


train 116 ====> Acc@1 95.906 Acc@5 99.984


valid 116 ====> Acc@1 84.640 Acc@5 99.340


train 117 ====> Acc@1 95.650 Acc@5 99.976


valid 117 ====> Acc@1 83.500 Acc@5 99.230


train 118 ====> Acc@1 96.284 Acc@5 99.988


valid 118 ====> Acc@1 87.090 Acc@5 99.540


train 119 ====> Acc@1 95.804 Acc@5 99.978


valid 119 ====> Acc@1 85.560 Acc@5 99.250


train 120 ====> Acc@1 96.524 Acc@5 99.970


valid 120 ====> Acc@1 86.270 Acc@5 99.180


train 121 ====> Acc@1 96.272 Acc@5 99.976


valid 121 ====> Acc@1 85.650 Acc@5 99.280


train 122 ====> Acc@1 96.152 Acc@5 99.976


valid 122 ====> Acc@1 84.150 Acc@5 98.880


train 123 ====> Acc@1 96.230 Acc@5 99.986


valid 123 ====> Acc@1 84.990 Acc@5 99.150


train 124 ====> Acc@1 96.056 Acc@5 99.984


valid 124 ====> Acc@1 84.950 Acc@5 99.420


train 125 ====> Acc@1 96.234 Acc@5 99.978


valid 125 ====> Acc@1 85.670 Acc@5 99.270


train 126 ====> Acc@1 95.906 Acc@5 99.980


valid 126 ====> Acc@1 86.080 Acc@5 99.420


train 127 ====> Acc@1 96.192 Acc@5 99.980


valid 127 ====> Acc@1 86.560 Acc@5 99.450


train 128 ====> Acc@1 96.326 Acc@5 99.994


valid 128 ====> Acc@1 85.000 Acc@5 99.380


train 129 ====> Acc@1 96.316 Acc@5 99.986


valid 129 ====> Acc@1 85.670 Acc@5 99.490


train 130 ====> Acc@1 96.200 Acc@5 99.972


valid 130 ====> Acc@1 85.950 Acc@5 99.380


train 131 ====> Acc@1 95.874 Acc@5 99.974


valid 131 ====> Acc@1 83.810 Acc@5 99.290


train 132 ====> Acc@1 96.398 Acc@5 99.982


valid 132 ====> Acc@1 85.200 Acc@5 99.310


train 133 ====> Acc@1 96.374 Acc@5 99.984


valid 133 ====> Acc@1 86.050 Acc@5 99.430


train 134 ====> Acc@1 96.208 Acc@5 99.970


valid 134 ====> Acc@1 83.400 Acc@5 99.070


train 135 ====> Acc@1 96.290 Acc@5 99.978


valid 135 ====> Acc@1 85.970 Acc@5 99.330


train 136 ====> Acc@1 96.206 Acc@5 99.992


valid 136 ====> Acc@1 85.790 Acc@5 99.430


train 137 ====> Acc@1 96.488 Acc@5 99.988


valid 137 ====> Acc@1 85.600 Acc@5 99.250


train 138 ====> Acc@1 96.246 Acc@5 99.986


valid 138 ====> Acc@1 86.050 Acc@5 99.440


train 139 ====> Acc@1 96.106 Acc@5 99.984


valid 139 ====> Acc@1 85.030 Acc@5 99.070


train 140 ====> Acc@1 96.194 Acc@5 99.990


valid 140 ====> Acc@1 84.480 Acc@5 99.230


train 141 ====> Acc@1 96.200 Acc@5 99.986


valid 141 ====> Acc@1 84.700 Acc@5 98.880


train 142 ====> Acc@1 96.350 Acc@5 99.988


valid 142 ====> Acc@1 81.540 Acc@5 99.080


train 143 ====> Acc@1 95.868 Acc@5 99.978


valid 143 ====> Acc@1 83.350 Acc@5 99.110


train 144 ====> Acc@1 96.268 Acc@5 99.990


valid 144 ====> Acc@1 80.070 Acc@5 98.620


train 145 ====> Acc@1 96.458 Acc@5 99.982


valid 145 ====> Acc@1 81.550 Acc@5 98.880


train 146 ====> Acc@1 95.814 Acc@5 99.986


valid 146 ====> Acc@1 85.990 Acc@5 99.310


train 147 ====> Acc@1 96.484 Acc@5 99.986


valid 147 ====> Acc@1 85.430 Acc@5 99.300


train 148 ====> Acc@1 96.024 Acc@5 99.974


valid 148 ====> Acc@1 85.560 Acc@5 99.280


train 149 ====> Acc@1 96.286 Acc@5 99.986


valid 149 ====> Acc@1 84.530 Acc@5 98.870


train 150 ====> Acc@1 96.204 Acc@5 99.970


valid 150 ====> Acc@1 86.230 Acc@5 99.420


train 151 ====> Acc@1 96.298 Acc@5 99.974


valid 151 ====> Acc@1 85.540 Acc@5 99.230


train 152 ====> Acc@1 96.220 Acc@5 99.980


valid 152 ====> Acc@1 85.440 Acc@5 99.290


train 153 ====> Acc@1 96.622 Acc@5 99.976


valid 153 ====> Acc@1 87.250 Acc@5 99.130


train 154 ====> Acc@1 96.012 Acc@5 99.976


valid 154 ====> Acc@1 85.350 Acc@5 99.220


train 155 ====> Acc@1 96.182 Acc@5 99.978


valid 155 ====> Acc@1 80.110 Acc@5 99.220


train 156 ====> Acc@1 96.364 Acc@5 99.986


valid 156 ====> Acc@1 85.330 Acc@5 99.290


train 157 ====> Acc@1 96.192 Acc@5 99.978


valid 157 ====> Acc@1 85.300 Acc@5 99.350


train 158 ====> Acc@1 96.034 Acc@5 99.974


valid 158 ====> Acc@1 82.260 Acc@5 98.970


train 159 ====> Acc@1 96.850 Acc@5 99.994


valid 159 ====> Acc@1 80.780 Acc@5 98.480


train 160 ====> Acc@1 96.020 Acc@5 99.980


valid 160 ====> Acc@1 85.240 Acc@5 99.200


train 161 ====> Acc@1 96.316 Acc@5 99.978


valid 161 ====> Acc@1 84.940 Acc@5 99.150


train 162 ====> Acc@1 96.400 Acc@5 99.988


valid 162 ====> Acc@1 85.960 Acc@5 99.130


train 163 ====> Acc@1 96.764 Acc@5 99.988


valid 163 ====> Acc@1 84.120 Acc@5 99.150


train 164 ====> Acc@1 96.272 Acc@5 99.978


valid 164 ====> Acc@1 86.210 Acc@5 99.190


train 165 ====> Acc@1 96.076 Acc@5 99.968


valid 165 ====> Acc@1 86.040 Acc@5 99.230


train 166 ====> Acc@1 96.696 Acc@5 99.992


valid 166 ====> Acc@1 83.660 Acc@5 99.280


train 167 ====> Acc@1 96.392 Acc@5 99.984


valid 167 ====> Acc@1 84.100 Acc@5 99.090


train 168 ====> Acc@1 96.374 Acc@5 99.980


valid 168 ====> Acc@1 84.770 Acc@5 99.470


train 169 ====> Acc@1 96.100 Acc@5 99.982


valid 169 ====> Acc@1 84.650 Acc@5 98.970


train 170 ====> Acc@1 96.338 Acc@5 99.978


valid 170 ====> Acc@1 84.810 Acc@5 99.370


train 171 ====> Acc@1 96.340 Acc@5 99.992


valid 171 ====> Acc@1 86.520 Acc@5 99.190


train 172 ====> Acc@1 96.416 Acc@5 99.986


valid 172 ====> Acc@1 84.580 Acc@5 99.180


train 173 ====> Acc@1 96.560 Acc@5 99.986


valid 173 ====> Acc@1 84.750 Acc@5 99.200


train 174 ====> Acc@1 96.442 Acc@5 99.988


valid 174 ====> Acc@1 85.100 Acc@5 99.150


train 175 ====> Acc@1 96.562 Acc@5 99.982


valid 175 ====> Acc@1 87.010 Acc@5 99.380


train 176 ====> Acc@1 96.230 Acc@5 99.970


valid 176 ====> Acc@1 86.650 Acc@5 99.180


train 177 ====> Acc@1 96.138 Acc@5 99.980


valid 177 ====> Acc@1 81.170 Acc@5 99.080


train 178 ====> Acc@1 96.496 Acc@5 99.982


valid 178 ====> Acc@1 86.790 Acc@5 99.440


train 179 ====> Acc@1 96.598 Acc@5 99.990


valid 179 ====> Acc@1 85.870 Acc@5 99.390


train 180 ====> Acc@1 96.314 Acc@5 99.992


valid 180 ====> Acc@1 81.420 Acc@5 99.260


train 181 ====> Acc@1 96.484 Acc@5 99.984


valid 181 ====> Acc@1 84.320 Acc@5 98.750


train 182 ====> Acc@1 96.566 Acc@5 99.978


In [None]:
# def filter_prune(model, filter_mask):
#     idx = 0
#     for name, item in model.named_parameters():  
#         #.module.named_parameters():
#         if len(item.size())==4 and 'mask' in name:
#             for i in range(item.size(0)):
#                 item.data[i,:,:,:] = 1 if filter_mask[idx] else 0
#                 idx += 1

In [None]:
# losses = AverageMeter('Loss', ':.4e')
# top1 = AverageMeter('Acc@1', ':6.2f')
# top5 = AverageMeter('Acc@5', ':6.2f')

# prune_freq = 16
# milestones = [150, 225]

# model.train()

# for i, (inputs, targets) in enumerate(trainloader):
#     inputs = inputs.to(device)
#     targets = targets.to(device)

#     ### filter_mask = get_filter_mask(model, 0.5)   
#     ## 여기에서 64가 나와서 문제였음 (2128이 나와야할거 같은데..?)
#     # if (globals()['iterations']+1) % args.prune_freq==0 and (epoch+1) <= args.milestones[1]:
#     # 위 부분을 생략하고 진행해서 그런 것 같음
#     #and epoch+1 <= milestones[1]:
#     #if (i+1) % prune_freq == 0: print('prune start')
        
#     filter_mask = get_filter_mask(model, 0.5)
#     if len(filter_mask) == 2128:
#         print('prune start')
#         ### filter_prune(model, filter_mask)
#         idx = 0
#         for name, item in model.named_parameters():
#             if len(item.size())==4 and 'mask' in name:
#                 for j in range(item.size(0)):
#                     item.data[j,:,:,:] = 1 if filter_mask[idx] else 0
#                     idx += 1
#     if len(filter_mask) != 64:
#         print(len(filter_mask))
#     # else: print('no pruning since', i+1 % prune_freq)

#     outputs = model(inputs)
# #     if reg:
# #         oloss = reg(model)
# #         oloss = odecay * oloss
# #         loss = criterion(outputs, targets) + oloss
# #     else:
#     loss = criterion(outputs, targets)

#     acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
#     losses.update(loss.item(), inputs.size(0))
#     top1.update(acc1[0], inputs.size(0))
#     top5.update(acc5[0], inputs.size(0))

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
    
#     if i % 20 == 0: print(i, top1.avg, top5.avg)

In [None]:
# for i in range(10):
#     print(globals()['iterations'])

### 필터 프루닝 함수 확인

In [None]:
# importance_all = None
# for name, item in model.named_parameters():
#     if len(item.size()) == 4 and 'weight' in name:
#         filters = item.data.view(item.size(0), -1).cpu()
#         weight_len = filters.size(1)
#         importance = filters.abs().sum(dim=1).numpy() / weight_len
        
#         if importance_all is None:
#             importance_all = importance
#         else:
#             importance_all = np.append(importance_all, importance)
            
# threshold = np.sort(importance_all)[int(len(importance_all) * 0.5)]
# filter_mask = np.greater(importance_all, threshold)
# print(filter_mask)

In [None]:
# idx = 0
# for name, item in model.named_parameters():  
#     #.module.named_parameters():
#     if len(item.size())==4 and 'mask' in name:
#         for i in range(item.size(0)):
#             item.data[i,:,:,:] = 1 if filter_mask[idx] else 0
#             idx += 1
#             print(item.data[i,:,:,:])
        

In [None]:
# len(filter_mask)