## Attention-based Feature Pruning

### 1. Dataset CIFAR10

In [1]:
import torch
from torchvision import datasets, transforms

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=256, shuffle=True, **kwargs)

Files already downloaded and verified


#### 1.1 one minimum iterator

In [2]:
idx, data = next(enumerate(test_loader))
value, label = data[0].shape, data[1].shape
value, label

KeyboardInterrupt: 

#### 1.2 one data, one batch of iterator

In [None]:
batch_idx, data_idx = 0, 6

data1 = data[batch_idx][data_idx].clone().unsqueeze(0)
data1.shape

### 2. Model VGG19_BN

#### 2.1 pretrained ImageNet VGG19 Model

In [None]:
import torchvision.models as models

# model = models.vgg19_bn(pretrained=True)
# model.eval()
# model

#### 2.2 define empty VGG19 Model

In [None]:
import math
import torch
import torch.nn as nn
from torch.autograd import Variable

__all__ = ['vgg']

defaultcfg = {
    11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}


class vgg(nn.Module):
    def __init__(self, dataset='cifar10', depth=19, init_weights=True, cfg=None, batch_norm=True):
        super(vgg, self).__init__()
        if cfg is None:
            cfg = defaultcfg[depth]

        self.feature = self.make_layers(cfg, batch_norm)

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        self.classifier = nn.Linear(cfg[-1], num_classes)
        if init_weights:
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y4

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))  # mean, std
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

model = vgg(dataset='cifar10', depth=19)
model.eval()
model

#### 2.3 initialize model weights

In [None]:
model._initialize_weights()
model.feature[0].weight.data, model.feature[1].weight.data, 

#### 2.4 load weight data

In [None]:
import torch

url_path = 'D:/Project/Pycharm/network-slimming/logs/'

baseline_path = 'baseline_vgg19_cifar10'
sparsity_path = 'sparsity_vgg19_cifar10_s_1e-4'
# fine_tune_path = 'attention_fine_tune_feature_vgg19_percent_0.7'

model_path = url_path + sparsity_path + '/model_best.pth.tar'

checkpoint = torch.load(model_path)
best_prec1 = checkpoint['best_prec1']
epoch1 = checkpoint['epoch']
if checkpoint['state_dict'] is not None:
    model.load_state_dict(checkpoint['state_dict'])
    
epoch1, best_prec1

In [None]:
model.feature[0].weight.data, model.feature[1].weight.data

#### 2.5 the first conv2d converting of model

In [None]:
value1 = model.feature[0](data1)
value1, value1.shape

#### 2.6 find BatchNorm2d & nn.ReLU converting

In [None]:
import torch.nn as nn

relu = nn.ReLU()
data_item = data1.clone()
for idx, m in enumerate(model.feature):
    data_item = m(data_item)
    if isinstance(m, nn.BatchNorm2d):
        relu_data = relu(data_item)
        print(idx, relu_data.shape)

### 3. Activation-based Gramma

In [None]:
import torch

def activation_based_gamma(weight_data):
    d1, d2 = weight_data.shape[0], weight_data.shape[1]
    
    # 1. A: feature map data
    A = weight_data.view(d1, d2, -1).abs()
    c, h, w = A.shape
    
    # 2. Fsum(A): sum of values along the channel direction
    FsumA = torch.zeros(h, w)
    for i in range(c):
        FsumA.add_(A[i])
        
    # 3. ||Fsum(A)||2: two norm
    FsumA_norm = torch.linalg.norm(FsumA)
    
    # 4. F(A) / ||F(A)||2: normalize weight data
    F_all = FsumA / FsumA_norm
    
    # 5. F(Aj) / ||F(Aj)||^2 & gamma = ∑ | F(A) / ||F(A)||2 - F(Aj) / ||F(Aj)||2 |
    gamma = torch.zeros(c)
    for j in range(c):
        FAj = FsumA - A[j]
        FAj_norm = torch.linalg.norm(FAj)
        Fj = FAj / FAj_norm
#         gamma[j] = (F_all - Fj).abs().sum()
        gamma[j] = torch.linalg.norm(F_all - Fj, ord=1)  # ord=1, 1 norm; ord=2, 2 norm

    return gamma

#### 3.1 batch size activation-based gamma

In [None]:
x = torch.randn(64, 64, 32, 32)

B, C, H, W = x.shape

gamma = torch.zeros(B)
for i in range(B):
    data = x[i].clone().squeeze(0)
    gamma += activation_based_gamma(data)
gamma_mean = gamma / B
gamma, gamma_mean

### 4. Prune
#### 4.1 number of channels

In [None]:
num_total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        num_total += m.weight.data.shape[0]
num_total    

#### 4.2 all channels' gamma

In [None]:
gamma_list = torch.zeros(num_total)
gamma_record = []

index = 0
one_batch = data1.clone()
for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_batch = m(one_batch)
    if isinstance(m, nn.BatchNorm2d):
        value = one_batch.clone().squeeze(0)
        gamma = activation_based_gamma(value)
        gamma_record.append(gamma)
        size = value.shape[0]
        gamma_list[index:(index+size)] = gamma.clone()
        index += size

In [None]:
gamma_list[:100]

In [None]:
gamma_record

#### 4.3 threshold

In [None]:
pruning_rate = 0.7
y, i = torch.sort(gamma_list)
thre_idx = int(num_total * pruning_rate)
thre = y[thre_idx]

thre_idx, thre

In [None]:
y

#### 4.4 prune

In [None]:
num_pruned = 0
num_cfg = []
mask_cfg = []

one_batch = data1.clone()
for k, m in enumerate(model.feature):
    with torch.no_grad():
        one_batch = m(one_batch)
        if isinstance(m, nn.BatchNorm2d):
            value = one_batch.clone().squeeze(0)
            gamma = activation_based_gamma(value)
            mask = gamma.gt(thre).float()
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            num_cfg.append(int(torch.sum(mask)))
            mask_cfg.append(mask.clone())
            num_pruned += mask.shape[0] - torch.sum(mask)
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                      format(k, mask.shape[0], int(torch.sum(mask))))
        
        elif isinstance(m, nn.MaxPool2d):
            mask_cfg.append('M')
    
pruned_ratio = num_pruned / num_total
print('Pre-processing Successful! Pruned ratio: ', pruned_ratio)