## Attention-based Feature Pruning

#### Dataset CIFAR10

In [7]:
import torch
from torchvision import datasets, transforms

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=256, shuffle=True, **kwargs)

Files already downloaded and verified


In [9]:
idx, data = next(enumerate(test_loader))
value, label = data[0].shape, data[1].shape
value, label

(torch.Size([256, 3, 32, 32]), torch.Size([256]))

In [14]:
data1 = data[0][0].clone().unsqueeze(0)
data1.shape

torch.Size([1, 3, 32, 32])

#### Model VGG19_BN

In [16]:
import torchvision.models as models

vgg19 = models.vgg19_bn(pretrained=True)
vgg19.eval()
vgg19

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

#### Predict Data

In [70]:
value1 = vgg19.features[0](data1)
value1, value1.shape

(tensor([[[[-8.5106e-08, -7.2914e-08,  7.8384e-10,  ...,  1.6261e-06,
             2.7526e-06,  2.5523e-06],
           [-1.3694e-07, -1.5476e-07, -2.0210e-07,  ...,  2.7468e-06,
             4.3231e-06,  3.9384e-06],
           [-3.3755e-07, -4.3320e-07, -5.1764e-07,  ...,  2.9127e-06,
             4.3238e-06,  3.9556e-06],
           ...,
           [-1.5394e-06, -2.5490e-06, -2.5003e-06,  ..., -1.9278e-06,
            -1.9479e-06, -1.6941e-06],
           [-1.5366e-06, -2.6200e-06, -2.6268e-06,  ..., -1.8583e-06,
            -1.9974e-06, -1.7430e-06],
           [-1.2952e-06, -2.2182e-06, -2.2789e-06,  ..., -1.3832e-06,
            -1.6083e-06, -1.3835e-06]],
 
          [[ 8.8070e-08,  1.5682e-07,  1.6829e-07,  ...,  1.2236e-06,
             1.8164e-06,  1.5073e-06],
           [-8.1696e-08, -1.6041e-07, -2.2403e-07,  ...,  1.3481e-06,
             1.8669e-06,  1.5077e-06],
           [-4.0767e-07, -5.6811e-07, -4.2556e-07,  ...,  1.1668e-06,
             1.8114e-06,  1.5591e-06],


find nn.ReLU

In [17]:
import torch.nn as nn

for idx, m in enumerate(vgg19.features):
    data1 = m(data1)
    if isinstance(m, nn.ReLU):
        print(idx, data1.shape)

2 torch.Size([1, 64, 32, 32])
5 torch.Size([1, 64, 32, 32])
9 torch.Size([1, 128, 16, 16])
12 torch.Size([1, 128, 16, 16])
16 torch.Size([1, 256, 8, 8])
19 torch.Size([1, 256, 8, 8])
22 torch.Size([1, 256, 8, 8])
25 torch.Size([1, 256, 8, 8])
29 torch.Size([1, 512, 4, 4])
32 torch.Size([1, 512, 4, 4])
35 torch.Size([1, 512, 4, 4])
38 torch.Size([1, 512, 4, 4])
42 torch.Size([1, 512, 2, 2])
45 torch.Size([1, 512, 2, 2])
48 torch.Size([1, 512, 2, 2])
51 torch.Size([1, 512, 2, 2])
