## Import libarary

In [12]:
import os

import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np

from models import vgg

### Set hyperparameter

In [None]:
DATASET = 'cifar10'
TEST_BATCH_SIZE = 1000
CUDA = True
PRUNE_PERCENT = 0.9
WEIGHT_PATH = 'model_best.pth'
PRUNE_PATH = 'model_prune.pth' 


### Load model

In [14]:
CUDA = CUDA and torch.cuda.is_available()

model = vgg()
if CUDA:
    model.cuda()
    
if WEIGHT_PATH:
    if os.path.isfile(WEIGHT_PATH):
        checkpoint = torch.load(WEIGHT_PATH)
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print('LOADING CHECKPOINT {} @EPOCH={}, BEST_PREC1={}'.format(WEIGHT_PATH,checkpoint['epoch'],best_prec1))

    else:
        print("NO CHECKPOINT FOUND")

print(model)

LOADING CHECKPOINT D:\VScode\EAI\Lab5\model_best.pth @EPOCH=61, BEST_PREC1=0.8730000257492065
vgg(
  (feature): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

### Test function (observe model accuracy)

In [15]:
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {}
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),
        batch_size=TEST_BATCH_SIZE, shuffle=True, **kwargs)
    model.eval()
    correct = 0
    with torch.no_grad():
      for data, target in test_loader:
          if CUDA:
              data, target = data.cuda(), target.cuda()
          data, target = Variable(data), Variable(target)
          output = model(data)
          pred = output.data.max(1, keepdim=True)[1]
          correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

### Pruning
#### Calculate the absolute value of scale factor in all Batch Nromalizaiton and sort them
#### Using PRUNE_RATIO to obtain the threshold

In [16]:
total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size

y, i = torch.sort(bn)

threshold_index = int(total * PRUNE_PERCENT)
threshold = y[threshold_index]


### Create CONFIG, which will be used when building the pruned network later

In [None]:
pruned = 0 
cfg = []  #CONFIG used to build the pruning network
cfg_mask = [] #Mask to help pruning

In [18]:
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)

2.0.1+cu118
True
11.8


### Create CONFIG based on Batch Normalization Layer information
#### 1. Copy the weight (i.e. scale factor) of Batch Normalization Layer
#### 2. Create a mask. The index value greater than the threshold will be set to 1, and the value less than the threshold will be set to 0
#### 3. The sum of the values ​​of the index greater than the threshold will be the output channel corresponding to the Layer after pruning
#### 4. Finally, get the CONFIG to build the pruning model

In [19]:
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.clone()
        mask = weight_copy.abs().gt(threshold).float().cuda()
        print(torch.sum(mask))
        
        if torch.sum(mask)==0:
            preserve_idx=torch.topk(weight_copy, 4)[1].tolist()
            mask[preserve_idx]=1
            
        fc_layers_idx = range(len(cfg))
        for i in fc_layers_idx:
            if cfg[i] == 0:
                cfg[i] = 3
       
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
         
        
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

pruned_ratio = pruned/total

print(f'PRUNE RATIO={pruned_ratio}')
print('PREPROCESSING SUCCESSFUL!')

print(cfg)
#print(mask)


tensor(18., device='cuda:0')
layer index: 3 	 total channel: 64 	 remaining channel: 18
tensor(57., device='cuda:0')
layer index: 6 	 total channel: 64 	 remaining channel: 57
tensor(74., device='cuda:0')
layer index: 10 	 total channel: 128 	 remaining channel: 74
tensor(119., device='cuda:0')
layer index: 13 	 total channel: 128 	 remaining channel: 119
tensor(125., device='cuda:0')
layer index: 17 	 total channel: 256 	 remaining channel: 125
tensor(85., device='cuda:0')
layer index: 20 	 total channel: 256 	 remaining channel: 85
tensor(36., device='cuda:0')
layer index: 23 	 total channel: 256 	 remaining channel: 36
tensor(2., device='cuda:0')
layer index: 26 	 total channel: 256 	 remaining channel: 2
tensor(0., device='cuda:0')
layer index: 30 	 total channel: 512 	 remaining channel: 4
tensor(0., device='cuda:0')
layer index: 33 	 total channel: 512 	 remaining channel: 4
tensor(0., device='cuda:0')
layer index: 36 	 total channel: 512 	 remaining channel: 4
tensor(0., device=

### Building a pruning model

In [20]:
newmodel = vgg(cfg=cfg)
newmodel.cuda()

vgg(
  (feature): Sequential(
    (0): Conv2d(3, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(18, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(57, 74, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): BatchNorm2d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(74, 119, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (11): BatchNorm2d(119, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

### Copy the original model weights to the pruned model
#### 1. Determine the input and output channels of this layer
#### 2. Decide what weights to copy based on different layers
#### Batch Normalization Layer
1.   scale factor
2.   bias
3.   running mean
4.   running variance

#### Convolutional Layer
1.   weight
2.   bias

#### Linear Layer
1.   weight
2.   bias



In [None]:
layer_id_in_cfg = 0
start_mask = torch.ones(3) #3 represents input channel(R,G,B)
end_mask = cfg_mask[layer_id_in_cfg]
count = 0
for [m0, m1] in zip(model.modules(), newmodel.modules()):

    if isinstance(m0, nn.BatchNorm2d):
        # Dealing with pruned weights
        m0.weight.data.mul_(end_mask)
        m0.bias.data.mul_(end_mask)

        #### Find the index of the non-zero element in the mask ####
        idx = torch.nonzero(end_mask)

        # Copy the weights of the original model to the weights of the pruned model

        #### Copy weight and bias ####
        m1.weight.data = m0.weight.data[idx].clone() 
        m1.bias.data = m0.bias.data[idx].clone()

        #### Copy running mean and running variance ####
        m1.running_mean = m0.running_mean[idx]
        m1.running_var = m0.running_var[idx]

        layer_id_in_cfg += 1
        start_mask = end_mask.clone()
        
        # The last connection layer is not modified
        if layer_id_in_cfg < len(cfg_mask):
            end_mask = cfg_mask[layer_id_in_cfg]

        #print("start_mask=",start_mask.shape)
        #print("end_mask=",end_mask.shape)
        
    elif isinstance(m0, nn.Conv2d):
        # Copy the convolutional layer weights of the original model to the convolutional layer weights of the corresponding pruned model
        # idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        # idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        idx0 = np.flatnonzero(start_mask.cpu().numpy())
        idx1 = np.flatnonzero(end_mask.cpu().numpy())

        # print("idx0.shape=",idx0.shape)
        # print("idx1.shape=",idx1.shape)

        w = m0.weight.data[:, idx0, :, :].clone()
        w = w[idx1, :, :, :].clone()
        m1.weight.data = w.clone()
        #m1.bias.data = m0.bias.data[idx1].clone()
    elif isinstance(m0, nn.Linear):
        # Refer to https://pytorch.org/docs/stable/generated/torch.nn.Linear.html to decide how to copy Linear Layer parameters.

        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))

        #### Copy weight ####
        m1.weight.data = m0.weight.data[:,idx0].clone()

        #### Copy bias #### 
        m1.bias.data = m0.bias.data.clone()


### Save the model and print the results

In [22]:
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, PRUNE_PATH)

print(newmodel)
model = newmodel
test(newmodel)

vgg(
  (feature): Sequential(
    (0): Conv2d(3, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(18, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(57, 74, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): BatchNorm2d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(74, 119, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (11): BatchNorm2d(119, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

tensor(0.1000)