In [1]:
from pruner import l1normPruner
import pruner
import os
import argparse
import torch
from torchvision import datasets, transforms
from models import *
import torch.optim as optim
from os.path import join
import json

from mythop import clever_format, profile


import torch
import numpy as np
import os
import torch
import torch.nn as nn

In [2]:
from dotmap import DotMap


# python prune.py --arch MobileNetV2 --pruner l1normpruner --pruneratio 0.6

args = DotMap()

args.dataset = 'imagenet'
args.workers=8
args.batch_size = 16
args.test_batch_size = 8
args.epochs = 10

args.start_epoch = 0
args.finetunelr = 0.01

args.momentum = 0.9
args.weight_decay = 1e-4

args.resume = ''
args.no_cuda=False
args.seed=1

args.save = 'checkpoints'
args.arch = 'MobileNetV2'
args.pruner = 'l1normPruner'
args.pruneratio = 0.6
args.sr = True


args.cuda = not args.no_cuda and torch.cuda.is_available()
savepath = os.path.join(args.save, args.arch, 'sr' if args.sr else 'nosr')
args.savepath = savepath
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}


args.data = '/home/hongky/datasets/imagenet'

In [3]:
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True, sampler=None)

test_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=args.test_batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)

In [10]:
import torchvision
from models import *

torch_mobilenetv2 = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)

def clone_mobilenet():
    mobilenet2 = eval('MobileNetV2')(n_class=1000, input_size=224)


    # classifier
    mobilenet2.classifier = torch_mobilenetv2.classifier 

    # features 0
    mobilenet2.features[0].convbn.conv = torch_mobilenetv2.features[0][0]
    mobilenet2.features[0].convbn.bn = torch_mobilenetv2.features[0][1]
    mobilenet2.features[0].convbn.relu = torch_mobilenetv2.features[0][2]

    # features 18
    mobilenet2.features[18].convbn.conv = torch_mobilenetv2.features[18][0]
    mobilenet2.features[18].convbn.bn = torch_mobilenetv2.features[18][1]
    mobilenet2.features[18].convbn.relu = torch_mobilenetv2.features[18][2]

    # feature 1
    mobilenet2.features[1].conv.dw_conv = torch_mobilenetv2.features[1].conv[0][0]
    mobilenet2.features[1].conv.dw_bn = torch_mobilenetv2.features[1].conv[0][1]
    mobilenet2.features[1].conv.dw_relu = torch_mobilenetv2.features[1].conv[0][2]
    mobilenet2.features[1].conv.project_conv = torch_mobilenetv2.features[1].conv[1]
    mobilenet2.features[1].conv.project_bn = torch_mobilenetv2.features[1].conv[2]



    for i in range(2, 18):
        mobilenet2.features[i].conv.expand_conv = torch_mobilenetv2.features[i].conv[0][0]
        mobilenet2.features[i].conv.expand_bn = torch_mobilenetv2.features[i].conv[0][1]
        mobilenet2.features[i].conv.expand_relu = torch_mobilenetv2.features[i].conv[0][2]

        mobilenet2.features[i].conv.dw_conv = torch_mobilenetv2.features[i].conv[1][0]
        mobilenet2.features[i].conv.dw_bn = torch_mobilenetv2.features[i].conv[1][1]
        mobilenet2.features[i].conv.dw_relu = torch_mobilenetv2.features[i].conv[1][2]

        mobilenet2.features[i].conv.project_conv = torch_mobilenetv2.features[i].conv[2]
        mobilenet2.features[i].conv.project_bn = torch_mobilenetv2.features[i].conv[3]

    return mobilenet2

model = clone_mobilenet()
newmodel = clone_mobilenet()


if args.cuda:
    model = nn.DataParallel(model).cuda()
    newmodel = nn.DataParallel(newmodel).cuda()
# if args.cuda:
#     model = model.cuda()
#     newmodel = newmodel.cuda()
model.eval()
newmodel.eval()
best_prec1 = -1
optimizer = optim.SGD(model.parameters(), lr=args.finetunelr, momentum=args.momentum, weight_decay=args.weight_decay)


In [5]:
import os
import torch
from tqdm import tqdm
import torch.nn.functional as F
import torch.optim as optim
from models import MobileNetV2, InvertedResidual, sepconv_bn, conv_bn_relu, ShuffleV2Block, Bottleneck
from pruner.Block import *



class BasePruner:
    def __init__(self, model, newmodel, testset, trainset, optimizer, args):
        self.model = model
        self.newmodel = newmodel
        self.testset = testset
        self.trainset = trainset
        self.optimizer = optimizer
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3, threshold=1e-2)
        self.args = args
        self.blocks = []

    def prune(self):
        self.blocks = []
        for midx, (name, module) in enumerate(self.model.named_modules()):
            idx = len(self.blocks)
            if isinstance(module, InvertedResidual):
                self.blocks.append(InverRes(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
            if isinstance(module, conv_bn_relu):
                # print(module)
                # for k, v in module.state_dict().items():
                #     print(k, v.shape)
                self.blocks.append(CB(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
            if isinstance(module, nn.Linear):
                self.blocks.append(FC(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
            if isinstance(module, ShuffleV2Block):
                self.blocks.append(ShuffleLayer(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
            if isinstance(module, Bottleneck):
                self.blocks.append(ResBottle(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
        # special blocks
        # for b in self.blocks:
            # if b.layername == 'features.18':
            #     b.keepoutput = True
            #     b.bnscale = None

    def test(self, newmodel=True, ckpt=None, cal_bn=False):
        if newmodel:
            model = self.newmodel
        else:
            model = self.model
        if ckpt:
            model.load_state_dict(ckpt)
        if cal_bn:
            model.train()
            # for idx,(data, target) in enumerate(tqdm(self.trainset, total=len(self.trainset))):
            for idx, (data, target) in enumerate(self.trainset):
                #data, target = data.cuda(), target.cuda()
                target = target.cuda()
                if idx == 100:
                    break
                with torch.no_grad():
                    _ = model(data)
            # print("calibrate bn done.")
        model.eval()
        test_loss = 0
        correct = 0
        # for data, target in tqdm(self.testset, total=len(self.testset)):
        for idx, (data, target) in tqdm(self.testset, total=len(self.testset)): #enumerate(self.testset):
            #data, target = data.cuda(), target.cuda()
            target = target.cuda()
            with torch.no_grad():
                output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()

        test_loss /= len(self.testset.dataset)
        return correct.item() / float(len(self.testset.dataset))

    def train(self):
        self.newmodel.train()
        avg_loss = 0.
        train_acc = 0.
        for batch_idx, (data, target) in tqdm(enumerate(self.trainset), total=len(self.trainset)):
            data, target = data.cuda(), target.cuda()
            self.optimizer.zero_grad()
            output = self.newmodel(data)
            loss = F.cross_entropy(output, target)
            avg_loss += loss.item()
            pred = output.data.max(1, keepdim=True)[1]
            train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
            loss.backward()
            self.optimizer.step()

    def finetune(self):
        best_prec1 = 0
        for epoch in range(1):
            self.train()
            prec1 = self.test()
            self.scheduler.step(prec1)
            lr_current = self.optimizer.param_groups[0]['lr']
            print("currnt lr:{},current prec:{}".format(lr_current, prec1))
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            if is_best:
                ckptfile = os.path.join(self.args.savepath, 'ft_model_best.pth.tar')
            else:
                ckptfile = os.path.join(self.args.savepath, 'ft_checkpoint.pth.tar')
            torch.save({
                'epoch': epoch + 1,
                'state_dict': self.newmodel.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': self.optimizer.state_dict(),
            }, ckptfile)
        return best_prec1

    def clone_model(self):
        blockidx = 0
        for name, m0 in self.newmodel.named_modules():
            if type(m0) not in [InvertedResidual,conv_bn_relu,nn.Linear,Bottleneck,ShuffleV2Block]:
                continue
            block = self.blocks[blockidx]
            curstatedict = block.statedict
            if blockidx == 0:
                inputmask = torch.arange(block.inputchannel)
            # print('name:', name, 'block.layername:', block.layername)
            assert name == block.layername
            if isinstance(block, CB):
                # conv(1weight)->bn(4weight)->relu
                assert len(curstatedict) == (1 + 4)
                block.clone2module(m0, inputmask)
                inputmask = block.prunemask
            if isinstance(block, InverRes):
                # dw->project or expand->dw->project
                assert len(curstatedict) in (10, 15)
                block.clone2module(m0, inputmask)
                inputmask = torch.arange(block.outputchannel)
            if isinstance(block, FC):
                block.clone2module(m0,inputmask)
            if isinstance(block, ResBottle):
                # dw->project or expand->dw->project
                assert len(curstatedict) in (15, 20)
                block.clone2module(m0, inputmask)
                inputmask = torch.arange(block.outputchannel)
            if isinstance(block, ShuffleLayer):
                if block.bnscale is not None:
                    block.clone2module(m0, inputmask)
                    inputmask = torch.arange((block.inputchannel + block.outputchannel) / 2)
                    if block.layername == 'features.3': # for 'features.4' stride=2, no pruning
                        inputmask = torch.arange(block.inputchannel + block.outputchannel)
                    if block.layername == 'features.15': # for 'conv_last' inputchannel=464
                        inputmask = torch.arange(block.inputchannel + block.outputchannel)
            blockidx += 1
            if blockidx > (len(self.blocks) - 1): break
        
        for name0, m0 in self.newmodel.named_modules():
            if name0 == 'first_conv.0':
                for name1, m1 in self.model.named_modules():
                    if name1 == 'first_conv.0':
                        break
                m0.weight.data = m1.weight.data
                break
        
        for name0, m0 in self.newmodel.named_modules():
            if name0 == 'first_conv.1':
                for name1, m1 in self.model.named_modules():
                    if name1 == 'first_conv.1':
                        break
                m0.weight.data = m1.weight.data
                m0.bias.data = m1.bias.data
                m0.running_mean.data = m1.running_mean.data
                m0.running_var.data = m1.running_var.data
                break

    def get_flops(self, model):
        from thop import clever_format, profile
        input = torch.randn(1, 3, 32, 32).cuda()
        flops, params = profile(model, inputs=(input,), verbose=False)
        return flops, params


In [6]:
class l1normPruner(BasePruner):
    def __init__(self, model, newmodel, testset, trainset, optimizer, args, pruneratio=0.1):
        super().__init__(model, newmodel, testset, trainset, optimizer, args)
        self.pruneratio = pruneratio

    def prune(self):
        #super().prune()
        self.blocks = []
        for midx, (name, module) in enumerate(self.model.named_modules()):
            idx = len(self.blocks)
            if isinstance(module, InvertedResidual):
                self.blocks.append(InverRes(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
            if isinstance(module, conv_bn_relu):
                # print(module)
                # for k, v in module.state_dict().items():
                #     print(k, v.shape)
                self.blocks.append(CB(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
            if isinstance(module, nn.Linear):
                self.blocks.append(FC(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
            if isinstance(module, ShuffleV2Block):
                self.blocks.append(ShuffleLayer(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
            if isinstance(module, Bottleneck):
                self.blocks.append(ResBottle(name, idx, idx - 1, idx + 1, list(module.state_dict().values())))
                
        for b in self.blocks:
            if isinstance(b, CB):
                pruneweight = torch.sum(torch.abs(b.statedict[0]), dim=(1, 2, 3))
                numkeep = int(pruneweight.shape[0] * (1 - self.pruneratio))
                _ascend = torch.argsort(pruneweight)
                _descend = torch.flip(_ascend, (0,))[:numkeep]
                mask = torch.zeros_like(pruneweight).long()
                mask[_descend] = 1
                b.prunemask = torch.where(mask == 1)[0]
            if isinstance(b, InverRes):
                if b.numlayer == 3:
                    pruneweight = torch.sum(torch.abs(b.statedict[0]), dim=(1, 2, 3))
                    numkeep = int(pruneweight.shape[0] * (1 - self.pruneratio))
                    _ascend = torch.argsort(pruneweight)
                    _descend = torch.flip(_ascend, (0,))[:numkeep]
                    mask = torch.zeros_like(pruneweight).long()
                    mask[_descend] = 1
                    b.prunemask = torch.where(mask == 1)[0]
        self.clone_model()
        print("l1 norm Pruner done")

In [11]:
def test(epoch,test_width=1.0,recal=False):
    model.eval()
    test_loss = 0
    correct = 0
    model.apply(lambda m: setattr(m, 'width_mult',test_width))
            
    model.eval()
    for data, target in tqdm(test_loader, total=len(test_loader)):
        if args.cuda:
            target = target.cuda()
        with torch.no_grad():
            output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).item()  # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nEpoch: {} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(epoch,
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return correct.item() / float(len(test_loader.dataset))

In [12]:
print(test(0))

100%|██████████| 6250/6250 [11:45<00:00,  8.86it/s]


Epoch: 0 Test set: Average loss: 1.1477, Accuracy: 35939/50000 (71.9%)

0.71878





In [7]:
if args.pruner == 'l1normPruner':
    kwargs = {'pruneratio': args.pruneratio}
elif args.pruner == 'SlimmingPruner':
    kwargs = {'pruneratio': args.pruneratio}
elif args.pruner == 'AutoSlimPruner':
    kwargs = {'prunestep': 16, 'constrain': 200e6}

pruner_object = l1normPruner(
            model=model, 
            newmodel=newmodel, 
            testset=test_loader, 
            trainset=train_loader,
            optimizer=optimizer, args=args, **kwargs)
pruner_object.prune()
##---------count op
# input = torch.randn(1, 3, 32, 32).cuda()
# flops, params = profile(model, inputs=(input,), verbose=False)
# flops, params = clever_format([flops, params], "%.3f")
# flopsnew, paramsnew = profile(newmodel, inputs=(input,), verbose=False)
# flopsnew, paramsnew = clever_format([flopsnew, paramsnew], "%.3f")
# print("flops:{}->{}, params: {}->{}".format(flops, flopsnew, params, paramsnew))


accold = pruner_object.test(newmodel=False, cal_bn=False)
print("original performance:{}".format(accold))

name: module.features.0 block.layername: module.features.0
name: module.features.1 block.layername: module.features.1
name: module.features.2 block.layername: module.features.2
name: module.features.3 block.layername: module.features.3
name: module.features.4 block.layername: module.features.4
name: module.features.5 block.layername: module.features.5
name: module.features.6 block.layername: module.features.6
name: module.features.7 block.layername: module.features.7
name: module.features.8 block.layername: module.features.8
name: module.features.9 block.layername: module.features.9
name: module.features.10 block.layername: module.features.10
name: module.features.11 block.layername: module.features.11
name: module.features.12 block.layername: module.features.12
name: module.features.13 block.layername: module.features.13
name: module.features.14 block.layername: module.features.14
name: module.features.15 block.layername: module.features.15
name: module.features.16 block.layername: mo



original performance:0.00098


In [None]:


accpruned = pruner_object.test(newmodel=True)
print("pruned performance:{}".format(accpruned))

accfinetune = pruner_object.finetune()
print("finetuned:{}".format(accfinetune))

with open(join(savepath, '{}.json'.format(args.pruneratio)), 'w') as f:
    json.dump({
        'accuracy_original': accold,
        'accuracy_pruned': accpruned,
        'accuracy_finetune': accfinetune,
    }, f)