## Attention-based Feature Pruning

#### Dataset CIFAR10

In [1]:
import torch
from torchvision import datasets, transforms

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=256, shuffle=True, **kwargs)

Files already downloaded and verified


##### one minimum iterator

In [2]:
idx, data = next(enumerate(test_loader))
value, label = data[0].shape, data[1].shape
value, label

(torch.Size([256, 3, 32, 32]), torch.Size([256]))

##### one data, one batch of iterator

In [3]:
batch_idx, data_idx = 0, 6

data1 = data[batch_idx][data_idx].clone().unsqueeze(0)
data1.shape

torch.Size([1, 3, 32, 32])

#### Model VGG19_BN

##### pretrained ImageNet VGG19 Model

In [4]:
import torchvision.models as models

# model = models.vgg19_bn(pretrained=True)
# model.eval()
# model

##### define empty VGG19 Model

In [5]:
import math
import torch
import torch.nn as nn
from torch.autograd import Variable

__all__ = ['vgg']

defaultcfg = {
    11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}


class vgg(nn.Module):
    def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None, batch_norm=True):
        super(vgg, self).__init__()
        if cfg is None:
            cfg = defaultcfg[depth]

        self.feature = self.make_layers(cfg, batch_norm)

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        self.classifier = nn.Linear(cfg[-1], num_classes)
        if init_weights:
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

model = vgg(dataset='cifar10', depth=19)
model.eval()
model

vgg(
  (feature): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilatio

##### load model weight

In [6]:
baseline_path = 'baseline_vgg19_cifar10'
sparsity_path = 'sparsity_vgg19_cifar10_s_1e-4'

model_path = sparsity_path + '/model_best.pth.tar'
print("=> loading checkpoint '{}'".format(model_path))

checkpoint = torch.load(model_path)
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
          .format(model_path, checkpoint['epoch'], best_prec1))

=> loading checkpoint 'sparsity_vgg19_cifar10_s_1e-4/model_best.pth.tar'
=> loaded checkpoint 'sparsity_vgg19_cifar10_s_1e-4/model_best.pth.tar' (epoch 150) Prec1: 0.934700


##### the first conv2d converting of model

In [7]:
value1 = model.feature[0](data1)
value1, value1.shape

(tensor([[[[-1.0166e-02, -2.4929e-02, -2.4767e-02,  ..., -2.3730e-02,
            -2.2874e-02, -2.4593e-02],
           [-1.6640e-02, -2.8813e-02, -2.6605e-02,  ..., -2.3197e-02,
            -2.3410e-02, -2.3978e-02],
           [-1.5304e-02, -2.3793e-02, -1.9484e-02,  ..., -2.1954e-02,
            -2.3533e-02, -2.4242e-02],
           ...,
           [ 4.7882e-03,  1.2377e-02,  1.9665e-02,  ...,  1.7346e-02,
             1.5419e-02,  9.5587e-03],
           [ 4.1717e-03,  1.2178e-02,  2.0134e-02,  ...,  1.7871e-02,
             1.7172e-02,  1.0261e-02],
           [ 1.8059e-03,  9.4177e-03,  1.3738e-02,  ...,  1.6433e-02,
             1.3187e-02,  4.8544e-03]],
 
          [[ 1.4362e+00,  2.7267e-01,  3.2212e-01,  ...,  3.0878e-01,
             5.8389e-01,  8.0896e-01],
           [ 9.9332e-01,  3.8037e-01,  4.6200e-01,  ...,  2.9928e-01,
             7.2840e-01,  1.7610e+00],
           [ 1.1223e+00,  4.3787e-01,  2.0857e-01,  ...,  2.2637e-01,
             4.9546e-01,  1.6911e+00],


##### find BatchNorm2d & nn.ReLU converting

In [8]:
import torch.nn as nn

relu = nn.ReLU()
data_item = data1.clone()
for idx, m in enumerate(model.feature):
    data_item = m(data_item)
    if isinstance(m, nn.BatchNorm2d):
        relu_data = relu(data_item)
        print(idx, relu_data.shape)

1 torch.Size([1, 64, 32, 32])
4 torch.Size([1, 64, 32, 32])
8 torch.Size([1, 128, 16, 16])
11 torch.Size([1, 128, 16, 16])
15 torch.Size([1, 256, 8, 8])
18 torch.Size([1, 256, 8, 8])
21 torch.Size([1, 256, 8, 8])
24 torch.Size([1, 256, 8, 8])
28 torch.Size([1, 512, 4, 4])
31 torch.Size([1, 512, 4, 4])
34 torch.Size([1, 512, 4, 4])
37 torch.Size([1, 512, 4, 4])
41 torch.Size([1, 512, 2, 2])
44 torch.Size([1, 512, 2, 2])
47 torch.Size([1, 512, 2, 2])
50 torch.Size([1, 512, 2, 2])


#### Activation-based Gramma

In [9]:
def activation_based_gamma(weight_data):
    d1, d2 = weight_data.shape[0], weight_data.shape[1]
    
    # 1. A: feature map data
    A = weight_data.view(d1, d2, -1).abs()
    c, h, w = A.shape
    
    # 2. Fsum(A): sum of values along the channel direction
    FsumA = torch.zeros(h, w)
    for i in range(c):
        FsumA.add_(A[i])
        
    # 3. ||Fsum(A)||2: two norm
    FsumA_norm = torch.linalg.norm(FsumA)
    
    # 4. F(A) / ||F(A)||2: normalize weight data
    F_all = FsumA / FsumA_norm
    
    # 5. F(Aj) / ||F(Aj)||^2 & gamma = ∑ | F(A) / ||F(A)||2 - F(Aj) / ||F(Aj)||2 |
    gamma = torch.zeros(c)
    for j in range(c):
        FAj = FsumA - A[j]
        FAj_norm = torch.linalg.norm(FAj)
        Fj = FAj / FAj_norm
#         gamma[j] = (F_all - Fj).abs().sum()
        gamma[j] = torch.linalg.norm(F_all - Fj, ord=1)  # ord=1, 1 norm; ord=2, 2 norm

    return gamma

#### Prune

##### number of channels

In [10]:
num_total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        num_total += m.weight.data.shape[0]
num_total    

5504

##### all channels' gamma

In [11]:
gamma_list = torch.zeros(num_total)
gamma_record = []

index = 0
one_batch = data1.clone()
for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_batch = m(one_batch)
    if isinstance(m, nn.BatchNorm2d):
        value = one_batch.clone().squeeze(0)
        gamma = activation_based_gamma(value)
        gamma_record.append(gamma)
        size = value.shape[0]
        gamma_list[index:(index+size)] = gamma.clone()
        index += size

In [12]:
gamma_list[:100]

tensor([3.7236e-04, 3.4888e-01, 2.9324e-01, 2.6871e-04, 2.3354e-01, 3.1433e-01,
        1.9982e-04, 5.1487e-03, 3.3475e-01, 4.0978e-04, 1.5021e-04, 4.0909e-01,
        2.2650e-04, 9.9328e-03, 4.7796e-01, 1.4078e-01, 5.6256e-01, 8.6726e-04,
        1.5042e-04, 3.8610e-03, 8.0518e-01, 5.5678e-01, 1.9023e-04, 4.1322e-01,
        1.0429e-01, 7.8695e-05, 6.5324e-01, 2.1161e-01, 6.0052e-01, 5.4495e-01,
        5.1282e-01, 6.3567e-02, 3.4445e-01, 1.3365e-04, 4.4360e-01, 2.7661e-01,
        1.7330e-01, 1.3560e-01, 3.8531e-03, 4.1409e-02, 1.7064e-04, 1.8007e-04,
        5.7117e-01, 2.2723e-04, 3.8275e-04, 1.6764e-01, 4.1472e-01, 1.4114e-01,
        2.0105e-03, 5.5843e-03, 2.6392e-01, 7.3227e-01, 1.5296e-04, 1.7813e-04,
        5.5396e-01, 6.1102e-01, 1.7187e-04, 2.5514e-01, 5.1704e-01, 6.7294e-01,
        1.1153e-04, 6.5012e-01, 1.2528e-01, 1.0897e-03, 2.1091e-01, 2.9682e-01,
        3.1583e-01, 1.8697e-01, 4.5234e-01, 2.3025e-01, 5.0543e-01, 2.4633e-01,
        2.3016e-01, 3.0294e-01, 2.9265e-

In [13]:
gamma_record

[tensor([3.7236e-04, 3.4888e-01, 2.9324e-01, 2.6871e-04, 2.3354e-01, 3.1433e-01,
         1.9982e-04, 5.1487e-03, 3.3475e-01, 4.0978e-04, 1.5021e-04, 4.0909e-01,
         2.2650e-04, 9.9328e-03, 4.7796e-01, 1.4078e-01, 5.6256e-01, 8.6726e-04,
         1.5042e-04, 3.8610e-03, 8.0518e-01, 5.5678e-01, 1.9023e-04, 4.1322e-01,
         1.0429e-01, 7.8695e-05, 6.5324e-01, 2.1161e-01, 6.0052e-01, 5.4495e-01,
         5.1282e-01, 6.3567e-02, 3.4445e-01, 1.3365e-04, 4.4360e-01, 2.7661e-01,
         1.7330e-01, 1.3560e-01, 3.8531e-03, 4.1409e-02, 1.7064e-04, 1.8007e-04,
         5.7117e-01, 2.2723e-04, 3.8275e-04, 1.6764e-01, 4.1472e-01, 1.4114e-01,
         2.0105e-03, 5.5843e-03, 2.6392e-01, 7.3227e-01, 1.5296e-04, 1.7813e-04,
         5.5396e-01, 6.1102e-01, 1.7187e-04, 2.5514e-01, 5.1704e-01, 6.7294e-01,
         1.1153e-04, 6.5012e-01, 1.2528e-01, 1.0897e-03]),
 tensor([0.2109, 0.2968, 0.3158, 0.1870, 0.4523, 0.2302, 0.5054, 0.2463, 0.2302,
         0.3029, 0.2926, 0.2381, 0.2657, 0.2882, 0

##### threshold

In [14]:
pruning_rate = 0.7
y, i = torch.sort(gamma_list)
thre_idx = int(num_total * pruning_rate)
thre = y[thre_idx]

thre_idx, thre

(3852, tensor(0.0043))

In [15]:
y

tensor([0.0000, 0.0000, 0.0000,  ..., 0.6729, 0.7323, 0.8052])

##### prune 

In [18]:
num_pruned = 0
num_cfg = []
mask_cfg = []

one_batch = data1.clone()
for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_batch = m(one_batch)
        if isinstance(m, nn.BatchNorm2d):
            value = one_batch.clone().squeeze(0)
            gamma = activation_based_gamma(value)
            mask = gamma.gt(thre).float()
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            num_cfg.append(int(torch.sum(mask)))
            mask_cfg.append(mask.clone())
            num_pruned += mask.shape[0] - torch.sum(mask)
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                      format(k, mask.shape[0], int(torch.sum(mask))))
        
        elif isinstance(m, nn.MaxPool2d):
            mask_cfg.append('M')
    
pruned_ratio = num_pruned / num_total
print('Pre-processing Successful! Pruned ratio: ', pruned_ratio)

layer index: 1 	 total channel: 64 	 remaining channel: 41
layer index: 4 	 total channel: 64 	 remaining channel: 64
layer index: 8 	 total channel: 128 	 remaining channel: 128
layer index: 11 	 total channel: 128 	 remaining channel: 128
layer index: 15 	 total channel: 256 	 remaining channel: 254
layer index: 18 	 total channel: 256 	 remaining channel: 249
layer index: 21 	 total channel: 256 	 remaining channel: 211
layer index: 24 	 total channel: 256 	 remaining channel: 178
layer index: 28 	 total channel: 512 	 remaining channel: 124
layer index: 31 	 total channel: 512 	 remaining channel: 54
layer index: 34 	 total channel: 512 	 remaining channel: 57
layer index: 37 	 total channel: 512 	 remaining channel: 49
layer index: 41 	 total channel: 512 	 remaining channel: 25
layer index: 44 	 total channel: 512 	 remaining channel: 25
layer index: 47 	 total channel: 512 	 remaining channel: 35
layer index: 50 	 total channel: 512 	 remaining channel: 24
Pre-processing Success