## Attention-based Feature Pruning

### 1. Dataset CIFAR10

In [1]:
import torch
from torchvision import datasets, transforms

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=256, shuffle=True, **kwargs)

Files already downloaded and verified


#### 1.1 one minimum iterator

In [2]:
idx, data = next(enumerate(test_loader))
value, label = data[0].shape, data[1].shape
value, label

(torch.Size([256, 3, 32, 32]), torch.Size([256]))

#### 1.2 one data, one batch of iterator

In [3]:
batch_idx, data_idx = 0, 6

data1 = data[batch_idx][data_idx].clone().unsqueeze(0)
data1.shape

torch.Size([1, 3, 32, 32])

### 2. Model VGG19_BN

#### 2.1 pretrained ImageNet VGG19 Model

In [4]:
import torchvision.models as models

# model = models.vgg19_bn(pretrained=True)
# model.eval()
# model

#### 2.2 define empty VGG19 Model

In [5]:
import math
import torch
import torch.nn as nn
from torch.autograd import Variable

__all__ = ['vgg']

defaultcfg = {
    11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}


class vgg(nn.Module):
    def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None, batch_norm=True):
        super(vgg, self).__init__()
        if cfg is None:
            cfg = defaultcfg[depth]

        self.feature = self.make_layers(cfg, batch_norm)

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        self.classifier = nn.Linear(cfg[-1], num_classes)
        if init_weights:
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y4

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))  # mean, std
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

model = vgg(dataset='cifar10', depth=19)
model.eval()
model

vgg(
  (feature): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilatio

#### 2.3 initialize model weights

In [6]:
model._initialize_weights()
model.feature[0].weight.data, model.feature[1].weight.data, 

(tensor([[[[-0.0795,  0.0506,  0.0250],
           [ 0.0510, -0.0756,  0.0581],
           [ 0.0328,  0.0918,  0.0331]],
 
          [[ 0.0604, -0.0323, -0.0195],
           [ 0.0309,  0.0148, -0.0514],
           [ 0.0347,  0.0042,  0.0223]],
 
          [[-0.0543,  0.0148, -0.0883],
           [ 0.0066,  0.0049,  0.0587],
           [-0.0394, -0.0388,  0.0156]]],
 
 
         [[[-0.0855, -0.0400, -0.0284],
           [ 0.0105, -0.0982,  0.0690],
           [-0.0395, -0.0255,  0.0381]],
 
          [[ 0.0857,  0.0542, -0.0269],
           [ 0.0531, -0.0918,  0.0439],
           [-0.0257, -0.1221, -0.0686]],
 
          [[ 0.0343, -0.0856,  0.0079],
           [ 0.0226, -0.0285,  0.0009],
           [ 0.0505, -0.0408,  0.0466]]],
 
 
         [[[-0.0689,  0.0781, -0.0804],
           [-0.0209, -0.0767, -0.0344],
           [-0.0110, -0.0027,  0.0779]],
 
          [[ 0.0583,  0.1620, -0.1076],
           [-0.0018,  0.0851,  0.0909],
           [ 0.0471,  0.0735,  0.0083]],
 
          

#### 2.4 load weight data

In [7]:
import torch

url_path = 'D:/Project/Pycharm/network-slimming/logs/'

baseline_path = 'baseline_vgg19_cifar10'
sparsity_path = 'sparsity_vgg19_cifar10_s_1e-4'
# fine_tune_path = 'attention_fine_tune_feature_vgg19_percent_0.7'

model_path = url_path + sparsity_path + '/model_best.pth.tar'

checkpoint = torch.load(model_path)
best_prec1 = checkpoint['best_prec1']
epoch1 = checkpoint['epoch']
if checkpoint['state_dict'] is not None:
    model.load_state_dict(checkpoint['state_dict'])
    
epoch1, best_prec1

(150, tensor(0.9347))

In [8]:
model.feature[0].weight.data, model.feature[1].weight.data

(tensor([[[[-1.3343e-04, -9.5322e-04, -1.4197e-03],
           [-1.6418e-03, -2.4356e-03, -1.2043e-03],
           [-5.6899e-04,  2.8932e-04,  6.0596e-04]],
 
          [[ 4.9176e-04,  6.0720e-05,  1.3321e-04],
           [-1.1490e-03, -1.2976e-03,  5.5458e-05],
           [ 7.4775e-05,  1.2027e-03,  1.5041e-03]],
 
          [[ 5.2284e-04, -1.3844e-04, -1.8975e-05],
           [-1.2763e-03, -1.6820e-03, -6.6350e-04],
           [-8.2281e-04, -1.8182e-04,  8.5114e-05]]],
 
 
         [[[-4.7590e-01, -4.5339e-01, -1.4134e-01],
           [-3.9807e-01,  1.5892e-01,  2.9237e-01],
           [-1.8018e-01,  5.1271e-01,  5.3557e-01]],
 
          [[ 6.7640e-02, -6.4064e-02, -2.8104e-01],
           [ 1.8280e-01,  2.8546e-01, -9.9560e-02],
           [ 5.9540e-02,  2.5474e-01,  8.9730e-02]],
 
          [[ 6.2599e-01,  5.6976e-01,  1.4212e-01],
           [ 2.8122e-01,  1.0813e-01, -3.5913e-01],
           [-3.8611e-01, -5.6656e-01, -5.2455e-01]]],
 
 
         [[[-9.8605e-02,  1.7364e-01,  1

#### 2.5 the first conv2d converting of model

In [9]:
value1 = model.feature[0](data1)
value1, value1.shape

(tensor([[[[-2.7403e-03, -8.9314e-03, -1.4230e-02,  ..., -1.3744e-02,
            -1.1991e-02, -9.3612e-03],
           [-2.9904e-03, -8.8377e-03, -1.1322e-02,  ..., -1.1168e-02,
            -8.6803e-03, -8.8492e-03],
           [-4.3897e-03, -1.1649e-02, -1.4391e-02,  ..., -1.1073e-02,
            -1.2432e-02, -1.2480e-02],
           ...,
           [-2.1857e-03, -4.0777e-03, -5.6345e-03,  ..., -1.2607e-03,
            -3.9508e-04, -1.7741e-03],
           [-5.4761e-04, -2.0297e-03, -2.7413e-03,  ..., -3.2913e-04,
            -1.3327e-04, -2.2790e-03],
           [-1.5973e-03, -2.7889e-03, -3.6301e-03,  ...,  5.6488e-04,
            -9.9868e-04, -2.7095e-03]],
 
          [[-5.8018e-01, -7.4216e-01, -5.7547e-01,  ..., -8.6251e-01,
            -7.3754e-01, -2.7124e-01],
           [-8.0247e-01, -7.4270e-01, -3.3895e-01,  ..., -5.1238e-02,
            -1.4432e-01,  6.9336e-01],
           [-7.0982e-01, -4.7751e-01, -8.3634e-02,  ..., -4.4813e-01,
            -5.7465e-01,  8.1512e-01],


#### 2.6 find BatchNorm2d & nn.ReLU converting

In [10]:
import torch.nn as nn

relu = nn.ReLU()
data_item = data1.clone()
for idx, m in enumerate(model.feature):
    data_item = m(data_item)
    if isinstance(m, nn.BatchNorm2d):
        relu_data = relu(data_item)
        print(idx, relu_data.shape)

1 torch.Size([1, 64, 32, 32])
4 torch.Size([1, 64, 32, 32])
8 torch.Size([1, 128, 16, 16])
11 torch.Size([1, 128, 16, 16])
15 torch.Size([1, 256, 8, 8])
18 torch.Size([1, 256, 8, 8])
21 torch.Size([1, 256, 8, 8])
24 torch.Size([1, 256, 8, 8])
28 torch.Size([1, 512, 4, 4])
31 torch.Size([1, 512, 4, 4])
34 torch.Size([1, 512, 4, 4])
37 torch.Size([1, 512, 4, 4])
41 torch.Size([1, 512, 2, 2])
44 torch.Size([1, 512, 2, 2])
47 torch.Size([1, 512, 2, 2])
50 torch.Size([1, 512, 2, 2])


### 3. Activation-based Gramma

In [11]:
import torch

def activation_based_gamma(weight_data):
    d1, d2 = weight_data.shape[0], weight_data.shape[1]
    
    # 1. A: feature map data
    A = weight_data.view(d1, d2, -1).abs()
    c, h, w = A.shape
    
    # 2. Fsum(A): sum of values along the channel direction
    FsumA = torch.zeros(h, w)
    for i in range(c):
        FsumA.add_(A[i])
        
    # 3. ||Fsum(A)||2: two norm
    FsumA_norm = torch.linalg.norm(FsumA)
    
    # 4. F(A) / ||F(A)||2: normalize weight data
    F_all = FsumA / FsumA_norm
    
    # 5. F(Aj) / ||F(Aj)||^2 & gamma = ∑ | F(A) / ||F(A)||2 - F(Aj) / ||F(Aj)||2 |
    gamma = torch.zeros(c)
    for j in range(c):
        FAj = FsumA - A[j]
        FAj_norm = torch.linalg.norm(FAj)
        Fj = FAj / FAj_norm
#         gamma[j] = (F_all - Fj).abs().sum()
        gamma[j] = torch.linalg.norm(F_all - Fj, ord=1)  # ord=1, 1 norm; ord=2, 2 norm

    return gamma

#### 3.1 batch size activation-based gamma

In [12]:
x = torch.randn(64, 64, 32, 32)

B, C, H, W = x.shape

gamma = torch.zeros(B)
for i in range(B):
    data = x[i].clone().squeeze(0)
    gamma += activation_based_gamma(data)
gamma_mean = gamma / B
gamma, gamma_mean

(tensor([0.7794, 0.7860, 0.7864, 0.7796, 0.7759, 0.7896, 0.7815, 0.7772, 0.7903,
         0.7936, 0.7935, 0.7807, 0.7860, 0.7867, 0.7792, 0.7824, 0.7805, 0.7873,
         0.7852, 0.7960, 0.7903, 0.7794, 0.7838, 0.7943, 0.7839, 0.7840, 0.7914,
         0.7744, 0.7844, 0.7948, 0.7838, 0.7864, 0.7851, 0.7826, 0.7848, 0.7848,
         0.7861, 0.7905, 0.7797, 0.7893, 0.7908, 0.7792, 0.7948, 0.7900, 0.7826,
         0.7862, 0.7863, 0.7889, 0.7872, 0.7882, 0.7931, 0.7862, 0.7823, 0.7839,
         0.7810, 0.7920, 0.7953, 0.7903, 0.7930, 0.7866, 0.7821, 0.7675, 0.7845,
         0.7767]),
 tensor([0.0122, 0.0123, 0.0123, 0.0122, 0.0121, 0.0123, 0.0122, 0.0121, 0.0123,
         0.0124, 0.0124, 0.0122, 0.0123, 0.0123, 0.0122, 0.0122, 0.0122, 0.0123,
         0.0123, 0.0124, 0.0123, 0.0122, 0.0122, 0.0124, 0.0122, 0.0123, 0.0124,
         0.0121, 0.0123, 0.0124, 0.0122, 0.0123, 0.0123, 0.0122, 0.0123, 0.0123,
         0.0123, 0.0124, 0.0122, 0.0123, 0.0124, 0.0122, 0.0124, 0.0123, 0.0122,
         

### 4. Prune
#### 4.1 number of channels

In [13]:
num_total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        num_total += m.weight.data.shape[0]
num_total

5504

#### 4.2 all channels' gamma

In [14]:
gamma_list = torch.zeros(num_total)
gamma_record = []

index = 0
one_batch = data1.clone()
for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_batch = m(one_batch)
    if isinstance(m, nn.BatchNorm2d):
        value = one_batch.clone().squeeze(0)
        gamma = activation_based_gamma(value)
        gamma_record.append(gamma)
        size = value.shape[0]
        gamma_list[index:(index+size)] = gamma.clone()
        index += size

In [15]:
gamma_list[:100]

tensor([9.9950e-06, 1.6465e-02, 1.2412e-02, 7.1973e-06, 6.8107e-03, 6.2145e-03,
        5.3365e-06, 1.3695e-04, 2.6244e-02, 1.0902e-05, 4.1816e-06, 7.1188e-03,
        6.0964e-06, 2.6420e-04, 4.7834e-02, 3.9112e-03, 9.8807e-03, 2.3160e-05,
        4.1816e-06, 1.0276e-04, 2.1900e-02, 5.8549e-02, 5.0701e-06, 1.8416e-02,
        2.7763e-03, 2.1625e-06, 4.9738e-02, 8.7627e-03, 1.6777e-02, 2.8115e-02,
        1.8824e-02, 1.9295e-03, 1.8411e-02, 3.6769e-06, 1.9866e-02, 9.0149e-03,
        6.4201e-03, 1.0166e-02, 1.0251e-04, 1.3096e-03, 4.5672e-06, 4.8298e-06,
        1.9475e-02, 6.1225e-06, 1.0265e-05, 7.8932e-03, 1.5965e-02, 4.5957e-03,
        5.3594e-05, 1.4863e-04, 7.6543e-03, 1.9452e-02, 4.2003e-06, 4.7777e-06,
        2.1506e-02, 2.9241e-02, 4.6156e-06, 5.8925e-03, 1.2298e-02, 3.1206e-02,
        3.0585e-06, 4.2142e-02, 3.8079e-03, 2.9044e-05, 1.1129e-02, 1.5222e-02,
        8.6545e-03, 9.9184e-03, 1.6566e-02, 1.1061e-02, 3.2084e-02, 2.2644e-02,
        1.1464e-02, 1.5363e-02, 2.0099e-

In [16]:
gamma_record

[tensor([9.9950e-06, 1.6465e-02, 1.2412e-02, 7.1973e-06, 6.8107e-03, 6.2145e-03,
         5.3365e-06, 1.3695e-04, 2.6244e-02, 1.0902e-05, 4.1816e-06, 7.1188e-03,
         6.0964e-06, 2.6420e-04, 4.7834e-02, 3.9112e-03, 9.8807e-03, 2.3160e-05,
         4.1816e-06, 1.0276e-04, 2.1900e-02, 5.8549e-02, 5.0701e-06, 1.8416e-02,
         2.7763e-03, 2.1625e-06, 4.9738e-02, 8.7627e-03, 1.6777e-02, 2.8115e-02,
         1.8824e-02, 1.9295e-03, 1.8411e-02, 3.6769e-06, 1.9866e-02, 9.0149e-03,
         6.4201e-03, 1.0166e-02, 1.0251e-04, 1.3096e-03, 4.5672e-06, 4.8298e-06,
         1.9475e-02, 6.1225e-06, 1.0265e-05, 7.8932e-03, 1.5965e-02, 4.5957e-03,
         5.3594e-05, 1.4863e-04, 7.6543e-03, 1.9452e-02, 4.2003e-06, 4.7777e-06,
         2.1506e-02, 2.9241e-02, 4.6156e-06, 5.8925e-03, 1.2298e-02, 3.1206e-02,
         3.0585e-06, 4.2142e-02, 3.8079e-03, 2.9044e-05]),
 tensor([0.0111, 0.0152, 0.0087, 0.0099, 0.0166, 0.0111, 0.0321, 0.0226, 0.0115,
         0.0154, 0.0201, 0.0132, 0.0111, 0.0261, 0

#### 4.3 threshold

In [17]:
pruning_rate = 0.7
y, i = torch.sort(gamma_list)
thre_idx = int(num_total * pruning_rate)
thre = y[thre_idx]

thre_idx, thre

(3852, tensor(0.0011))

In [18]:
y

tensor([0.0000, 0.0000, 0.0000,  ..., 0.0497, 0.0580, 0.0585])

#### 4.4 prune

In [19]:
num_pruned = 0
num_cfg = []
mask_cfg = []

one_batch = data1.clone()
for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_batch = m(one_batch)
        if isinstance(m, nn.BatchNorm2d):
            value = one_batch.clone().squeeze(0)
            gamma = activation_based_gamma(value)
            mask = gamma.gt(thre).float()
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            num_cfg.append(int(torch.sum(mask)))
            mask_cfg.append(mask.clone())
            num_pruned += mask.shape[0] - torch.sum(mask)
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                      format(k, mask.shape[0], int(torch.sum(mask))))
        
        elif isinstance(m, nn.MaxPool2d):
            mask_cfg.append('M')
    
pruned_ratio = num_pruned / num_total
print('Pre-processing Successful! Pruned ratio: ', pruned_ratio)

layer index: 1 	 total channel: 64 	 remaining channel: 38
layer index: 4 	 total channel: 64 	 remaining channel: 64
layer index: 8 	 total channel: 128 	 remaining channel: 127
layer index: 11 	 total channel: 128 	 remaining channel: 128
layer index: 15 	 total channel: 256 	 remaining channel: 252
layer index: 18 	 total channel: 256 	 remaining channel: 244
layer index: 21 	 total channel: 256 	 remaining channel: 198
layer index: 24 	 total channel: 256 	 remaining channel: 159
layer index: 28 	 total channel: 512 	 remaining channel: 129
layer index: 31 	 total channel: 512 	 remaining channel: 52
layer index: 34 	 total channel: 512 	 remaining channel: 57
layer index: 37 	 total channel: 512 	 remaining channel: 53
layer index: 41 	 total channel: 512 	 remaining channel: 28
layer index: 44 	 total channel: 512 	 remaining channel: 26
layer index: 47 	 total channel: 512 	 remaining channel: 41
layer index: 50 	 total channel: 512 	 remaining channel: 55
Pre-processing Success