In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import collections
import argparse
import sys
import pickle
import numpy as np
import time, datetime
import copy
from thop import profile
from collections import OrderedDict
import shutil
import torch.utils
import torch.utils.data.distributed
from torchvision import datasets, transforms

In [2]:
import math
import pdb


def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

def make_divisible(x, divisible_by=8):
    import numpy as np
    return int(np.ceil(x * 1. / divisible_by) * divisible_by)


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, compress_rate, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t-ex, c-channel, n-blocknum, s-stride
            [1, 16, 1, 1],
            [6, 24, 2, 2], # NOTE: change stride 1 -> 2 for CIFAR100
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]
        self.compress_rate=compress_rate[:]

        # building first layer
        assert input_size % 32 == 0
        # input_channel = make_divisible(input_channel * width_mult)  # first channel is always 32!
        self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        cnt=1
        for t, c, n, s in interverted_residual_setting:
            output_channel = make_divisible(c * width_mult) if t > 1 else c
            output_channel = int((1-self.compress_rate[cnt])*output_channel)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
            cnt+=1

        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        #self.classifier = nn.Linear(self.last_channel, n_class)
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, n_class),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


def mobilenet_v2(compress_rate,n_class=1000):
    model = MobileNetV2(compress_rate=compress_rate,n_class=n_class,width_mult=1)
    return model

In [3]:
print("prepare resNet_50 model...")
def adapt_channel(compress_rate, num_layers):

    if num_layers==56:
        stage_repeat = [9, 9, 9]
        stage_out_channel = [16] + [16] * 9 + [32] * 9 + [64] * 9
    elif num_layers==110:
        stage_repeat = [18, 18, 18]
        stage_out_channel = [16] + [16] * 18 + [32] * 18 + [64] * 18

    stage_oup_cprate = []
    stage_oup_cprate += [compress_rate[0]]
    for i in range(len(stage_repeat)-1):
        stage_oup_cprate += [compress_rate[i+1]] * stage_repeat[i]
    stage_oup_cprate +=[0.] * stage_repeat[-1]
    mid_cprate = compress_rate[len(stage_repeat):]

    overall_channel = []
    mid_channel = []
    for i in range(len(stage_out_channel)):
        if i == 0 :
            overall_channel += [int(stage_out_channel[i] * (1-stage_oup_cprate[i]))]
        else:
            overall_channel += [int(stage_out_channel[i] * (1-stage_oup_cprate[i]))]
            mid_channel += [int(stage_out_channel[i] * (1-mid_cprate[i-1]))]

    return overall_channel, mid_channel


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, midplanes, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.inplanes = inplanes
        self.planes = planes
        self.conv1 = conv3x3(inplanes, midplanes, stride)
        self.bn1 = nn.BatchNorm2d(midplanes)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = conv3x3(midplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.stride = stride

        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != planes:
            if stride!=1:
                self.shortcut = LambdaLayer(
                    lambda x: F.pad(x[:, :, ::2, ::2],
                                    (0, 0, 0, 0, (planes-inplanes)//2, planes-inplanes-(planes-inplanes)//2), "constant", 0))
            else:
                self.shortcut = LambdaLayer(
                    lambda x: F.pad(x[:, :, :, :],
                                    (0, 0, 0, 0, (planes-inplanes)//2, planes-inplanes-(planes-inplanes)//2), "constant", 0))
            #self.shortcut = LambdaLayer(
            #    lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4),"constant", 0))

            '''self.shortcut = nn.Sequential(
                conv1x1(inplanes, planes, stride=stride),
                #nn.BatchNorm2d(planes),
            )#'''

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        #print(self.stride, self.inplanes, self.planes, out.size(), self.shortcut(x).size())
        out += self.shortcut(x)
        out = self.relu2(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, num_layers, compress_rate, num_classes=100):
        super(ResNet, self).__init__()
        assert (num_layers - 2) % 6 == 0, 'depth should be 6n+2'
        n = (num_layers - 2) // 6

        self.num_layer = num_layers
        self.overall_channel, self.mid_channel = adapt_channel(compress_rate, num_layers)

        self.layer_num = 0
        self.conv1 = nn.Conv2d(3, self.overall_channel[self.layer_num], kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.overall_channel[self.layer_num])
        self.relu = nn.ReLU(inplace=True)
        self.layers = nn.ModuleList()
        self.layer_num += 1

        #self.layers = nn.ModuleList()
        self.layer1 = self._make_layer(block, blocks_num=n, stride=1)
        self.layer2 = self._make_layer(block, blocks_num=n, stride=2)
        self.layer3 = self._make_layer(block, blocks_num=n, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        if self.num_layer == 56:
            self.fc = nn.Linear(64 * BasicBlock.expansion, num_classes)
        else:
            self.linear = nn.Linear(64 * BasicBlock.expansion, num_classes)


    def _make_layer(self, block, blocks_num, stride):
        layers = []
        layers.append(block(self.mid_channel[self.layer_num - 1], self.overall_channel[self.layer_num - 1],
                                 self.overall_channel[self.layer_num], stride))
        self.layer_num += 1

        for i in range(1, blocks_num):
            layers.append(block(self.mid_channel[self.layer_num - 1], self.overall_channel[self.layer_num - 1],
                                     self.overall_channel[self.layer_num]))
            self.layer_num += 1

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        for i, block in enumerate(self.layer1):
            x = block(x)
        for i, block in enumerate(self.layer2):
            x = block(x)
        for i, block in enumerate(self.layer3):
            x = block(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        if self.num_layer == 56:
            x = self.fc(x)
        else:
            x = self.linear(x)

        return x


def resnet_56(compress_rate):
    return ResNet(BasicBlock, 56, compress_rate=compress_rate)

def resnet_110(compress_rate):
    return ResNet(BasicBlock, 110, compress_rate=compress_rate)

prepare resNet_50 model...


In [4]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
cudnn.benchmark = True
cudnn.enabled=True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def load_mobilenetv2_model(model, oristate_dict):

    state_dict = model.state_dict()

    last_select_index = None

    all_honey_conv_weight = []

    bn_part_name=['.weight','.bias','.running_mean','.running_var']
    prefix = rank_conv_prefix+'rank_conv'
    subfix = ".npy"

    layer_cnt=1
    conv_cnt=1
    cfg=[1,2,3,4,3,3,1,1]
    for layer, num in enumerate(cfg):
        if layer_cnt==1:
            conv_id=[0,3]
        elif layer_cnt==18:
            conv_id=[0]
        else:
            conv_id=[0,3,6]

        for k in range(num):
            if layer_cnt==18:
                block_name = 'features.' + str(layer_cnt) + '.'
            else:
                block_name = 'features.'+str(layer_cnt)+'.conv.'

            for l in conv_id:
                conv_cnt += 1
                conv_name = block_name + str(l)
                bn_name = block_name + str(l+1)

                conv_weight_name = conv_name + '.weight'
                all_honey_conv_weight.append(conv_weight_name)
                oriweight = oristate_dict[conv_weight_name]
                curweight = state_dict[name_base+conv_weight_name]
                orifilter_num = oriweight.size(0)
                currentfilter_num = curweight.size(0)

                if orifilter_num != currentfilter_num:
                    print('loading rank from: ' + prefix + str(conv_cnt) + subfix)
                    rank = np.load(prefix + str(conv_cnt) + subfix)
                    select_index = np.argsort(rank)[orifilter_num - currentfilter_num:]  # preserved filter id
                    select_index.sort()

                    if (l==6 or (l==0 and layer_cnt!=1) or (l==3 and layer_cnt==1)) and last_select_index is not None:
                        for index_i, i in enumerate(select_index):
                            for index_j, j in enumerate(last_select_index):
                                state_dict[name_base+conv_weight_name][index_i][index_j] = \
                                    oristate_dict[conv_weight_name][i][j]
                            for bn_part in bn_part_name:
                                state_dict[name_base + bn_name + bn_part][index_i] = \
                                    oristate_dict[bn_name + bn_part][i]
                    else:
                        for index_i, i in enumerate(select_index):
                            state_dict[name_base+conv_weight_name][index_i] = \
                                oristate_dict[conv_weight_name][i]
                            for bn_part in bn_part_name:
                                state_dict[name_base + bn_name + bn_part][index_i] = \
                                    oristate_dict[bn_name + bn_part][i]

                    last_select_index = select_index

                elif  (l==6 or (l==0 and layer_cnt!=1) or (l==3 and layer_cnt==1)) and last_select_index is not None:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[name_base+conv_weight_name][index_i][index_j] = \
                                oristate_dict[conv_weight_name][index_i][j]
                    for bn_part in bn_part_name:
                        state_dict[name_base + bn_name + bn_part] = \
                            oristate_dict[bn_name + bn_part]
                    last_select_index = None

                else:
                    state_dict[name_base+conv_weight_name] = oriweight
                    for bn_part in bn_part_name:
                        state_dict[name_base + bn_name + bn_part] = \
                            oristate_dict[bn_name + bn_part]
                    last_select_index = None

                state_dict[name_base + bn_name + '.num_batches_tracked'] = oristate_dict[bn_name + '.num_batches_tracked']

            layer_cnt+=1

    for name, module in model.named_modules():
        name = name.replace('module.', '')
        if isinstance(module, nn.Conv2d):
            conv_name = name + '.weight'
            bn_name = list(name[:])
            bn_name[-1] = str(int(name[-1])+1)
            bn_name = ''.join(bn_name)
            if conv_name not in all_honey_conv_weight:
                state_dict[name_base+conv_name] = oristate_dict[conv_name]
                for bn_part in bn_part_name:
                    state_dict[name_base + bn_name + bn_part] = \
                        oristate_dict[bn_name + bn_part]
                state_dict[name_base + bn_name + '.num_batches_tracked'] = oristate_dict[bn_name + '.num_batches_tracked']

        elif isinstance(module, nn.Linear):
            state_dict[name_base+name + '.weight'] = oristate_dict[name + '.weight']
            state_dict[name_base+name + '.bias'] = oristate_dict[name + '.bias']

    model.load_state_dict(state_dict,strict=False)
    print("finish pruning")
    
def adjust_learning_rate(optimizer, epoch, step, len_iter):

    if lr_type == 'step':
        factor = epoch // 30
        if epoch >= 80:
            factor = factor + 1
        lr = learning_rate * (0.1 ** factor)

    elif lr_type == 'cos':  # cos without warm-up
        lr = 0.5 * learning_rate * (1 + math.cos(math.pi * (epoch - 5) / (epochs - 5)))

    elif lr_type == 'exp':
        step = 1
        decay = 0.96
        lr = learning_rate * (decay ** (epoch // step))

    elif lr_type == 'fixed':
        lr = learning_rate
    else:
        raise NotImplementedError

    #Warmup
    if epoch < 5:
        lr = lr * float(1 + step + epoch * len_iter) / (5. * len_iter)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    if step == 0:
        print('learning_rate: ' + str(lr))

def load_resnet_model(model, oristate_dict, layer):
    cfg = {
        56: [9, 9, 9],
        110: [18, 18, 18],
    }

    state_dict = model.state_dict()

    current_cfg = cfg[layer]
    last_select_index = None

    all_conv_weight = []

    prefix = rank_conv_prefix+'rank_conv'
    subfix = ".npy"

    cnt=1
    for layer, num in enumerate(current_cfg):
        layer_name = 'layer' + str(layer + 1) + '.'
        for k in range(num):
            for l in range(2):

                cnt+=1
                cov_id=cnt

                conv_name = layer_name + str(k) + '.conv' + str(l + 1)
                conv_weight_name = conv_name + '.weight'
                all_conv_weight.append(conv_weight_name)
                oriweight = oristate_dict[conv_weight_name]
                curweight =state_dict[name_base+conv_weight_name]
                orifilter_num = oriweight.size(0)
                currentfilter_num = curweight.size(0)

                if orifilter_num != currentfilter_num:
                    print('loading rank from: ' + prefix + str(cov_id) + subfix)
                    rank = np.load(prefix + str(cov_id) + subfix)
                    select_index = np.argsort(rank)[orifilter_num - currentfilter_num:]  # preserved filter id
                    select_index.sort()

                    if last_select_index is not None:
                        for index_i, i in enumerate(select_index):
                            for index_j, j in enumerate(last_select_index):
                                state_dict[name_base+conv_weight_name][index_i][index_j] = \
                                    oristate_dict[conv_weight_name][i][j]
                    else:
                        for index_i, i in enumerate(select_index):
                            state_dict[name_base+conv_weight_name][index_i] = \
                                oristate_dict[conv_weight_name][i]

                    last_select_index = select_index

                elif last_select_index is not None:
                    for index_i in range(orifilter_num):
                        for index_j, j in enumerate(last_select_index):
                            state_dict[name_base+conv_weight_name][index_i][index_j] = \
                                oristate_dict[conv_weight_name][index_i][j]
                    last_select_index = None

                else:
                    state_dict[name_base+conv_weight_name] = oriweight
                    last_select_index = None

    for name, module in model.named_modules():
        name = name.replace('module.', '')

        if isinstance(module, nn.Conv2d):
            conv_name = name + '.weight'
            if 'shortcut' in name:
                continue
            if conv_name not in all_conv_weight:
                state_dict[name_base+conv_name] = oristate_dict[conv_name]

        elif isinstance(module, nn.Linear):
            state_dict[name_base+name + '.weight'] = oristate_dict[name + '.weight']
            state_dict[name_base+name + '.bias'] = oristate_dict[name + '.bias']

    model.load_state_dict(state_dict)

In [6]:
print("超参数")
CLASSES = 100
lr_type = 'step'
epochs = 300
batch_size=  256

learning_rate = 0.1
momentum = 0.9
# 0.006
weight_decay = 0.0001
label_smooth = 0.1 
lr_decay_step = '150,225'

best_acc = 0
# "mobilenet_v2" "resnet_56"
arch = "mobilenet_v2"
# "[0.]+[0.4]*2+[0.5]*9+[0.6]*9+[0.7]*9" “[0.]+[0.15]*2+[0.4]*27”"'[0.]+[0.18]*29'"
compress_rate = "[0.]+[0.18]*29"
# "./data/model/Hrank_preTrain/cifar-10/resnet_56.pt.pt"
pretrain_dir = "./data/model/Hrank_preTrain/cifar-100/resnet_56_cifar_100.t7"
save_dir = "PruneGraft_cifar100_press1_"+arch

rank_conv_dir = "./data/model/rank_conv/cifar-100/resnet_56_limit5/"
print(save_dir)

超参数
PruneGraft_cifar100_press1_resnet_56


In [7]:
start_t = time.time()
print("prepare compress_rate")
def process_compress_rate(compress_rate):
    if compress_rate:# 处理args.compress_rate
        import re
        cprate_str = compress_rate
        cprate_str_list = cprate_str.split('+')
        pat_cprate = re.compile(r'\d+\.\d*')
        pat_num = re.compile(r'\*\d+')
        cprate = []
        for x in cprate_str_list:
            num = 1
            find_num = re.findall(pat_num, x)
            if find_num:
                assert len(find_num) == 1
                num = int(find_num[0].replace('*', ''))
            find_cprate = re.findall(pat_cprate, x)
            assert len(find_cprate) == 1
            cprate += [float(find_cprate[0])] * num

        compress_rate = cprate
        return compress_rate
compress_rate = process_compress_rate(compress_rate)
print('compress_rate:' + str(compress_rate))

prepare compress_rate
compress_rate:[0.0, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18]


In [8]:
print("=====> Building model")
if arch == "mobilenet_v2":
    net_1 = mobilenet_v2(compress_rate,n_class=10)
    net_2 = mobilenet_v2(compress_rate,n_class=10)
    net_1.to(device)
    net_2.to(device)
    
    rank_conv_prefix = "./data/model/rank_conv/mobileNetV2_limit5/"
    name_base=''

    print('resuming from pretrain model')
    origin_model = mobilenet_v2(compress_rate=[0.] * 100,n_class=10).cuda()
    ckpt = torch.load("./data/model/PruneGraft_cifar10_MobileNetV2_preTrain/best_9.t7")
    # ckpt = {k:v for k,v in ckpt.items() if "classifier" not in k}
    ckpt = ckpt["net"]

    origin_model.load_state_dict(ckpt)
    oristate_dict = origin_model.state_dict()
    load_mobilenetv2_model(net_1,oristate_dict)
    load_mobilenetv2_model(net_2,oristate_dict)
if arch == "resnet_56":
    net_1 = resnet_56(compress_rate)
    net_2 = resnet_56(compress_rate)
    net_1.to(device)
    net_2.to(device)
    

    
    input_image_size=32
    input_image = torch.randn(1, 3, input_image_size, input_image_size).cuda()
    flops, params = profile(net_2, inputs=(input_image,))
    print('Params: %.2f' % (params))
    print('Flops: %.2f' % (flops))

    input_image_size=32
    input_image = torch.randn(1, 3, input_image_size, input_image_size).cuda()
    flops, params = profile(net_1, inputs=(input_image,))
    print('Params: %.2f' % (params))
    print('Flops: %.2f' % (flops))
    rank_conv_prefix = rank_conv_dir
    name_base=''
    
    origin_model = resnet_56(compress_rate=[0.] * 100).cuda()
    
    input_image_size=32
    input_image = torch.randn(1, 3, input_image_size, input_image_size).cuda()
    flops, params = profile(origin_model, inputs=(input_image,))
    print('Params: %.2f' % (params))
    print('Flops: %.2f' % (flops))
    
    ckpt = torch.load(pretrain_dir, map_location='cuda:0')
    origin_model.load_state_dict(ckpt['net'])

    
    oristate_dict = origin_model.state_dict()
    load_resnet_model(net_1, oristate_dict, 56)
    print("-------------------")
    load_resnet_model(net_2, oristate_dict, 56)
    

=====> Building model
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.ModuleList'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.LambdaLayer'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.BasicBlock'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[91m[WARN] Cannot find rule for <class '__main__.ResNet'>. Treat it as zero Macs and zero Params.[00m
Pa

loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv22.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv23.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv24.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv25.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv26.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv27.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv28.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv29.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv30.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv31.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_limit5/rank_conv32.npy
loading rank from: ./data/model/rank_conv/cifar-100/resnet_56_lim

In [9]:
class CrossEntropyLabelSmooth(nn.Module):

  def __init__(self, num_classes, epsilon):
    super(CrossEntropyLabelSmooth, self).__init__()
    self.num_classes = num_classes
    self.epsilon = epsilon
    self.logsoftmax = nn.LogSoftmax(dim=1)

  def forward(self, inputs, targets):
    log_probs = self.logsoftmax(inputs)
    targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
    targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
    loss = (-targets * log_probs).mean(0).sum()
    return loss   


criterion_1 = nn.CrossEntropyLoss()
criterion_1 = criterion_1.cuda()
criterion_smooth_1 = CrossEntropyLabelSmooth(CLASSES, 0.1)
criterion_smooth_1 = criterion_smooth_1.cuda()

criterion_2 = nn.CrossEntropyLoss()
criterion_2 = criterion_2.cuda()
criterion_smooth_2 = CrossEntropyLabelSmooth(CLASSES, 0.1)
criterion_smooth_2 = criterion_smooth_2.cuda()

optimizer_1 = torch.optim.SGD(net_1.parameters(), learning_rate, momentum=momentum, weight_decay=weight_decay)
optimizer_2 = torch.optim.SGD(net_2.parameters(), 0.1, momentum=momentum, weight_decay=0.0001)
lr_decay_step = list(map(int, lr_decay_step.split(',')))
scheduler_1 = torch.optim.lr_scheduler.MultiStepLR(optimizer_1, milestones=lr_decay_step, gamma=0.1)
scheduler_2 = torch.optim.lr_scheduler.MultiStepLR(optimizer_2, milestones=lr_decay_step, gamma=0.1)

lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=60, gamma=0.1)
lr_scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=60, gamma=0.1)

In [10]:
print('load training data')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR100(root='./data/cifar-100/', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR100(root='./data/cifar-100/', train=False, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,drop_last=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False,drop_last=False)

load training data
Files already downloaded and verified
Files already downloaded and verified


In [14]:
def entropy(x, n=10):
    x = x.reshape(-1)
    scale = (x.max() - x.min()) / n
    entropy = 0
    for i in range(n):
        p = torch.sum((x >= x.min() + i * scale) * (x < x.min() + (i + 1) * scale), dtype=torch.float) / len(x)
        if p != 0:
            entropy -= p * torch.log(p)
    return float(entropy.cpu())

def train(epoch,i,net,optimizer,scheduler,criterion):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    
    num_iter = len(trainloader)
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
#         adjust_learning_rate(optimizer, epoch, batch_idx, num_iter)
        
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % 1000 == 1000 - 1 or 1000 == trainloader.__len__() - 1:
            print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
                train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
    scheduler.step()
    
def test(epoch,net,i,criterion):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100. * correct / total
    if acc > best_acc:
        best_acc = acc
        state = {
            'net': net.state_dict(),
            'acc': acc
        }
        torch.save(state, './data/model/'+save_dir+'/best_%d.t7' % (epoch))
    print('Network:%d    epoch:%d    accuracy:%.3f    best:%.3f' % (i, epoch, acc, best_acc))

def grafting(net, epoch,i):
    while True:
        try:
            checkpoint = torch.load('./data/model/'+save_dir+'/ckpt%d_%d.t7' % (i - 1, epoch))['net']
            break
        except:
            time.sleep(10)
    model = collections.OrderedDict()
    w = 1 
    for i, (key, u) in enumerate(net.state_dict().items()):
        if 'conv' in key and 'weight' in key:
            w = round(0.4 / np.pi * np.arctan(500 * (entropy(u) - entropy(checkpoint[key]))) + 0.5, 2)
        model[key] = u * w + checkpoint[key] * (1 - w)
        
    net.load_state_dict(model)
    
# for epoch in range(200):
#     scheduler_1.step()
#     scheduler_2.step()

In [15]:
for epoch in range(300):
    train(epoch,1,net_1,optimizer_1,scheduler_1,criterion_smooth_1)
    test(epoch,net_1,1,criterion_1)
    state = {
        'net': net_1.state_dict(),
    }
    torch.save(state, './data/model/'+save_dir+'/ckpt%d_%d.t7' % ( 1, epoch))
    
    train(epoch,2,net_2,optimizer_2,scheduler_2,criterion_smooth_2)
    test(epoch,net_2,2,criterion_2)
    state = {
        'net': net_2.state_dict(),
        
    }
    torch.save(state, './data/model/'+save_dir+'/ckpt%d_%d.t7' % ( 0, epoch))
    
    grafting(net_1,epoch,1)
    grafting(net_2,epoch,2)

learning_rate: 0.00010204081632653062
Network:1    epoch:0    accuracy:42.420    best:42.420
learning_rate: 0.00010204081632653062
Network:2    epoch:0    accuracy:27.020    best:42.420
learning_rate: 0.020102040816326532
Network:1    epoch:1    accuracy:45.400    best:45.400
learning_rate: 0.020102040816326532
Network:2    epoch:1    accuracy:40.710    best:45.400
learning_rate: 0.04010204081632653
Network:1    epoch:2    accuracy:50.210    best:50.210
learning_rate: 0.04010204081632653
Network:2    epoch:2    accuracy:49.710    best:50.210
learning_rate: 0.06010204081632654
Network:1    epoch:3    accuracy:50.410    best:50.410
learning_rate: 0.06010204081632654
Network:2    epoch:3    accuracy:54.120    best:54.120
learning_rate: 0.08010204081632653
Network:1    epoch:4    accuracy:51.750    best:54.120
learning_rate: 0.08010204081632653
Network:2    epoch:4    accuracy:51.110    best:54.120
learning_rate: 0.1
Network:1    epoch:5    accuracy:56.360    best:56.360
learning_rate: 0.1

Network:2    epoch:49    accuracy:68.670    best:69.040
learning_rate: 0.010000000000000002
Network:1    epoch:50    accuracy:68.790    best:69.040
learning_rate: 0.010000000000000002
Network:2    epoch:50    accuracy:68.360    best:69.040
learning_rate: 0.010000000000000002
Network:1    epoch:51    accuracy:68.540    best:69.040
learning_rate: 0.010000000000000002
Network:2    epoch:51    accuracy:68.250    best:69.040
learning_rate: 0.010000000000000002
Network:1    epoch:52    accuracy:68.070    best:69.040
learning_rate: 0.010000000000000002
Network:2    epoch:52    accuracy:68.350    best:69.040
learning_rate: 0.010000000000000002
Network:1    epoch:53    accuracy:68.460    best:69.040
learning_rate: 0.010000000000000002
Network:2    epoch:53    accuracy:68.560    best:69.040
learning_rate: 0.010000000000000002
Network:1    epoch:54    accuracy:68.680    best:69.040
learning_rate: 0.010000000000000002
Network:2    epoch:54    accuracy:68.560    best:69.040
learning_rate: 0.0100000

learning_rate: 1.0000000000000003e-05
Network:1    epoch:94    accuracy:68.910    best:69.040
learning_rate: 1.0000000000000003e-05
Network:2    epoch:94    accuracy:68.920    best:69.040
learning_rate: 1.0000000000000003e-05
Network:1    epoch:95    accuracy:68.780    best:69.040
learning_rate: 1.0000000000000003e-05
Network:2    epoch:95    accuracy:68.830    best:69.040
learning_rate: 1.0000000000000003e-05
Network:1    epoch:96    accuracy:68.780    best:69.040
learning_rate: 1.0000000000000003e-05
Network:2    epoch:96    accuracy:68.760    best:69.040
learning_rate: 1.0000000000000003e-05
Network:1    epoch:97    accuracy:68.850    best:69.040
learning_rate: 1.0000000000000003e-05
Network:2    epoch:97    accuracy:68.790    best:69.040
learning_rate: 1.0000000000000003e-05
Network:1    epoch:98    accuracy:68.920    best:69.040
learning_rate: 1.0000000000000003e-05
Network:2    epoch:98    accuracy:68.800    best:69.040
learning_rate: 1.0000000000000003e-05
Network:1    epoch:99 

Network:1    epoch:137    accuracy:68.890    best:69.040
learning_rate: 1.0000000000000004e-06
Network:2    epoch:137    accuracy:68.770    best:69.040
learning_rate: 1.0000000000000004e-06
Network:1    epoch:138    accuracy:68.610    best:69.040
learning_rate: 1.0000000000000004e-06
Network:2    epoch:138    accuracy:68.940    best:69.040
learning_rate: 1.0000000000000004e-06
Network:1    epoch:139    accuracy:68.970    best:69.040
learning_rate: 1.0000000000000004e-06
Network:2    epoch:139    accuracy:68.840    best:69.040
learning_rate: 1.0000000000000004e-06
Network:1    epoch:140    accuracy:68.980    best:69.040
learning_rate: 1.0000000000000004e-06
Network:2    epoch:140    accuracy:69.030    best:69.040
learning_rate: 1.0000000000000004e-06
Network:1    epoch:141    accuracy:68.860    best:69.040
learning_rate: 1.0000000000000004e-06
Network:2    epoch:141    accuracy:68.940    best:69.040
learning_rate: 1.0000000000000004e-06
Network:1    epoch:142    accuracy:68.860    best:

Network:2    epoch:180    accuracy:68.840    best:69.040
learning_rate: 1.0000000000000004e-08
Network:1    epoch:181    accuracy:68.760    best:69.040
learning_rate: 1.0000000000000004e-08
Network:2    epoch:181    accuracy:68.720    best:69.040
learning_rate: 1.0000000000000004e-08
Network:1    epoch:182    accuracy:68.800    best:69.040
learning_rate: 1.0000000000000004e-08
Network:2    epoch:182    accuracy:68.850    best:69.040
learning_rate: 1.0000000000000004e-08
Network:1    epoch:183    accuracy:68.830    best:69.040
learning_rate: 1.0000000000000004e-08
Network:2    epoch:183    accuracy:68.780    best:69.040
learning_rate: 1.0000000000000004e-08
Network:1    epoch:184    accuracy:68.910    best:69.040
learning_rate: 1.0000000000000004e-08
Network:2    epoch:184    accuracy:68.880    best:69.040
learning_rate: 1.0000000000000004e-08
Network:1    epoch:185    accuracy:68.870    best:69.040
learning_rate: 1.0000000000000004e-08
Network:2    epoch:185    accuracy:68.790    best:

learning_rate: 1.0000000000000005e-09
Network:1    epoch:224    accuracy:68.840    best:69.040
learning_rate: 1.0000000000000005e-09
Network:2    epoch:224    accuracy:68.820    best:69.040
learning_rate: 1.0000000000000005e-09
Network:1    epoch:225    accuracy:68.830    best:69.040
learning_rate: 1.0000000000000005e-09
Network:2    epoch:225    accuracy:68.930    best:69.040
learning_rate: 1.0000000000000005e-09
Network:1    epoch:226    accuracy:68.890    best:69.040
learning_rate: 1.0000000000000005e-09
Network:2    epoch:226    accuracy:68.850    best:69.040
learning_rate: 1.0000000000000005e-09
Network:1    epoch:227    accuracy:68.910    best:69.040
learning_rate: 1.0000000000000005e-09
Network:2    epoch:227    accuracy:68.890    best:69.040
learning_rate: 1.0000000000000005e-09
Network:1    epoch:228    accuracy:68.850    best:69.040
learning_rate: 1.0000000000000005e-09
Network:2    epoch:228    accuracy:68.860    best:69.040
learning_rate: 1.0000000000000005e-09
Network:1   

Network:1    epoch:267    accuracy:68.960    best:69.040
learning_rate: 1.0000000000000006e-10
Network:2    epoch:267    accuracy:68.870    best:69.040
learning_rate: 1.0000000000000006e-10
Network:1    epoch:268    accuracy:68.860    best:69.040
learning_rate: 1.0000000000000006e-10
Network:2    epoch:268    accuracy:68.750    best:69.040
learning_rate: 1.0000000000000006e-10
Network:1    epoch:269    accuracy:68.830    best:69.040
learning_rate: 1.0000000000000006e-10
Network:2    epoch:269    accuracy:68.850    best:69.040
learning_rate: 1.0000000000000006e-11
Network:1    epoch:270    accuracy:68.880    best:69.040
learning_rate: 1.0000000000000006e-11
Network:2    epoch:270    accuracy:68.800    best:69.040
learning_rate: 1.0000000000000006e-11
Network:1    epoch:271    accuracy:68.840    best:69.040
learning_rate: 1.0000000000000006e-11
Network:2    epoch:271    accuracy:68.810    best:69.040
learning_rate: 1.0000000000000006e-11
Network:1    epoch:272    accuracy:68.830    best: