### 1. Quantify the difference of difference feature of neural network after a certain amount of network pruning.
### 2. Understand the subnetwork overlap between different neural network samples.

In [3]:
from torchvision import datasets, transforms
from torch import nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import sys
import os
import argparse
import numpy as np

sys.path.insert(0, os.path.abspath('..'))

# import the model of neural network
from python.model import LeNet, LeNet_5

In [4]:
# Training settings
parser = argparse.ArgumentParser(
    description='PyTorch MNIST pruning from deep compression paper')
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
                    help='input batch size for training (default: 50)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--log', type=str, default='log.txt',
                    help='log file name')
parser.add_argument('--sensitivity', type=float, default=2,
                    help="sensitivity value that is multiplied to layer's std in order to get threshold value")
args, unknown = parser.parse_known_args()

#the device used for training
device = 'cpu'#torch.device("cuda" if use_cuda else 'cpu')

In [5]:
# load the training data
mnist = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ]))

'''
    create feature implace dataset
'''
#count = 5500
#test_loader = torch.utils.data.DataLoader(mnist)
#subset_indexes = []
#index = 0
#with torch.no_grad():
#    for data, target in test_loader:
#        device_data, device_target = data.to('cpu'), target.to('cpu')
#        if count > 0 and target == 0:
#            count -= 1
#            index += 1
#            continue
#        else:
#            subset_indexes.append(index)
#            index += 1
        
#mnist = torch.utils.data.Subset(mnist, list(range(3000)))
train_loader = torch.utils.data.DataLoader(mnist, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
# load the testing data
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=1, shuffle=False)

In [6]:
CHECKPOINT_DIR = '../data/model/letnet_bias'  # model checkpoints
# make checkpoint path directory
os.makedirs(CHECKPOINT_DIR, exist_ok=True)


def save(model, name):
    path = os.path.join(
        CHECKPOINT_DIR, 'model_{}.pkl'.format(name))
    torch.save(model.state_dict(), path)
    
    
def train(epochs, model, device, optimizer):
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))

        for batch_idx, (data, target) in pbar:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()

        save(model, str(epoch))
        test(model, device)


def test(model, device):
    model.eval()
    test_loss = 0
    correct = 0
    flag = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        print('Test set: Average loss:', test_loss,
              'Accuracy', correct/len(test_loader.dataset))

        return accuracy

In [73]:
class Inner_State_Difference_Quantification:
    def __init__(self, model, pruned_model, samples):
        self.model = model
        self.pruned_model = model
        self.samples = samples
        
    def getInnerStateRepresentation(self):
        return self.model.getInnerState(self.samples)
    
    def saliencyMap(self):
        pass

    

In [76]:
#load different neural network model
model = LeNet(mask=True).to(device)
#model = LeNet_5(mask=True).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0001)
#model.load_state_dict(torch.load('../data/model/letnet_bias/model_9.pkl'))
#test(model, device)
#model.load_state_dict(torch.load('../data/model/LetNet/letnet_5_trained.pkl'))

#model.prune_by_percentile(97)
#train(10, model, device, optimizer)
#test(model, device)

  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 2.3059203907251358 Accuracy 0.1049


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:11<00:00, 53.43it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.2713622331491403 Accuracy 0.921


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:11<00:00, 53.96it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.1879172479834643 Accuracy 0.9427


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:09<00:00, 66.22it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.1421542904417028 Accuracy 0.957


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:09<00:00, 61.30it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.11957704122020454 Accuracy 0.9636


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:31<00:00, 18.88it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.1007671627737071 Accuracy 0.9694


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:08<00:00, 66.73it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.09576021982306922 Accuracy 0.9696


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:08<00:00, 70.17it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.08107472876005438 Accuracy 0.975


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:08<00:00, 69.62it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.07733136297626533 Accuracy 0.9747


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:09<00:00, 63.21it/s]
  0%|                                                                                                                          | 0/600 [00:00<?, ?it/s]

Test set: Average loss: 0.07586093284512406 Accuracy 0.9765


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:08<00:00, 70.10it/s]


Test set: Average loss: 0.07666409314486038 Accuracy 0.9763
Test set: Average loss: 0.07666409314486038 Accuracy 0.9763


97.63

# subnetwork overlap

In [34]:

#path = os.path.join('../data/model/LetNet', 'model_{}.pkl'.format('odd'))
#torch.save(model.state_dict(), path)
dataset = {}
count = 0

for data, target in test_loader:
    device_data, device_target = data.to('cpu'), target.to('cpu')
    label = device_target.item()
    
    if label in dataset:
        dataset[label].append(device_data.tolist())
    else:
        if label == 0 and count > 0:
            count-=1
            continue
        dataset[label] = [device_data.tolist()]
    
keys = list(dataset.keys())
keys.sort()   

In [35]:
fc1 = {}
fc2 = {}
layer1 = {}
layer2 = {}
for k in keys:
    rs = model.activationPattern(dataset[k])
    fc1[k] = rs['fc1']
    fc2[k] = rs['fc2']  

In [56]:
for k in keys:
    layer1[k] = set(np.where(np.array(fc1[k]) > 250)[0].tolist())
    layer2[k] = set(np.where(np.array(fc2[k]) > 250)[0].tolist())
    


In [57]:
index = 1
print(len(layer1[index]), len(layer2[index]))
for i in range(0, 10):
    print('layer1',len(layer1[index] & layer1[i])/len(layer1[index]))
    print('layer2',len(layer2[index] & layer2[i])/len(layer2[index]))
# &  layer1[2] & layer1[3] & layer1[4] & layer1[5] & layer1[6] & layer1[7] & layer1[8] & layer1[9]) 
#print(layer1[0] | layer1[1] |  layer1[2] | layer1[3] | layer1[4] | layer1[5] | layer1[6] | layer1[7] | layer1[8] | layer1[9]) 

79 54
layer1 0.5443037974683544
layer2 0.6851851851851852
layer1 1.0
layer2 1.0
layer1 0.6329113924050633
layer2 0.7592592592592593
layer1 0.5569620253164557
layer2 0.6111111111111112
layer1 0.6582278481012658
layer2 0.7777777777777778
layer1 0.6455696202531646
layer2 0.6666666666666666
layer1 0.6329113924050633
layer2 0.7222222222222222
layer1 0.5569620253164557
layer2 0.7037037037037037
layer1 0.6708860759493671
layer2 0.7592592592592593
layer1 0.5949367088607594
layer2 0.6666666666666666


In [50]:
#print(layer2[0] & layer2[1] &  layer2[2] & layer2[3] & layer2[4] & layer2[5] & layer2[6] & layer2[7] & layer2[8] & layer2[9]) 
#print(layer2[0] | layer2[1] |  layer2[2] | layer2[3] | layer2[4] | layer2[5] | layer2[6] | layer2[7] | layer2[8] | layer2[9]) 

In [39]:
len(dataset[0])

980