## Pruning and Expanding VGG Weight

### 1. Dataset CIFAR10

In [1]:
import torch
from torchvision import datasets, transforms

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
dataloader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=256, shuffle=True, **kwargs)

Files already downloaded and verified


#### 1.1 one batch dataset 

In [2]:
idx, data = next(enumerate(dataloader))
data_batch = data[0].clone()
value_shape, data_shape = data[0].shape, data[1].shape
value_shape, data_shape

(torch.Size([256, 3, 32, 32]), torch.Size([256]))

#### 1.2 one item of one batch dataset 

In [3]:
batch_idx, data_idx = 0, 6

data1 = data[batch_idx][data_idx].clone().unsqueeze(0)
data1.shape

torch.Size([1, 3, 32, 32])

### 2. VGG19 

#### 2.1 vgg config strategy

In [4]:
import math
import torch
import torch.nn as nn
from torch.autograd import Variable

__all__ = ['vgg']

defaultcfg = {
    11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}


class vgg(nn.Module):
    def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None, batch_norm=True):
        super(vgg, self).__init__()
        if cfg is None:
            cfg = defaultcfg[depth]

        self.feature = self.make_layers(cfg, batch_norm)

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        self.classifier = nn.Linear(cfg[-1], num_classes)
        if init_weights:
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()## Attention-based Feature Pruning

#### 2.2 instance vgg model

In [5]:
model = vgg(dataset='cifar10', depth=19)
model.eval()
model

vgg(
  (feature): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilatio

#### 2.3 initialize model weights

In [6]:
model._initialize_weights()
model.feature[0].weight.data[:1], model.feature[1].weight.data[:5]

(tensor([[[[ 0.0054, -0.0220, -0.0080],
           [-0.0397,  0.0566,  0.0228],
           [ 0.0515,  0.0953, -0.0053]],
 
          [[ 0.0369,  0.0198, -0.0199],
           [-0.0630,  0.0354,  0.0105],
           [ 0.1169,  0.0763,  0.0212]],
 
          [[-0.0721,  0.0405, -0.0013],
           [ 0.0713,  0.0088, -0.0064],
           [ 0.0437,  0.0748,  0.0370]]]]),
 tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000]))

#### 2.4 load model weight

In [1]:
import torch
import os

url_path = 'G://lizulin/network-slimming/logs/'

baseline_vgg19_cifar10 = 'baseline_vgg19_cifar10'
baseline_vgg19_cifar100 = 'baseline_vgg19_cifar100'

sparsity_vgg19_cifar10_s_1e_4 = 'sparsity_vgg19_cifar10_s_1e_4'
sparsity_resnet_cifar10_s_1e_5 = 'sparsity_resnet_cifar10_s_1e-5'

prune_vgg19_cifar10_percent_0_7 = 'prune_vgg19_cifar10_percent_0.7'

fine_tune_expand_more_vgg19_cifar100_percent_0_5 = 'fine_tune_expand_more_vgg19_cifar100_percent_0.5'
fine_tune_expand_vgg19_cifar100_percent_0_5 = 'fine_tune_expand_vgg19_cifar100_percent_0.5'
fine_tune_expand_vgg19_percent_0_7 = 'fine_tune_expand_vgg19_percent_0.7'

model_name = 'model_best.pth.tar'
pruned_name = 'pruned.pth.tar'

model_path = os.path.join(url_path, sparsity_vgg19_cifar10_s_1e_4, model_name)
pruned_path = os.path.join(url_path, prune_vgg19_cifar10_percent_0_7, pruned_name)

model_path, pruned_path

('G://lizulin/network-slimming/logs/sparsity_vgg19_cifar10_s_1e_4\\model_best.pth.tar',
 'G://lizulin/network-slimming/logs/prune_vgg19_cifar10_percent_0.7\\pruned.pth.tar')

In [2]:
checkpoint = torch.load(pruned_path)

checkpoint.keys()

dict_keys(['cfg', 'state_dict'])

In [10]:
checkpoint['state_dict']

OrderedDict([('feature.0.weight',
              tensor([[[[-0.1815, -0.5081, -0.1073],
                        [ 0.2028, -0.3799, -0.2350],
                        [ 0.4627,  0.6266,  0.4859]],
              
                       [[ 0.0844, -0.1016,  0.0830],
                        [ 0.1243, -0.5017, -0.4353],
                        [ 0.1423,  0.0837,  0.0164]],
              
                       [[ 0.2333,  0.2832,  0.3199],
                        [ 0.2092, -0.3065, -0.3616],
                        [ 0.0640, -0.0864, -0.1183]]],
              
              
                      [[[-0.1293, -0.2136, -0.1245],
                        [-0.2128, -0.3252, -0.2032],
                        [-0.1767, -0.2778, -0.1450]],
              
                       [[ 0.0395,  0.0216,  0.0500],
                        [ 0.0312,  0.0037,  0.0362],
                        [ 0.0197, -0.0210,  0.0373]],
              
                       [[ 0.1276,  0.1911,  0.1371],
                      

In [None]:
epoch1 = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])

epoch1, best_prec1

In [9]:
model.feature[0].weight.data[:1], model.feature[1].weight.data[:5]

(tensor([[[[-1.3343e-04, -9.5322e-04, -1.4197e-03],
           [-1.6418e-03, -2.4356e-03, -1.2043e-03],
           [-5.6899e-04,  2.8932e-04,  6.0596e-04]],
 
          [[ 4.9176e-04,  6.0720e-05,  1.3321e-04],
           [-1.1490e-03, -1.2976e-03,  5.5458e-05],
           [ 7.4775e-05,  1.2027e-03,  1.5041e-03]],
 
          [[ 5.2284e-04, -1.3844e-04, -1.8975e-05],
           [-1.2763e-03, -1.6820e-03, -6.6350e-04],
           [-8.2281e-04, -1.8182e-04,  8.5114e-05]]]]),
 tensor([-1.9593e-07,  2.1987e-01,  3.1219e-01, -6.0219e-08,  2.6633e-01]))

### 3. Prune Weight Alogrithm
    basic attention-based gamma

In [39]:
import torch

# 裁剪 Feature Map，包含 batch size，沿着该方向求平均值
def activation_based_gamma_batch(weight_data):
    d0, d1, d2 = weight_data.shape[0], weight_data.shape[1], weight_data.shape[2]
    A = weight_data.view(d0, d1, d2, -1)
#     return activation_based_gamma(torch.sum(A, dim=0))  # 返回所有batch的和
    return activation_based_gamma(torch.mean(A, dim=0))  # 返回所有batch的平均值


# 裁剪 Conv2d Weight ，不需要 batch size
def activation_based_gamma(weight_data):
    d1, d2 = weight_data.shape[0], weight_data.shape[1]
    
    # 1. A: feature map data
    A = weight_data.view(d1, d2, -1).abs()
    c, h, w = A.shape
    
    # 2. Fsum(A): sum of values along the channel direction
    FsumA = torch.sum(A, dim=0)
        
    # 3. ||Fsum(A)||2: two norm
    FsumA_norm = torch.linalg.norm(FsumA)
    
    # 4. F(A) / ||F(A)||2: normalize weight data
    F_all = FsumA / FsumA_norm
    
    # 5. F(Aj) / ||F(Aj)||^2 & gamma = ∑ | F(A) / ||F(A)||2 - F(Aj) / ||F(Aj)||2 |
    gamma = torch.zeros(c)
    for j in range(c):
        FAj = FsumA - A[j]
        FAj_norm = torch.linalg.norm(FAj)
        Fj = FAj / FAj_norm
#         gamma[j] = (F_all - Fj).abs().sum()   # L1-norm
        gamma[j] = torch.linalg.norm(F_all - Fj, ord=2)  # L2-norm

    return gamma

#### 3.1 observe better algorithm according to random data
    feature map(batch_size, num_channel, feature_width, feature_height)

In [28]:
x = {
    '0': torch.randn(64, 512, 4, 4), 
    '1': torch.randn(64, 512, 8, 8),
    '2': torch.randn(64, 512, 16, 16),
    
    '3': torch.randn(64, 256, 4, 4),
    '4': torch.randn(64, 256, 8, 8),
    '5': torch.randn(64, 256, 16, 16),
    
    '6': torch.randn(64, 128, 4, 4),
    '7': torch.randn(64, 128, 8, 8),
    '8': torch.randn(64, 128, 16, 16)
}

y = {
    '0': torch.randn(64, 512, 4, 4), 
    '3': torch.randn(64, 256, 4, 4),
    '6': torch.randn(64, 128, 4, 4),
    
    '1': torch.randn(64, 512, 8, 8),
    '4': torch.randn(64, 256, 8, 8),
    '7': torch.randn(64, 128, 8, 8),
    
    '2': torch.randn(64, 512, 16, 16),
    '5': torch.randn(64, 256, 16, 16),
    '8': torch.randn(64, 128, 16, 16)
}

##### 3.1.1 L1-norm 

$$
\gamma = \sum \left| \frac{F(A)}{\|F(A)\|_2} - \frac{F(A_j)}{\|F(A_j)\|_2} \right|
$$

    只要 Feature map 的`(feature_width, feature_height)`一致，那么`num_channel`的`sum`就是一样的，即在固定`(feature_width, feature_height)`的情况下，`num_channel`越大，每个channel占比的影响就越大。
    


In [38]:
print('gamma[j] = (F_all - Fj).abs().sum()')
for key in x:
    based_gamma = activation_based_gamma_batch(x[key])
    print(x[key].shape, torch.sum(based_gamma), torch.mean(based_gamma))

gamma[j] = (F_all - Fj).abs().sum()
torch.Size([64, 512, 4, 4]) tensor(2.3386) tensor(0.0046)
torch.Size([64, 512, 8, 8]) tensor(4.8243) tensor(0.0094)
torch.Size([64, 512, 16, 16]) tensor(9.6430) tensor(0.0188)
torch.Size([64, 256, 4, 4]) tensor(2.3424) tensor(0.0092)
torch.Size([64, 256, 8, 8]) tensor(4.8285) tensor(0.0189)
torch.Size([64, 256, 16, 16]) tensor(9.6434) tensor(0.0377)
torch.Size([64, 128, 4, 4]) tensor(2.3575) tensor(0.0184)
torch.Size([64, 128, 8, 8]) tensor(4.7923) tensor(0.0374)
torch.Size([64, 128, 16, 16]) tensor(9.6803) tensor(0.0756)


##### 3.1.2 L2-norm

$$
\gamma = \left\| \frac{F(A)}{\|F(A)\|_2} - \frac{F(A_j)}{\|F(A_j)\|_2} \right\|_2
$$

    与L1-norm一样，只要`(feature_width, feature_height)`一致，那么`num_channel`的`sum`就是一样的。但是L2-norm会因为`num_channel`越大，每个channel占比的影响越小。

In [43]:
print('gamma[j] = torch.linalg.norm(F_all - Fj, ord=2)')
for key in x:
    based_gamma = activation_based_gamma_batch(x[key])
    print(list(x[key].shape), torch.sum(based_gamma), torch.mean(based_gamma))

gamma[j] = torch.linalg.norm(F_all - Fj, ord=2)
[64, 512, 4, 4] tensor(0.5762) tensor(0.0011)
[64, 512, 8, 8] tensor(0.4688) tensor(0.0009)
[64, 512, 16, 16] tensor(0.3550) tensor(0.0007)
[64, 256, 4, 4] tensor(0.5748) tensor(0.0022)
[64, 256, 8, 8] tensor(0.4703) tensor(0.0018)
[64, 256, 16, 16] tensor(0.3521) tensor(0.0014)
[64, 128, 4, 4] tensor(0.5808) tensor(0.0045)
[64, 128, 8, 8] tensor(0.4646) tensor(0.0036)
[64, 128, 16, 16] tensor(0.3548) tensor(0.0028)


##### 3.1.3 Deprecated: max(sum(abs(data)), dim=0)
    等价于 gamma[j] = torch.linalg.norm(F_all - Fj, ord=1)
    已弃用，因为既不是L1-norm，也不是L2-norm。但是有个有趣的现象是，无论`num_channel`怎么变换，总的gamma值是一致的。随着num_channel的增加，每个channel的占比影响减少。
    
$$
\gamma = \sum \left\| \frac{Q^j_S}{\|Q^j_S\|_2} - \frac{Q^j_T}{\|Q^j_T\|_2} \right\|_p
$$

In [14]:
print('gamma[j] = torch.linalg.norm(F_all - Fj, ord=1)')
for key in x:
    based_gamma = activation_based_gamma_batch(x[key])
    print(x[key].shape, torch.sum(based_gamma), torch.mean(based_gamma))

gamma[j] = torch.linalg.norm(F_all - Fj, ord=1)
torch.Size([64, 512, 4, 4]) tensor(0.5806) tensor(0.0011)
torch.Size([64, 256, 4, 4]) tensor(0.5726) tensor(0.0022)
torch.Size([64, 128, 4, 4]) tensor(0.5767) tensor(0.0045)
torch.Size([64, 512, 8, 8]) tensor(0.4651) tensor(0.0009)
torch.Size([64, 256, 8, 8]) tensor(0.4694) tensor(0.0018)
torch.Size([64, 128, 8, 8]) tensor(0.4710) tensor(0.0037)
torch.Size([64, 512, 16, 16]) tensor(0.3537) tensor(0.0007)
torch.Size([64, 256, 16, 16]) tensor(0.3520) tensor(0.0014)
torch.Size([64, 128, 16, 16]) tensor(0.3531) tensor(0.0028)


#### 3.2 observe one batch or all batches gamma

##### 3.2.1 one batch gamma 

In [15]:
# x = torch.normal(10, 5, size=(32, 512, 2, 2))
x = torch.randn(64, 128, 16, 16)

batch_size, num_channel, input_width, input_height = x.shape

gamma = torch.zeros(num_channel)
for i in range(batch_size):
    data = x[i].clone().squeeze(0)
    gamma += activation_based_gamma(data)

gamma.shape, gamma

(torch.Size([128]),
 tensor([0.1797, 0.1796, 0.1793, 0.1787, 0.1769, 0.1758, 0.1793, 0.1770, 0.1769,
         0.1754, 0.1740, 0.1781, 0.1760, 0.1754, 0.1771, 0.1798, 0.1789, 0.1761,
         0.1767, 0.1754, 0.1746, 0.1763, 0.1779, 0.1781, 0.1758, 0.1760, 0.1770,
         0.1767, 0.1759, 0.1763, 0.1778, 0.1741, 0.1752, 0.1749, 0.1767, 0.1782,
         0.1773, 0.1798, 0.1775, 0.1763, 0.1800, 0.1773, 0.1756, 0.1740, 0.1770,
         0.1760, 0.1787, 0.1769, 0.1795, 0.1793, 0.1801, 0.1786, 0.1742, 0.1767,
         0.1770, 0.1750, 0.1760, 0.1795, 0.1772, 0.1769, 0.1791, 0.1743, 0.1756,
         0.1762, 0.1774, 0.1743, 0.1791, 0.1765, 0.1787, 0.1805, 0.1740, 0.1772,
         0.1776, 0.1803, 0.1773, 0.1769, 0.1749, 0.1756, 0.1786, 0.1778, 0.1784,
         0.1789, 0.1770, 0.1782, 0.1748, 0.1759, 0.1761, 0.1763, 0.1796, 0.1774,
         0.1755, 0.1754, 0.1741, 0.1804, 0.1771, 0.1766, 0.1756, 0.1769, 0.1805,
         0.1785, 0.1798, 0.1780, 0.1772, 0.1745, 0.1796, 0.1763, 0.1780, 0.1767,
        

##### 3.2.2 all batches' gamma

In [16]:
gamma = activation_based_gamma_batch(x)
gamma.shape, gamma

(torch.Size([128]),
 tensor([0.0028, 0.0030, 0.0027, 0.0028, 0.0033, 0.0026, 0.0032, 0.0026, 0.0027,
         0.0028, 0.0025, 0.0026, 0.0027, 0.0023, 0.0032, 0.0027, 0.0029, 0.0027,
         0.0024, 0.0032, 0.0030, 0.0025, 0.0029, 0.0028, 0.0026, 0.0026, 0.0028,
         0.0027, 0.0026, 0.0026, 0.0026, 0.0029, 0.0028, 0.0024, 0.0024, 0.0028,
         0.0027, 0.0028, 0.0027, 0.0030, 0.0028, 0.0028, 0.0026, 0.0026, 0.0028,
         0.0026, 0.0027, 0.0025, 0.0029, 0.0029, 0.0032, 0.0024, 0.0028, 0.0024,
         0.0029, 0.0026, 0.0031, 0.0031, 0.0027, 0.0033, 0.0026, 0.0028, 0.0032,
         0.0028, 0.0029, 0.0030, 0.0026, 0.0030, 0.0028, 0.0025, 0.0029, 0.0029,
         0.0030, 0.0029, 0.0029, 0.0030, 0.0028, 0.0025, 0.0031, 0.0027, 0.0030,
         0.0028, 0.0028, 0.0028, 0.0028, 0.0028, 0.0029, 0.0027, 0.0025, 0.0030,
         0.0027, 0.0028, 0.0028, 0.0028, 0.0026, 0.0031, 0.0028, 0.0029, 0.0027,
         0.0029, 0.0026, 0.0026, 0.0023, 0.0029, 0.0026, 0.0024, 0.0023, 0.0027,
        

### 4. Pruning Progress

#### 4.1 number of channel & weight

In [17]:
num_channel, num_weight = 0, 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        num_channel += m.weight.data.shape[0]
    elif isinstance(m, nn.Conv2d):
        num_weight += m.weight.data.shape[0]
        
num_channel, num_weight

(5504, 5504)

#### 4.2 Pruning the Conv2d Weight

##### 4.2.1 conv2d gamma

In [18]:
gamma_weight = torch.zeros(num_weight)
gamma_weight_record = []

index = 0
for k, m in enumerate(model.feature):
    if isinstance(m, nn.Conv2d):
        gamma = activation_based_gamma(m.weight.data)
        gamma_weight_record.append(gamma)
        
        size = m.weight.data.shape[0]
        gamma_weight[index:(index+size)] = gamma.clone()
        index += size

gamma_weight[:100], gamma_weight_record[:10]

(tensor([4.5326e-05, 1.7719e-02, 8.4088e-03, 1.7871e-05, 7.4207e-03, 1.1284e-02,
         5.4985e-06, 2.1443e-03, 2.4170e-02, 1.6162e-05, 6.6668e-06, 1.2856e-02,
         6.0110e-06, 7.9992e-04, 2.8856e-02, 1.0777e-02, 7.6803e-03, 5.2634e-05,
         5.9979e-06, 3.4912e-04, 1.2163e-02, 2.0054e-02, 8.0653e-06, 2.4019e-02,
         5.8209e-03, 1.1782e-05, 2.2681e-02, 8.0056e-03, 2.7430e-02, 7.5525e-03,
         2.2810e-02, 3.2574e-03, 1.5393e-02, 5.8843e-06, 1.8252e-02, 1.9143e-02,
         4.4849e-03, 1.3133e-02, 3.9484e-04, 2.5633e-03, 5.5649e-06, 2.7661e-05,
         3.3130e-02, 1.8082e-05, 2.5268e-05, 1.3858e-02, 2.1784e-02, 2.6519e-03,
         2.0591e-04, 2.3056e-03, 4.8556e-03, 3.8576e-02, 8.4763e-06, 7.4939e-06,
         3.4897e-02, 1.5302e-02, 8.3965e-06, 2.0118e-02, 4.3423e-03, 3.8959e-02,
         9.1252e-06, 3.5502e-02, 7.8835e-03, 6.1306e-05, 9.4059e-03, 1.0069e-02,
         1.0224e-02, 1.1451e-02, 5.1413e-03, 1.2001e-02, 2.0187e-02, 1.0213e-02,
         1.2197e-02, 1.3964e

##### 4.2.2 threshold

In [19]:
pruning_rate = 0.7
y, i = torch.sort(gamma_weight)
thre_index = int(num_weight * pruning_rate)
thre = y[thre_index]

thre_index, thre

(3852, tensor(0.0015))

##### 4.2.3 pruning

In [20]:
num_pruned = 0  # 裁剪个数
cfg_num = []

for k, m in enumerate(model.feature):
    if isinstance(m, nn.Conv2d):
        gamma = activation_based_gamma(m.weight.data)
        mask = gamma.gt(thre).float()
        
        num_remain = int(torch.sum(mask))
        num_pruned += mask.shape[0] - num_remain
        
        cfg_num.append(num_remain)
        
        print('layer index: {:d}\t total channel: {:d}\tremaining channel: {:d}'.
                      format(k, mask.shape[0], num_remain))

pruned_ratio = num_pruned / num_weight
print('Pruned ratio: {:4f}\r\n'.format(pruned_ratio))

layer index: 0	 total channel: 64	remaining channel: 40
layer index: 3	 total channel: 64	remaining channel: 64
layer index: 7	 total channel: 128	remaining channel: 127
layer index: 10	 total channel: 128	remaining channel: 128
layer index: 14	 total channel: 256	remaining channel: 84
layer index: 17	 total channel: 256	remaining channel: 88
layer index: 20	 total channel: 256	remaining channel: 137
layer index: 23	 total channel: 256	remaining channel: 158
layer index: 27	 total channel: 512	remaining channel: 123
layer index: 30	 total channel: 512	remaining channel: 88
layer index: 33	 total channel: 512	remaining channel: 104
layer index: 36	 total channel: 512	remaining channel: 98
layer index: 40	 total channel: 512	remaining channel: 97
layer index: 43	 total channel: 512	remaining channel: 105
layer index: 46	 total channel: 512	remaining channel: 119
layer index: 49	 total channel: 512	remaining channel: 91
Pruned ratio: 0.700036



#### 4.3 Pruning the Feature Map

##### 4.3.1 one batch gamma

In [21]:
one_item = data1.clone()
gamma_feature_item = torch.zeros(num_channel)
gamma_feature_item_record = []

index = 0
for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_item = m(one_item)
        
    if isinstance(m, nn.ReLU):
        value = one_item.clone().squeeze(0)
        gamma = activation_based_gamma(value)
        gamma_feature_item_record.append(gamma)
        
        size = value.shape[0]
        gamma_feature_item[index:(index+size)] = gamma.clone()
        index += size
    
gamma_feature_item_record[:10]

[tensor([0.0000, 0.0178, 0.0036, 0.0000, 0.0145, 0.0092, 0.0000, 0.0000, 0.0169,
         0.0000, 0.0000, 0.0118, 0.0000, 0.0000, 0.0200, 0.0063, 0.0056, 0.0000,
         0.0000, 0.0000, 0.0067, 0.0158, 0.0000, 0.0307, 0.0000, 0.0000, 0.0290,
         0.0069, 0.0277, 0.0029, 0.0178, 0.0021, 0.0156, 0.0000, 0.0270, 0.0092,
         0.0072, 0.0059, 0.0000, 0.0011, 0.0000, 0.0000, 0.0288, 0.0000, 0.0000,
         0.0048, 0.0179, 0.0059, 0.0000, 0.0000, 0.0040, 0.0197, 0.0000, 0.0000,
         0.0233, 0.0188, 0.0000, 0.0183, 0.0013, 0.0306, 0.0000, 0.0262, 0.0039,
         0.0000]),
 tensor([0.0089, 0.0257, 0.0184, 0.0111, 0.0057, 0.0117, 0.0248, 0.0110, 0.0136,
         0.0206, 0.0160, 0.0184, 0.0122, 0.0098, 0.0102, 0.0091, 0.0141, 0.0108,
         0.0117, 0.0173, 0.0129, 0.0313, 0.0105, 0.0233, 0.0127, 0.0110, 0.0084,
         0.0105, 0.0215, 0.0218, 0.0220, 0.0232, 0.0030, 0.0163, 0.0091, 0.0256,
         0.0249, 0.0130, 0.0128, 0.0108, 0.0115, 0.0095, 0.0003, 0.0077, 0.0175,
         

In [22]:
pruning_rate = 0.7
y, i = torch.sort(gamma_feature_item)
thre_idx = int(num_channel * pruning_rate)
thresold_feature_item = y[thre_index]

thre_index, thresold_feature_item

(3852, tensor(0.0011))

In [23]:
one_item = data1.clone()
num_pruned = 0  # 裁剪个数
cfg_num = []

for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_item = m(one_item)
    
    if isinstance(m, nn.ReLU):
        value = one_item.clone().squeeze(0)
        gamma = activation_based_gamma(value)
        mask = gamma.gt(thresold_feature_item).float()
        
        num_remain = int(torch.sum(mask))
        num_pruned += mask.shape[0] - num_remain
        
        cfg_num.append(num_remain)
        
        print('layer index: {:d}\t shape: {} \t total channel: {:d}\t remaining channel: {:d}'.
                      format(k, list(value.shape), mask.shape[0], num_remain))

pruned_ratio = num_pruned / num_channel
print('Pruned ratio: {:4f}\r\n'.format(pruned_ratio))

layer index: 2	 shape: [64, 32, 32] 	 total channel: 64	 remaining channel: 37
layer index: 5	 shape: [64, 32, 32] 	 total channel: 64	 remaining channel: 63
layer index: 9	 shape: [128, 16, 16] 	 total channel: 128	 remaining channel: 126
layer index: 12	 shape: [128, 16, 16] 	 total channel: 128	 remaining channel: 128
layer index: 16	 shape: [256, 8, 8] 	 total channel: 256	 remaining channel: 243
layer index: 19	 shape: [256, 8, 8] 	 total channel: 256	 remaining channel: 229
layer index: 22	 shape: [256, 8, 8] 	 total channel: 256	 remaining channel: 197
layer index: 25	 shape: [256, 8, 8] 	 total channel: 256	 remaining channel: 149
layer index: 29	 shape: [512, 4, 4] 	 total channel: 512	 remaining channel: 112
layer index: 32	 shape: [512, 4, 4] 	 total channel: 512	 remaining channel: 64
layer index: 35	 shape: [512, 4, 4] 	 total channel: 512	 remaining channel: 68
layer index: 38	 shape: [512, 4, 4] 	 total channel: 512	 remaining channel: 79
layer index: 42	 shape: [512, 2,

##### 4.3.2 one batch average gamma

In [24]:
one_batch = data_batch.clone()
gamma_feature_batch = torch.zeros(num_channel)
gamma_feature_batch_record = []

index = 0
for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_batch = m(one_batch)

    if isinstance(m, nn.ReLU):
        value = one_batch.clone()
        gamma = activation_based_gamma_batch(value)
        gamma_feature_batch_record.append(gamma)
        
        size = value.shape[1]
        gamma_feature_batch[index:(index+size)] = gamma.clone()
        index += size

gamma_feature_batch[:100], gamma_feature_batch_record[:10]

(tensor([0.0000, 0.0024, 0.0010, 0.0000, 0.0011, 0.0012, 0.0000, 0.0000, 0.0026,
         0.0000, 0.0000, 0.0028, 0.0000, 0.0000, 0.0033, 0.0015, 0.0011, 0.0000,
         0.0000, 0.0000, 0.0017, 0.0034, 0.0000, 0.0048, 0.0000, 0.0000, 0.0049,
         0.0029, 0.0036, 0.0007, 0.0024, 0.0004, 0.0027, 0.0000, 0.0040, 0.0024,
         0.0016, 0.0009, 0.0000, 0.0003, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,
         0.0014, 0.0023, 0.0012, 0.0000, 0.0000, 0.0020, 0.0033, 0.0000, 0.0000,
         0.0027, 0.0050, 0.0000, 0.0030, 0.0007, 0.0031, 0.0000, 0.0066, 0.0009,
         0.0000, 0.0016, 0.0057, 0.0037, 0.0027, 0.0015, 0.0057, 0.0069, 0.0033,
         0.0037, 0.0032, 0.0029, 0.0046, 0.0033, 0.0027, 0.0028, 0.0014, 0.0047,
         0.0064, 0.0023, 0.0021, 0.0048, 0.0072, 0.0035, 0.0051, 0.0081, 0.0019,
         0.0025, 0.0031, 0.0039, 0.0072, 0.0067, 0.0038, 0.0027, 0.0065, 0.0027,
         0.0075]),
 [tensor([0.0000, 0.0024, 0.0010, 0.0000, 0.0011, 0.0012, 0.0000, 0.0000, 0.0026,
        

In [25]:
prunning_rate = 0.7
y, i = torch.sort(gamma_feature_batch)
thre_idx = int(num_channel * pruning_rate)
threshold_feature_batch = y[thre_idx]

thre_idx, threshold_feature_batch

(3852, tensor(0.0008))

In [26]:
one_batch = data_batch.clone()
num_pruned = 0  # 裁剪个数
cfg_num = []

for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_batch = m(one_batch)
    
    if isinstance(m, nn.ReLU):
        value = one_batch.clone()
        gamma = activation_based_gamma_batch(value)
        mask = gamma.gt(threshold_feature_batch).float()
        
        num_remain = int(torch.sum(mask))
        num_pruned += mask.shape[0] - num_remain
        
        cfg_num.append(num_remain)
        
        print('layer index: {:d} shape: {} total channel: {:d} \t remaining channel: {:d}'.
              format(k, list(value.shape), mask.shape[0], num_remain))

pruned_ratio = num_pruned / num_channel
print('Pruned ratio: {:4f}\r\n'.format(pruned_ratio))

layer index: 2 shape: [256, 64, 32, 32] total channel: 64 	 remaining channel: 33
layer index: 5 shape: [256, 64, 32, 32] total channel: 64 	 remaining channel: 64
layer index: 9 shape: [256, 128, 16, 16] total channel: 128 	 remaining channel: 124
layer index: 12 shape: [256, 128, 16, 16] total channel: 128 	 remaining channel: 128
layer index: 16 shape: [256, 256, 8, 8] total channel: 256 	 remaining channel: 236
layer index: 19 shape: [256, 256, 8, 8] total channel: 256 	 remaining channel: 237
layer index: 22 shape: [256, 256, 8, 8] total channel: 256 	 remaining channel: 205
layer index: 25 shape: [256, 256, 8, 8] total channel: 256 	 remaining channel: 163
layer index: 29 shape: [256, 512, 4, 4] total channel: 512 	 remaining channel: 125
layer index: 32 shape: [256, 512, 4, 4] total channel: 512 	 remaining channel: 64
layer index: 35 shape: [256, 512, 4, 4] total channel: 512 	 remaining channel: 65
layer index: 38 shape: [256, 512, 4, 4] total channel: 512 	 remaining channel: