In [1]:
# https://github.com/pytorch/vision/blob/master/torchvision/models/__init__.py
import argparse
import os,sys
import shutil
import pdb, time
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models
from utils import convert_secs2time, time_string, time_file_str
# from models import print_log
import models
import random
import numpy as np
import copy

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

In [2]:
from dotmap import DotMap

args = DotMap()
args.data = '/home/hongky/datasets/imagenet'
args.save_dir = './infer_small_model/'
args.arch = 'resnet101'
args.workers = 12
args.batch_size = 64
args.lr = 0.1
args.print_freq = 200

args.rate = 0.7
args.layer_begin = 3
args.layer_end = 3
args.layer_inter = 1
args.epoch_prune = 1
args.skip_downsample = 1
args.get_small = True 
args.use_cuda = True

args.prefix = time_file_str()

In [3]:
def validate(val_loader, model, criterion, log, is_cuda=False):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        # target = target.cuda(async=True)
        if is_cuda:
            input = input.cuda()
        target = target.cuda(non_blocking=True)
        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print_log('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, len(val_loader), batch_time=batch_time, loss=losses,
                top1=top1, top5=top5), log)

    print_log(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5,
                                                                                           error1=100 - top1.avg), log)

    return top1.avg


def save_checkpoint(state, is_best, filename, bestname):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, bestname)


def print_log(print_string, log):
    print("{}".format(print_string))
    log.write('{}\n'.format(print_string))
    log.flush()


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def remove_module_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    return new_state_dict

In [4]:
best_prec1 = 0

if not os.path.isdir(args.save_dir):
    os.makedirs(args.save_dir)
log = open(os.path.join(args.save_dir, 'gpu-time.{}.{}.log'.format(args.arch, args.prefix)), 'w')

# create model
print_log("=> creating model '{}'".format(args.arch), log)
model = models.__dict__[args.arch](pretrained=False)
print_log("=> Model : {}".format(model), log)
print_log("=> parameter : {}".format(args), log)
print_log("Compress Rate: {}".format(args.rate), log)
print_log("Layer Begin: {}".format(args.layer_begin), log)
print_log("Layer End: {}".format(args.layer_end), log)
print_log("Layer Inter: {}".format(args.layer_inter), log)
print_log("Epoch prune: {}".format(args.epoch_prune), log)
print_log("Skip downsample : {}".format(args.skip_downsample), log)



cudnn.benchmark = True

# Data loading code
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        # transforms.Scale(256),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)

criterion = nn.CrossEntropyLoss().cuda()

=> creating model 'resnet101'
=> Model : ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d

=> loaded checkpoint '1210_resnet101/resnet101-rate-0.6/checkpoint.resnet101.2020-10-12-6800.pth.tar' (epoch 68)


In [5]:
def check_channel(tensor):
    #print('DEBUG: tensor-size')
    #print(tensor.size())
    #print(len(tensor.size()))
#     if len(tensor.size()) == 0:
#         return 0, 0
    size_0 = tensor.size()[0]
    size_1 = tensor.size()[1] * tensor.size()[2] * tensor.size()[3]
    tensor_resize = tensor.view(size_0, -1)
    # indicator: if the channel contain all zeros
    channel_if_zero = np.zeros(size_0)
    for x in range(0, size_0, 1):
        channel_if_zero[x] = np.count_nonzero(tensor_resize[x].cpu().numpy()) != 0
    # indices = (torch.LongTensor(channel_if_zero) != 0 ).nonzero().view(-1)

    indices_nonzero = torch.LongTensor((channel_if_zero != 0).nonzero()[0])
    # indices_nonzero = torch.LongTensor((channel_if_zero != 0).nonzero()[0])

    zeros = (channel_if_zero == 0).nonzero()[0]
    indices_zero = torch.LongTensor(zeros) if zeros != [] else []

    return indices_zero, indices_nonzero

In [6]:
def prune_conv_bn(conv1, bn1, inplanes, inplanes_indices=None, kernel_size=1, stride=1, padding=0, bias=False):
    indices_zero, indices_nonzero = check_channel(conv1.weight.detach())
    print(len(indices_zero), len(indices_nonzero))
    n_outplanes = len(indices_nonzero)

    n_conv1 = nn.Conv2d(inplanes, n_outplanes, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=bias)

#     print('-------------\n conv1')
    state_dict = {}
    for k in conv1.state_dict().keys():

        vals = conv1.state_dict()[k]
#         print('param: ', k, type(vals), vals.size())
        state_dict[k] = torch.index_select(vals, 0, indices_nonzero)
        if inplanes_indices != None:
            state_dict[k] = torch.index_select(state_dict[k], 1, inplanes_indices)
    n_conv1.load_state_dict(state_dict)


#     print('-------------\n bn1')
    state_dict = {}
    
    n_bn1 = nn.BatchNorm2d(len(indices_nonzero))
    for k in bn1.state_dict().keys():
        vals = bn1.state_dict()[k]
#         print('param: ', k, type(vals), vals.size())
        if(len(vals.size()) > 0):
            state_dict[k] = torch.index_select(vals, 0, indices_nonzero)
        else:
            state_dict[k] = vals

    n_bn1.load_state_dict(state_dict)
    
    return n_conv1, n_bn1, n_outplanes, indices_nonzero


def prune_inplane_conv_bn(conv1, bn1, inplanes, inplanes_indices=None, kernel_size=1, stride=1, padding=0, bias=False):
    n_outplanes = conv1.weight.size()[0]

    n_conv1 = nn.Conv2d(inplanes, n_outplanes, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=bias)

#     print('-------------\n conv1')
    state_dict = {}
    for k in conv1.state_dict().keys():

        vals = conv1.state_dict()[k]
        if inplanes_indices != None:
            state_dict[k] = torch.index_select(vals, 1, inplanes_indices)
        else:
            state_dict[k] = vals
    n_conv1.load_state_dict(state_dict)


    n_bn1 = bn1
    
    indices_nonzero = None
    
    return n_conv1, n_bn1, n_outplanes, indices_nonzero



class PrunedBottleneck(nn.Module):
    expansion = 4

    def __init__(self, origin_block, inplanes, inplanes_indices, stride=1, downsample=None):
        super(PrunedBottleneck, self).__init__()
        
        
        conv1, bn1, next_inplanes, next_inplanes_indices = prune_conv_bn(origin_block.conv1, origin_block.bn1, 
                                   inplanes, inplanes_indices, 
                                   kernel_size=1, bias=False)
        self.conv1 = conv1
        self.bn1 = bn1
        
        
        conv2, bn2, next_inplanes, next_inplanes_indices = prune_conv_bn(origin_block.conv2, origin_block.bn2,
                               next_inplanes, next_inplanes_indices,  
                               kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.conv2 = conv2
        self.bn2 = bn2
        
        
        conv3, bn3, next_inplanes, next_inplanes_indices = prune_inplane_conv_bn(origin_block.conv3, origin_block.bn3,
                               next_inplanes, next_inplanes_indices,  
                               kernel_size=1, bias=False)
        self.conv3 = conv3
        self.bn3 = bn3
        
        
        self.next_inplanes = next_inplanes
        self.next_inplanes_indices = next_inplanes_indices
        
        
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        
    def forward(self, x):
        residual = x
#         print('block-0:', x.size())

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
#         print('block-1:', out.size())

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
#         print('block-2:', out.size())

        out = self.conv3(out)
        out = self.bn3(out)
#         print('block-3:', out.size())

        if self.downsample is not None:
            residual = self.downsample(x)
#             print('residual:', residual.size())

        out += residual
        out = self.relu(out)

        return out
    
    
    
    




class CloneBottleneck(nn.Module):

    def __init__(self, origin_block, inplanes, inplanes_indices, stride=1, downsample=None):
        super(CloneBottleneck, self).__init__()
        
        
        conv1, bn1, next_inplanes, next_inplanes_indices = prune_inplane_conv_bn(origin_block.conv1, origin_block.bn1, 
                                   inplanes, inplanes_indices, 
                                   kernel_size=1, bias=False)
        self.conv1 = conv1
        self.bn1 = bn1
        
        
        conv2, bn2, next_inplanes, next_inplanes_indices = prune_inplane_conv_bn(origin_block.conv2, origin_block.bn2,
                               next_inplanes, next_inplanes_indices,  
                               kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.conv2 = conv2
        self.bn2 = bn2
        
        
        conv3, bn3, next_inplanes, next_inplanes_indices = prune_inplane_conv_bn(origin_block.conv3, origin_block.bn3,
                               next_inplanes, next_inplanes_indices,  
                               kernel_size=1, bias=False)
        self.conv3 = conv3
        self.bn3 = bn3
        
        
        self.next_inplanes = next_inplanes
        self.next_inplanes_indices = next_inplanes_indices
        
        
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        
    def forward(self, x):
        residual = x
        #print('block-0:', x.size())

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        #print('block-1:', out.size())

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        #print('block-2:', out.size())

        out = self.conv3(out)
        out = self.bn3(out)
        #print('block-3:', out.size())

        if self.downsample is not None:
            residual = self.downsample(x)
            #print('residual:', residual.size())

        out += residual
        out = self.relu(out)

        return out

In [7]:
def make_downsample(origin_downsample, conv3, inplanes, inplanes_indices, stride):
    indices_zero, indices_nonzero = check_channel(conv3.weight.detach())
    print(len(indices_zero), len(indices_nonzero))
    n_outplanes = len(indices_nonzero)
    
    
    conv1 = origin_downsample[0]
    bn1 = origin_downsample[1]
    
    n_conv1 = nn.Conv2d(inplanes, n_outplanes,
                        kernel_size=1, stride=stride, bias=False)
                
            
    #     print('-------------\n conv1')
    state_dict = {}
    for k in conv1.state_dict().keys():

        vals = conv1.state_dict()[k]
#         print('param: ', k, type(vals), vals.size())
        state_dict[k] = torch.index_select(vals, 0, indices_nonzero)
        if inplanes_indices != None:
            state_dict[k] = torch.index_select(state_dict[k], 1, inplanes_indices)
    n_conv1.load_state_dict(state_dict)


#     print('-------------\n bn1')
    state_dict = {}
    
    n_bn1 = nn.BatchNorm2d(n_outplanes)
    for k in bn1.state_dict().keys():
        vals = bn1.state_dict()[k]
#         print('param: ', k, type(vals), vals.size())
        if(len(vals.size()) > 0):
            state_dict[k] = torch.index_select(vals, 0, indices_nonzero)
        else:
            state_dict[k] = vals

    n_bn1.load_state_dict(state_dict)
    
    n_downsample = nn.Sequential(n_conv1, n_bn1)
    return n_downsample


def make_normal_downsample(origin_downsample, inplanes, inplanes_indices, stride):
    conv1 = origin_downsample[0]
    bn1 = origin_downsample[1]
    
    n_outplanes = conv1.weight.size()[0]
    
    n_conv1 = nn.Conv2d(inplanes, n_outplanes,
                        kernel_size=1, stride=stride, bias=False)
                
    state_dict = {}
    for k in conv1.state_dict().keys():
        vals = conv1.state_dict()[k]
        if inplanes_indices != None:
            state_dict[k] = torch.index_select(vals, 1, inplanes_indices)
        else:
            state_dict[k] = vals
    n_conv1.load_state_dict(state_dict)

    n_bn1 = bn1
    
    n_downsample = nn.Sequential(n_conv1, n_bn1)
    return n_downsample
            




class PruneResNet101(nn.Module):

    def __init__(self, origin_model, layers=[3,4,23,3], num_classes=1000):
        super(PruneResNet101, self).__init__()
        
        
        conv1, bn1, next_inplanes, next_inplanes_indices = prune_conv_bn(origin_model.conv1, origin_model.bn1, 
                                   inplanes=3, inplanes_indices=None, 
                                   kernel_size=7, stride=2, padding=3, bias=False)
        
        self.conv1 = conv1
        self.bn1 = bn1

#         self.conv1 = origin_model.conv1
#         self.bn1 = origin_model.bn1
#         next_inplanes = 64
#         next_inplanes_indices = None
        
        
        
        
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        
        
        
        self.layer1, next_inplanes, next_inplanes_indices = self._make_layer(
                                    origin_model.layer1,
                                    next_inplanes, next_inplanes_indices, 
                                    layers[0])
        print('origin_layer1::')
        print(origin_model.layer1)
        print('======')
        print('layer1::')
        print(self.layer1)
        print('---\n\n')
        
        self.layer2, next_inplanes, next_inplanes_indices = self._make_layer(
                                    origin_model.layer2,
                                    next_inplanes, next_inplanes_indices,  
                                    layers[1], stride=2)
        print('origin_layer2::')
        print(origin_model.layer2)
        print('======')
        print('layer2::')
        print(self.layer2)
        print('---\n\n')
        
        self.layer3, next_inplanes, next_inplanes_indices = self._make_layer(
                                    origin_model.layer3,
                                    next_inplanes, next_inplanes_indices,  
                                    layers[2], stride=2)
        print('origin_layer3::')
        print(origin_model.layer3)
        print('======')
        print('layer3::')
        print(self.layer3)
        print('---\n\n')
        
        self.layer4, next_inplanes, next_inplanes_indices = self._make_layer(
                                    origin_model.layer4,
                                    next_inplanes, next_inplanes_indices,  
                                    layers[3], stride=2)
        print('origin_layer4::')
        print(origin_model.layer4)
        print('======')
        print('layer4::')
        print(self.layer4)
        print('---\n\n')
        
        self.avgpool = nn.AvgPool2d(7, stride=1)
        
        
        fc = nn.Linear(next_inplanes, num_classes)
        o_fc = origin_model.fc
        state_dict = {}
        for k in o_fc.state_dict().keys():
            vals = o_fc.state_dict()[k]
            if(len(vals.size()) > 1):
                state_dict[k] = torch.index_select(vals, 1, next_inplanes_indices)
            else:
                state_dict[k] = vals
                
        fc.load_state_dict(state_dict)
        
        
        self.fc = fc

        

    def _make_layer(self, origin_layer, inplanes, inplanes_indices, blocks, stride=1):
        print('blocks: ', blocks)
        layers = []
        
        block0 = origin_layer[0]
        downsample = make_downsample(block0.downsample, block0.conv3, inplanes, inplanes_indices, stride)
        
        
        # origin_block, inplanes, inplanes_indices, stride=1, downsample=None
        new_block0 = PrunedBottleneck(block0, inplanes, inplanes_indices, stride, downsample)
        inplanes = new_block0.next_inplanes
        inplanes_indices = new_block0.next_inplanes_indices
        
        layers.append(new_block0)
        
        
        for i in range(1, blocks):
            blocki = origin_layer[i]
            new_blocki = PrunedBottleneck(blocki, inplanes, inplanes_indices, downsample=None)
            inplanes = new_blocki.next_inplanes
            inplanes_indices = new_blocki.next_inplanes_indices
            layers.append(new_blocki)

        return nn.Sequential(*layers), inplanes, inplanes_indices
    
    

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        
#         print('0: ',x.size())

        x = self.layer1(x)
#         print('1: ',x.size())
        
        x = self.layer2(x)
#         print('2: ',x.size())
        
        x = self.layer3(x)
#         print('3: ',x.size())
        
        x = self.layer4(x)
#         print('4: ',x.size())

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [8]:

class CloneResNet101(nn.Module):

    def __init__(self, origin_model, layers=[3,4,23,3], num_classes=1000):
        super(CloneResNet101, self).__init__()
        
        
#         conv1, bn1, next_inplanes, next_inplanes_indices = prune_conv_bn(origin_model.conv1, origin_model.bn1, 
#                                    inplanes=3, inplanes_indices=None, 
#                                    kernel_size=7, stride=2, padding=3, bias=False)
        
#         self.conv1 = conv1
#         self.bn1 = bn1

        self.conv1 = origin_model.conv1
        self.bn1 = origin_model.bn1
        next_inplanes = 64
        next_inplanes_indices = None
        
        
        
        
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        
        
        
        self.layer1, next_inplanes, next_inplanes_indices = self._make_normal_layer(
                                    origin_model.layer1,
                                    next_inplanes, next_inplanes_indices, 
                                    layers[0])
#         print('origin_layer1::')
#         print(origin_model.layer1)
        print('======')
        print('layer1::')
        print(self.layer1)
        print('---\n\n')
        
        self.layer2, next_inplanes, next_inplanes_indices = self._make_normal_layer(
                                    origin_model.layer2,
                                    next_inplanes, next_inplanes_indices,  
                                    layers[1], stride=2)
#         print('origin_layer2::')
#         print(origin_model.layer2)
        print('======')
        print('layer2::')
        print(self.layer2)
        print('---\n\n')
        
        self.layer3, next_inplanes, next_inplanes_indices = self._make_mix_layer(
                                    origin_model.layer3,
                                    next_inplanes, next_inplanes_indices,  
                                    layers[2], stride=2)
#         print('origin_layer3::')
#         print(origin_model.layer3)
        print('======')
        print('layer3::')
        print(self.layer3)
        print('---\n\n')
        
        self.layer4, next_inplanes, next_inplanes_indices = self._make_normal_layer(
                                    origin_model.layer4,
                                    next_inplanes, next_inplanes_indices,  
                                    layers[3], stride=2)
        print('origin_layer4::')
        print(origin_model.layer4)
        print('======')
        print('layer4::')
        print(self.layer4)
        print('---\n\n')
        
        self.avgpool = nn.AvgPool2d(7, stride=1)
        
        
#         fc = nn.Linear(next_inplanes, num_classes)
#         o_fc = origin_model.fc
#         state_dict = {}
#         for k in o_fc.state_dict().keys():
#             vals = o_fc.state_dict()[k]
#             if(len(vals.size()) > 1):
#                 state_dict[k] = torch.index_select(vals, 1, next_inplanes_indices)
#             else:
#                 state_dict[k] = vals
                
#         fc.load_state_dict(state_dict)
        
        
        self.fc = origin_model.fc

        

    def _make_layer(self, origin_layer, inplanes, inplanes_indices, blocks, stride=1):
        print('blocks: ', blocks)
        layers = []
        
        block0 = origin_layer[0]
        downsample = make_downsample(block0.downsample, block0.conv3, inplanes, inplanes_indices, stride)
        
        
        # origin_block, inplanes, inplanes_indices, stride=1, downsample=None
        new_block0 = PrunedBottleneck(block0, inplanes, inplanes_indices, stride, downsample)
        inplanes = new_block0.next_inplanes
        inplanes_indices = new_block0.next_inplanes_indices
        
        layers.append(new_block0)
        
        
        for i in range(1, blocks):
            blocki = origin_layer[i]
            new_blocki = PrunedBottleneck(blocki, inplanes, inplanes_indices, downsample=None)
            inplanes = new_blocki.next_inplanes
            inplanes_indices = new_blocki.next_inplanes_indices
            layers.append(new_blocki)

        return nn.Sequential(*layers), inplanes, inplanes_indices
    
    
    def _make_normal_layer(self, origin_layer, inplanes, inplanes_indices, blocks, stride=1):
        print('blocks: ', blocks)
        layers = []
        
        block0 = origin_layer[0]
        downsample = make_normal_downsample(block0.downsample, inplanes, inplanes_indices, stride)
        
        
        # origin_block, inplanes, inplanes_indices, stride=1, downsample=None
        new_block0 = CloneBottleneck(block0, inplanes, inplanes_indices, stride, downsample)
        inplanes = new_block0.next_inplanes
        inplanes_indices = new_block0.next_inplanes_indices
        
        layers.append(new_block0)
        
        
        for i in range(1, blocks):
            blocki = origin_layer[i]
            new_blocki = CloneBottleneck(blocki, inplanes, inplanes_indices, downsample=None)
            inplanes = new_blocki.next_inplanes
            inplanes_indices = new_blocki.next_inplanes_indices
            layers.append(new_blocki)

        return nn.Sequential(*layers), inplanes, inplanes_indices
    
    
    def _make_mix_layer(self, origin_layer, inplanes, inplanes_indices, blocks, stride=1):
        print('blocks: ', blocks)
        layers = []
        
        block0 = origin_layer[0]
        downsample = make_normal_downsample(block0.downsample, inplanes, inplanes_indices, stride)
        
        
        # origin_block, inplanes, inplanes_indices, stride=1, downsample=None
        new_block0 = PrunedBottleneck(block0, inplanes, inplanes_indices, stride, downsample)
        inplanes = new_block0.next_inplanes
        inplanes_indices = new_block0.next_inplanes_indices
        
        layers.append(new_block0)
        
        
        for i in range(1, blocks-1):
            blocki = origin_layer[i]
            new_blocki = PrunedBottleneck(blocki, inplanes, inplanes_indices, downsample=None)
            inplanes = new_blocki.next_inplanes
            inplanes_indices = new_blocki.next_inplanes_indices
            layers.append(new_blocki)
        

        return nn.Sequential(*layers), inplanes, inplanes_indices
    

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        
#         print('0: ',x.size())

        x = self.layer1(x)
#         print('1: ',x.size())
        
        x = self.layer2(x)
#         print('2: ',x.size())
        
        x = self.layer3(x)
#         print('3: ',x.size())
        
        x = self.layer4(x)
#         print('4: ',x.size())

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
    

In [26]:
paths = [
    'traced_resnet101_0.9.pt',
    'traced_resnet101_0.7.pt',
    'traced_resnet101_0.6.pt',
    'traced_resnet101_0.5.pt',
    'traced_resnet101_0.4.pt',
]

resumes = [
    '0810_resnet101/resnet101-rate-0.9/best.resnet101.2020-10-08-2290.pth.tar',
    '0810_resnet101/resnet101-rate-0.7/best.resnet101.2020-10-08-2702.pth.tar',
    '1210_resnet101/resnet101-rate-0.6/checkpoint.resnet101.2020-10-12-6800.pth.tar',
    '1210_resnet101/resnet101-rate-0.5/checkpoint.resnet101.2020-10-12-9905.pth.tar',
    '1210_resnet101/resnet101-rate-0.4/best.resnet101.2020-10-12-2215.pth.tar'
]
#'1210_resnet101/resnet101-rate-0.6/checkpoint.resnet101.2020-10-12-6800.pth.tar'
#'0810_resnet101/resnet101-rate-0.7/best.resnet101.2020-10-08-2702.pth.tar'
#'0810_resnet101/resnet101-rate-0.9/best.resnet101.2020-10-08-2290.pth.tar'
# '1210_resnet101/resnet101-rate-0.4/best.resnet101.2020-10-12-2215.pth.tar'
#'1210_resnet101/resnet101-rate-0.4/checkpoint.resnet101.2020-10-12-2215.pth.tar'
#'1210_resnet101/resnet101-rate-0.5/checkpoint.resnet101.2020-10-12-9905.pth.tar'
#'1210_resnet101/resnet101-rate-0.6/checkpoint.resnet101.2020-10-12-6800.pth.tar'
#'./0810_resnet101/resnet101-rate-0.7/best.resnet101.2020-10-08-2702.pth.tar'

# optionally resume from a checkpoint

for i in range(5):
    path = paths[i]
    args.resume = resumes[i]
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            state_dict = checkpoint['state_dict']
            state_dict = remove_module_dict(state_dict)
            model.load_state_dict(state_dict)
            print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)

    # print('evaluate: big')
    # print('big model accu', validate(val_loader, torch.nn.DataParallel(model).cuda(), criterion, log))        

    pruned_network = CloneResNet101(model.cpu())
    #small_model = torch.nn.DataParallel(pruned_network).cuda()
    # print('evaluate: small')
    # print('small model accu', validate(val_loader, small_model, criterion, log))


    # path = 'pruned_layer3_resnet101_0.6.pth'
    # torch.save(small_model, path)

    # save mobile pytorch small_model
    pruned_network.eval()
    example_inputs = torch.rand(1, 3, 224, 224).cpu()
    model_traced = torch.jit.trace(pruned_network, example_inputs = example_inputs)
    model_traced.save(path)

=> loading checkpoint '0810_resnet101/resnet101-rate-0.9/best.resnet101.2020-10-08-2290.pth.tar'
=> loaded checkpoint '0810_resnet101/resnet101-rate-0.9/best.resnet101.2020-10-08-2290.pth.tar' (epoch 97)
blocks:  3
layer1::
Sequential(
  (0): CloneBottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru



25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
25 231
layer3::
Sequential(
  (0): PrunedBottleneck(
    (conv1): Conv2d(512, 231, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(231, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(231, 231, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(231, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(231, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): PrunedBottlene

=> loading checkpoint '0810_resnet101/resnet101-rate-0.7/best.resnet101.2020-10-08-2702.pth.tar'
=> loaded checkpoint '0810_resnet101/resnet101-rate-0.7/best.resnet101.2020-10-08-2702.pth.tar' (epoch 94)
blocks:  3
layer1::
Sequential(
  (0): CloneBottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

origin_layer4::
Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
102 154
layer3::
Sequential(
  (0): PrunedBottleneck(
    (conv1): Conv2d(512, 154, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(154, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(154, 154, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(154, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(154, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): PrunedBottleneck(
    

origin_layer4::
Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
128 128
layer3::
Sequential(
  (0): PrunedBottleneck(
    (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): PrunedBottleneck(
    

origin_layer4::
Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
153 103
layer3::
Sequential(
  (0): PrunedBottleneck(
    (conv1): Conv2d(512, 103, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(103, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(103, 103, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(103, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(103, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): PrunedBottleneck(
    (conv1):

origin_layer4::
Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

In [12]:
path = 'pruned_layer3_resnet101_0.5.pth'

torch.save(small_model, path)
loaded_model = torch.load(path)
small_load_model = loaded_model.cuda() #torch.nn.DataParallel(loaded_model).cuda()
print('evaluate: small')
print('small model accu', validate(val_loader, small_load_model, criterion, log, is_cuda=True))

evaluate: small


  app.launch_new_instance()


Test: [0/782]	Time 2.134 (2.134)	Loss 1.0579 (1.0579)	Prec@1 82.812 (82.812)	Prec@5 90.625 (90.625)
Test: [200/782]	Time 0.122 (0.111)	Loss 1.7574 (1.5379)	Prec@1 54.688 (62.741)	Prec@5 85.938 (85.549)
Test: [400/782]	Time 0.081 (0.105)	Loss 1.6367 (1.6529)	Prec@1 59.375 (60.673)	Prec@5 79.688 (83.868)
Test: [600/782]	Time 0.087 (0.102)	Loss 1.8400 (1.7557)	Prec@1 64.062 (58.824)	Prec@5 82.812 (82.111)
 * Prec@1 57.642 Prec@5 81.014 Error@1 42.358
small model accu 57.642


In [None]:
# 0.9 76.01 78.22 160MB
# 0.7 75.154 77.682 134MB
# 0.6 72.622 75.108 123MB
# 0.5 64.732 67.412 112MB
# 0.4 61.358 65.16 103MB
