### 1. Quantify the difference of difference feature of neural network after a certain amount of network pruning.
### 2. Understand the subnetwork overlap between different neural network samples.

In [61]:
from torchvision import datasets, transforms
from torch import nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import sys
import os
import argparse
import numpy as np

sys.path.insert(0, os.path.abspath('..'))

# import the model of neural network
from python.model import LeNet, LeNet_5

In [62]:
# Training settings
parser = argparse.ArgumentParser(
    description='PyTorch MNIST pruning from deep compression paper')
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
                    help='input batch size for training (default: 50)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--log', type=str, default='log.txt',
                    help='log file name')
parser.add_argument('--sensitivity', type=float, default=2,
                    help="sensitivity value that is multiplied to layer's std in order to get threshold value")
args, unknown = parser.parse_known_args()

#the device used for training
device = 'cpu'#torch.device("cuda" if use_cuda else 'cpu')

In [63]:
# load the training data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)

# load the testing data
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=1, shuffle=False)

In [64]:
def train(epochs, model, device, optimizer):
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))

        for batch_idx, (data, target) in pbar:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()

        #save(model, str(epoch))
        test(model, device)


def test(model, device):
    model.eval()
    test_loss = 0
    correct = 0
    flag = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        print('Test set: Average loss:', test_loss,
              'Accuracy', correct/len(test_loader.dataset))

        return accuracy

In [65]:
class Inner_State_Difference_Quantification:
    def __init__(self, model, pruned_model, samples):
        self.model = model
        self.pruned_model = model
        self.samples = samples
        
    def getInnerStateRepresentation(self):
        return self.model.getInnerState(self.samples)
    
    def feature_map(self):
        pass

    def pca(self):
        pass

    def gradient_of_neuron(self):
        pass

    def weight(self):
        pass

    def saliencyMap(self):
        pass

    def critical_neuron_activation_ranking(self):
        pass

    #
    def input_vs_output(self):
        pass

    ###such as?
    def others(self):
        pass

In [109]:
#load different neural network model
model = LeNet(mask=True).to(device)
#model = LeNet_5(mask=True).to(device)

#model.load_state_dict(torch.load('../data/model/LetNet/model_19.pkl'))
model.load_state_dict(torch.load('../data/model/LetNet/letnet300_trained.pkl'))

model.prune_by_percentile(85)
test(model, device)

85
pruning with threshold 0.01793949184939264
Test set: Average loss: 0.07846910527819351 Accuracy 0.9806


98.06

# subnetwork overlap

In [110]:
dataset = {}

for data, target in test_loader:
    device_data, device_target = data.to('cpu'), target.to('cpu')
    label = device_target.item()
    
    if label in dataset:
        dataset[label].append(device_data.tolist())
    else:
        dataset[label] = [device_data.tolist()]
    
keys = list(dataset.keys())
keys.sort()   

In [111]:
fc1 = {}
fc2 = {}
layer1 = {}
layer2 = {}
for k in keys:
    rs = model.activationPattern(dataset[k])
    fc1[k] = rs['fc1']
    fc2[k] = rs['fc2']  

In [112]:
for k in keys:
    layer1[k] = set(np.where(np.array(fc1[k]) > 50)[0].tolist())
    layer2[k] = set(np.where(np.array(fc2[k]) > 50)[0].tolist())

In [113]:
print(layer1[0] & layer1[1] &  layer1[2] & layer1[3] & layer1[4] & layer1[5] & layer1[6] & layer1[7] & layer1[8] & layer1[9]) 
print(layer1[0] | layer1[1] |  layer1[2] | layer1[3] | layer1[4] | layer1[5] | layer1[6] | layer1[7] | layer1[8] | layer1[9]) 

{260, 5, 133, 39, 235, 270, 242, 212, 121, 158, 62}
{256, 257, 1, 260, 5, 261, 12, 13, 270, 15, 271, 288, 289, 36, 39, 295, 298, 46, 62, 79, 83, 84, 93, 94, 95, 97, 104, 107, 121, 133, 138, 143, 149, 152, 153, 156, 158, 160, 164, 172, 178, 184, 191, 212, 230, 235, 242, 246, 247, 248, 249}


In [114]:
print(layer2[0] & layer2[1] &  layer2[2] & layer2[3] & layer2[4] & layer2[5] & layer2[6] & layer2[7] & layer2[8] & layer2[9]) 
print(layer2[0] | layer2[1] |  layer2[2] | layer2[3] | layer2[4] | layer2[5] | layer2[6] | layer2[7] | layer2[8] | layer2[9]) 

{26, 19, 31, 79}
{1, 2, 3, 4, 8, 10, 11, 19, 20, 21, 22, 23, 26, 27, 28, 29, 31, 33, 34, 35, 36, 42, 45, 47, 53, 57, 60, 61, 62, 63, 65, 66, 72, 76, 77, 79, 80, 87, 88, 89, 90, 91, 93, 94, 95, 98, 99}
