# Checkpoint 1: My Pretrained VGG

In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import os, math
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
defaultcfg = {
    11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}

In [3]:
class vgg(nn.Module):
    def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None):
        super(vgg, self).__init__()
        if cfg is None:
            cfg = defaultcfg[depth]

        self.feature = self.make_layers(cfg, True)

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        self.classifier = nn.Linear(cfg[-1], num_classes)
        if init_weights:
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=True):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

In [4]:
model = vgg(dataset="cifar10", depth=11)
model.cuda()

PATH = "./vgg11_cifar10_epoch200_acc90.pth.tar"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}'".format(PATH))

=> loaded checkpoint './vgg11_cifar10_epoch200_acc90.pth.tar'


In [5]:
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        print(m)
        print(m.weight.data.shape[0])
        print(m.weight.data)
        break

BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
64
tensor([0.6629, 0.5578, 0.8722, 0.3798, 0.3158, 0.4844, 0.4109, 0.3328, 0.4253,
        0.6638, 0.4556, 0.3201, 0.4276, 0.8444, 0.5639, 0.4691, 0.4931, 0.5410,
        0.4419, 0.2959, 0.3999, 0.5731, 0.3124, 0.4125, 0.8451, 0.4995, 0.3665,
        0.2883, 0.5922, 0.6514, 0.4406, 0.2737, 0.3806, 0.5938, 0.6592, 0.6602,
        0.6936, 0.5784, 0.2629, 0.3013, 0.7561, 0.3959, 0.6716, 0.6181, 0.6050,
        0.7310, 0.6730, 0.1690, 0.5611, 0.2631, 0.3112, 0.2536, 0.5079, 0.9266,
        0.4094, 0.4232, 0.1899, 0.5662, 0.4807, 0.5908, 0.4720, 0.3315, 0.5610,
        0.2689], device='cuda:0')


# Checkpoint 2: Pruning

In [6]:
print(model)

vgg(
  (feature): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inpl

In [7]:
percent = 0.5
total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size

y, i = torch.sort(bn)
thre_index = int(total * percent)
thre = y[thre_index].cuda()

pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone().cuda()
        mask = weight_copy.gt(thre).float().cuda()
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

layer index: 3 	 total channel: 64 	 remaining channel: 32
layer index: 7 	 total channel: 128 	 remaining channel: 82
layer index: 11 	 total channel: 256 	 remaining channel: 135
layer index: 14 	 total channel: 256 	 remaining channel: 16
layer index: 18 	 total channel: 512 	 remaining channel: 185
layer index: 21 	 total channel: 512 	 remaining channel: 240
layer index: 25 	 total channel: 512 	 remaining channel: 239
layer index: 28 	 total channel: 512 	 remaining channel: 446


# Checkpoint 3: Make real prune

In [8]:
print(cfg)
print(cfg_mask[0])

[32, 'M', 82, 'M', 135, 16, 'M', 185, 240, 'M', 239, 446]
tensor([1., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 1.,
        0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 1., 1.,
        1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1.,
        0., 0., 0., 1., 1., 1., 0., 0., 1., 0.], device='cuda:0')


In [9]:
newmodel = vgg(dataset="cifar10", cfg=cfg)
newmodel.cuda()

vgg(
  (feature): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 82, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (5): BatchNorm2d(82, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(82, 135, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (9): BatchNorm2d(135, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace)
    (11): Conv2d(135, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (12): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace)


In [10]:
num_parameters = sum([param.nelement() for param in newmodel.parameters()])
num_parameters

2052596

In [11]:
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.BatchNorm2d):
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        if idx1.size == 1:
            idx1 = np.resize(idx1,(1,))
        m1.weight.data = m0.weight.data[idx1.tolist()].clone()
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()
        layer_id_in_cfg += 1
        start_mask = end_mask.clone()
        if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
            end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
        w1 = w1[idx1.tolist(), :, :, :].clone()
        m1.weight.data = w1.clone()
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()

In shape: 3, Out shape 32.
In shape: 32, Out shape 82.
In shape: 82, Out shape 135.
In shape: 135, Out shape 16.
In shape: 16, Out shape 185.
In shape: 185, Out shape 240.
In shape: 240, Out shape 239.
In shape: 239, Out shape 446.


In [13]:
NEW_PATH = os.path.join("/home/xs-caomengqi/Sparse-Network-Training-Double-Masking/network-slimming/", 'pruned_{}.pth.tar'.format(percent))
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, NEW_PATH)