In [24]:
import torch
from models.vgg_cifar import MaskVGG,VGG
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import os
from lib.utils import AverageMeter,accuracy
import torch.nn as nn
import time
import copy

In [7]:

vgg = VGG('vgg16').cuda()
vgg.load_state_dict(torch.load('C:\\Users\\lenovo\\Desktop\\cacp\\amc_vgg\\checkpoints\\vgg16_cifar10.pt')['state_dict'])

<All keys matched successfully>

In [8]:
def get_split_dataset(dset_name, batch_size, n_worker, val_size, data_root='../data',
                      use_real_val=False, shuffle=True):
    '''
        split the train set into train / val for rl search
    '''
    if shuffle:
        index_sampler = SubsetRandomSampler
    else:  # every time we use the same order for the split subset
        class SubsetSequentialSampler(SubsetRandomSampler):
            def __iter__(self):
                return (self.indices[i] for i in torch.arange(len(self.indices)).int())
        index_sampler = SubsetSequentialSampler

    print('=> Preparing data: {}...'.format(dset_name))
    if dset_name == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_train)
        if use_real_val:  # split the actual val set
            valset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test)
            n_val = len(valset)
            assert val_size < n_val
            indices = list(range(n_val))
            np.random.shuffle(indices)
            _, val_idx = indices[val_size:], indices[:val_size]
            train_idx = list(range(len(trainset)))  # all train set for train
        else:  # split the train set
            valset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_test)
            n_train = len(trainset)
            indices = list(range(n_train))
            # now shuffle the indices
            np.random.shuffle(indices)
            assert val_size < n_train
            train_idx, val_idx = indices[val_size:], indices[:val_size]

        train_sampler = index_sampler(train_idx)
        val_sampler = index_sampler(val_idx)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, sampler=train_sampler,
                                                   num_workers=n_worker, pin_memory=True)
        val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, sampler=val_sampler,
                                                 num_workers=n_worker, pin_memory=True)
        n_class = 10
        
    elif dset_name == 'imagenet':
        train_dir = os.path.join(data_root, 'train')
        val_dir = os.path.join(data_root, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        input_size = 224
        train_transform = transforms.Compose([
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        test_transform = transforms.Compose([
                transforms.Resize(int(input_size/0.875)),
                transforms.CenterCrop(input_size),
                transforms.ToTensor(),
                normalize,
            ])

        trainset = datasets.ImageFolder(train_dir, train_transform)
        if use_real_val:
            valset = datasets.ImageFolder(val_dir, test_transform)
            n_val = len(valset)
            assert val_size < n_val
            indices = list(range(n_val))
            np.random.shuffle(indices)
            _, val_idx = indices[val_size:], indices[:val_size]
            train_idx = list(range(len(trainset)))  # all trainset
        else:
            valset = datasets.ImageFolder(train_dir, test_transform)
            n_train = len(trainset)
            indices = list(range(n_train))
            np.random.shuffle(indices)
            assert val_size < n_train
            train_idx, val_idx = indices[val_size:], indices[:val_size]

        train_sampler = index_sampler(train_idx)
        val_sampler = index_sampler(val_idx)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler,
                                                   num_workers=n_worker, pin_memory=True)
        val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, sampler=val_sampler,
                                                 num_workers=n_worker, pin_memory=True)

        n_class = 1000
    else:
        raise NotImplementedError

    return train_loader, val_loader, n_class

In [9]:
def _validate(val_loader, model, verbose=True):
        '''
        Validate the performance on validation set
        :param val_loader:
        :param model:
        :param verbose:
        :return:
        '''
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        criterion = nn.CrossEntropyLoss().cuda()
        # switch to evaluate mode
        model.eval()
        end = time.time()

        t1 = time.time()
        with torch.no_grad():
            for i, (input, target) in enumerate(val_loader):
                target = target.cuda(non_blocking=True)
                input_var = torch.autograd.Variable(input).cuda()
                target_var = torch.autograd.Variable(target).cuda()

                # compute output
                output = model(input_var)
                loss = criterion(output, target_var)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
                losses.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
        t2 = time.time()
        if verbose:
            print('* Test loss: %.3f    top1: %.3f    top5: %.3f    time: %.3f' %
                  (losses.avg, top1.avg, top5.avg, t2 - t1))
        # if self.acc_metric == 'acc1':
        #     return top1.avg
        # elif self.acc_metric == 'acc5':
        #     return top5.avg
        # else:
        #     raise NotImplementedError

In [11]:

train_loader, val_loader, n_class = get_split_dataset('cifar10', 50,
                                                                        0, 5000,
                                                                        data_root='C:\\Users\\lenovo\\dataset\\cifar',
                                                                        use_real_val=True,
                                                                        shuffle=False)  # same sampling
_validate(val_loader,vgg)

=> Preparing data: cifar10...
Files already downloaded and verified
Files already downloaded and verified
* Test loss: 0.296    top1: 93.180    top5: 99.820    time: 4.259


In [14]:
def vgg_masked(vgg_origin,strategy):
    from models.vgg_cifar import MaskVGG
    masked_vgg = MaskVGG('vgg16',strategy)
    
    orimo_ls = list(vgg_origin.modules())
    
    for i,m in enumerate(masked_vgg.modules()):
        if type(m) in [nn.Conv2d,nn.BatchNorm2d,nn.Linear,nn.AvgPool2d,nn.MaxPool2d,nn.ReLU]:
            # type
            ty = type(m)

            # conv
            if ty == nn.Conv2d:
                m.weight.data.copy_(orimo_ls[i].weight.data)
                m.bias.data.copy_(orimo_ls[i].bias.data)
            # bn2d
            elif ty == nn.BatchNorm2d:
                m.weight.data.copy_(orimo_ls[i].weight.data)
                m.bias.data.copy_(orimo_ls[i].bias.data)
                m.running_mean.data.copy_(orimo_ls[i].running_mean.data)
                m.running_var.data.copy_(orimo_ls[i].running_var.data)
            elif ty == nn.Linear:
                # linear
                m.weight.data.copy_(orimo_ls[i].weight.data)
            else:# maxpool,avgpool,relu don't need params
                pass
    masked_vgg = masked_vgg.cuda()
    return masked_vgg

In [13]:
mask = vgg_masked(vgg,strategy=[1.0]*13)
_validate(val_loader,mask)

* Test loss: 0.296    top1: 93.180    top5: 99.800    time: 2.665


In [15]:
all_idx = []
all_ops= []
prunable_idx = []
prunable_ops = []
orgin_channel = []

In [16]:

for i,m in enumerate(vgg.modules()):
    if type(m) in [nn.Conv2d,nn.BatchNorm2d,nn.Linear,nn.AvgPool2d,nn.MaxPool2d,nn.ReLU]:
        all_idx.append(i)
        all_ops.append(m)
        if type(m) in [nn.Conv2d]:
            prunable_idx.append(i)
            prunable_ops.append(m)
            orgin_channel.append(m.out_channels)

In [17]:
def idx2idxx(idx):
    return prunable_idx[idx]
def idxx2idx(idxx):
    return prunable_idx.index(idxx)


In [18]:
def preprocess_get_mask(select,method = 'l1'):
    mask = []
    for i,a in enumerate(select):
        c = orgin_channel[i]
        d = int(c * a)
        mask_ = np.zeros(c,bool)
        weight = prunable_ops[i].weight.data.cpu().numpy()
        if method == 'l1':
            importance = np.abs(weight).sum((1, 2, 3))
            sorted_idx = np.argsort(-importance)  # sum magnitude along C_in, sort descend
            preserve_idx = sorted_idx[:d]  # to preserve index
            mask_[preserve_idx] = True
        mask.append(mask_)
    return mask

In [20]:
select = [0.5]*13

In [25]:
prunable_idx = []
prunable_ops = []
layer_type_dict = {}
org_channels = {}
conv_buffer_dict = {} # layer after the conv
all_idx = []
buffer_conv_map = {}

i=0
buffer_temp_idx = []
modules = list(vgg.modules())
n = len(modules)
while i < n :
    m = modules[i]

    if type(m) not in [nn.Conv2d,nn.BatchNorm2d,nn.Linear,nn.AvgPool2d,nn.MaxPool2d,nn.ReLU]:
        i+=1
        continue
    else:
        assert type(m) == torch.nn.modules.conv.Conv2d
        prunable_ops.append(m)
        prunable_idx.append(i)
        layer_type_dict[i] = type(m)
        org_channels[i] = (m.in_channels) 
        
        buffer_temp_idx = []
        while i != n-1:
            i+=1
            bu = modules[i]
            if type(bu) is torch.nn.modules.conv.Conv2d:
                i-=1
                break
            buffer_temp_idx.append(i)
        conv_buffer_dict[prunable_idx[-1]] = copy.deepcopy(buffer_temp_idx)
        for j in buffer_temp_idx:
            buffer_conv_map[j] = prunable_idx[-1]
    i+=1


In [56]:
def pruned_model(origin_model,pruned_model,all_mask):
    m_list = list(origin_model.modules())
    mp_list = list(pruned_model.modules())
    st = time.time()
    for idx,idxx in enumerate(prunable_idx):
        
        # replace conv first
        mask = all_mask[idx]
        weight = m_list[idxx].weight.data.cpu().numpy()
        bias = m_list[idxx].bias.data.cpu().numpy()
        
        mask_weight = None
        if idx == 0:
            mask_weight = weight[mask,:,:,:]
        else:
            input_mask = all_mask[idx-1]
            # select input
            mask_weight = weight[:,input_mask,:,:].reshape(weight.shape[0],-1,weight.shape[2],weight.shape[3])
            # select output
            mask_weight = mask_weight[mask,:,:,:].reshape(-1,mask_weight.shape[1],mask_weight.shape[2],mask_weight.shape[3])
        mask_bias = bias[mask]
        mp = mp_list[idxx]
        mp.weight.data.copy_(torch.from_numpy(mask_weight).cuda())
        mp.bias.data.copy_(torch.from_numpy(mask_bias).cuda())

        # replace other layers
        buffer = conv_buffer_dict[idxx]
        for buffer_idx in buffer:
            m = m_list[buffer_idx]
            mp = mp_list[buffer_idx]
            if type(m) == nn.BatchNorm2d:
                mp.weight.data.copy_(torch.from_numpy(m.weight.data.cpu().numpy()[mask]).cuda())
                mp.bias.data.copy_(torch.from_numpy(m.bias.data.cpu().numpy()[mask]).cuda())
                mp.running_mean.data.copy_(torch.from_numpy(m.running_mean.data.cpu().numpy()[mask]).cuda())
                mp.running_var.data.copy_(torch.from_numpy(m.running_var.data.cpu().numpy()[mask]).cuda())
            elif type(m) == nn.Linear:
                mp.weight.data.copy_(torch.from_numpy(m.weight.data.cpu().numpy()[:,mask]).cuda())
    ed = time.time()
    print(f'replace cost {ed-st}s')

In [57]:
select = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
mask = preprocess_get_mask(select)
# mask = vgg_masked(vgg,strategy=select)
from models.vgg_cifar import MaskVGG
masked_vgg = MaskVGG('vgg16',select).cuda()
_validate(val_loader,masked_vgg)
pruned_model(vgg,masked_vgg,mask)
_validate(val_loader,masked_vgg)

* Test loss: 2.303    top1: 9.860    top5: 49.980    time: 2.647
replace cost 0.19254016876220703s
* Test loss: 0.737    top1: 83.040    top5: 98.120    time: 2.541
