### 1. Quantify the difference of difference feature of neural network after a certain amount of network pruning.
### 2. Understand the subnetwork overlap between different neural network samples.

In [2]:
from torchvision import datasets, transforms
from torch import nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import sys
import os
import argparse
import numpy as np

sys.path.insert(0, os.path.abspath('..'))

# import the model of neural network
from python.model import LeNet, LeNet_5

In [3]:
# Training settings
parser = argparse.ArgumentParser(
    description='PyTorch MNIST pruning from deep compression paper')
parser.add_argument('--batch-size', type=int, default=50, metavar='N',
                    help='input batch size for training (default: 50)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, metavar='S',
                    help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--log', type=str, default='log.txt',
                    help='log file name')
parser.add_argument('--sensitivity', type=float, default=2,
                    help="sensitivity value that is multiplied to layer's std in order to get threshold value")
args, unknown = parser.parse_known_args()

#the device used for training
device = 'cpu'#torch.device("cuda" if use_cuda else 'cpu')

In [4]:
# load the training data
mnist = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ]))

'''
    create feature imbalance dataset
'''
count = 0
test_loader = torch.utils.data.DataLoader(mnist)
subset_indexes = []
index = 0
with torch.no_grad():
    for data, target in test_loader:
        device_data, device_target = data.to('cpu'), target.to('cpu')
        if count > 0 and target == 0:
            count -= 1
            index += 1
            continue
        else:
            subset_indexes.append(index)
            index += 1
        
mnist = torch.utils.data.Subset(mnist, subset_indexes)
train_loader = torch.utils.data.DataLoader(mnist, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
# load the testing data
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=64, shuffle=False)

In [5]:
CHECKPOINT_DIR = '../data/model/letnet_bias'  # model checkpoints
# make checkpoint path directory
os.makedirs(CHECKPOINT_DIR, exist_ok=True)


def save(model, name):
    path = os.path.join(
        CHECKPOINT_DIR, 'model_{}.pkl'.format(name))
    torch.save(model.state_dict(), path)
    
    
def train(epochs, model, device, optimizer):
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))

        for batch_idx, (data, target) in pbar:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()

        save(model, str(epoch))
        test(model, device)


def test(model, device):
    model.eval()
    test_loss = 0
    correct = 0
    flag = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        print('Test set: Average loss:', test_loss,
              'Accuracy', correct/len(test_loader.dataset))

        return accuracy

In [6]:
class Inner_State_Difference_Quantification:
    def __init__(self, model, pruned_model, samples):
        self.model = model
        self.pruned_model = model
        self.samples = samples
        
    def getInnerStateRepresentation(self):
        return self.model.getInnerState(self.samples)
    
    def saliencyMap(self):
        pass

In [7]:
#load different neural network model
model = LeNet(mask=True).to(device)
#model = LeNet_5(mask=True).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0001)
#model.load_state_dict(torch.load('../data/model/letnet_bias/model_9.pkl'))
#test(model, device)
#model.load_state_dict(torch.load('../data/model/LetNet/letnet_5_trained.pkl'))

#model.prune_by_percentile(97)
train(10, model, device, optimizer)
#test(model, device)

100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:11<00:00, 102.64it/s]
  1%|▉                                                                           | 14/1200 [00:00<00:08, 139.99it/s]

Test set: Average loss: 0.19542308067083358 Accuracy 0.9374


100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:10<00:00, 116.14it/s]
  0%|                                                                                      | 0/1200 [00:00<?, ?it/s]

Test set: Average loss: 0.12133916338831187 Accuracy 0.9623


100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:10<00:00, 113.04it/s]
  0%|                                                                                      | 0/1200 [00:00<?, ?it/s]

Test set: Average loss: 0.09188457947820425 Accuracy 0.9727


100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:10<00:00, 112.93it/s]
  0%|                                                                                      | 0/1200 [00:00<?, ?it/s]

Test set: Average loss: 0.08666917420215904 Accuracy 0.9729


100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:10<00:00, 114.27it/s]
  0%|                                                                                      | 0/1200 [00:00<?, ?it/s]

Test set: Average loss: 0.0839968749506399 Accuracy 0.9734


100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:10<00:00, 115.43it/s]
  0%|                                                                                      | 0/1200 [00:00<?, ?it/s]

Test set: Average loss: 0.07801514909081161 Accuracy 0.9752


100%|███████████████████████████████████████████████████████████████████████████| 1200/1200 [00:13<00:00, 88.48it/s]
  0%|                                                                                      | 0/1200 [00:00<?, ?it/s]

Test set: Average loss: 0.06526057104072533 Accuracy 0.9785


100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:11<00:00, 106.24it/s]
  0%|                                                                                      | 0/1200 [00:00<?, ?it/s]

Test set: Average loss: 0.06807703620311804 Accuracy 0.9789


100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:11<00:00, 104.82it/s]
  0%|                                                                                      | 0/1200 [00:00<?, ?it/s]

Test set: Average loss: 0.0664645975478692 Accuracy 0.9796


100%|██████████████████████████████████████████████████████████████████████████| 1200/1200 [00:10<00:00, 112.55it/s]


Test set: Average loss: 0.06861220777491107 Accuracy 0.9794


# subnetwork overlap

In [13]:
#path = os.path.join('../data/model/LetNet', 'model_{}.pkl'.format('odd'))
#torch.save(model.state_dict(), path)
dataset = {}
count = 0

for data, target in test_loader:
    device_data, device_target = data.to('cpu'), target.to('cpu')
    #label = device_target.item()
    
    print(device_data, device_target)
    
    #if label in dataset:
    #    dataset[label].append(device_data.tolist())
    #else:
    #    if label == 0 and count > 0:
    #        count-=1
    #        continue
    #    dataset[label] = [device_data.tolist()]
    
keys = list(dataset.keys())
keys.sort()  





tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0.

          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([6, 2, 7, 7, 2, 2, 1, 1, 2, 8, 3, 7, 2, 4, 1, 7, 1, 7, 6, 7, 8, 2, 7, 3,
        1, 7, 5, 8, 2, 6, 2, 2, 5, 6, 5, 0, 9, 2, 4, 3, 3, 9, 7, 6, 6, 8, 0, 4,
        1, 5, 8, 2, 9, 1, 8, 0, 6, 7, 2, 1, 0, 5, 5, 2])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0.

          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([1, 5, 3, 8, 9, 1, 9, 7, 9, 5, 5, 2, 7, 4, 6, 0, 1, 1, 1, 0, 4, 4, 7, 6,
        3, 0, 0, 4, 3, 0, 6, 1, 9, 6, 1, 3, 8, 1, 2, 5, 6, 2, 7, 3, 6, 0, 1, 9,
        7, 6, 6, 8, 9, 2, 9, 5, 8, 3, 1, 0, 0, 7, 6, 6])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0.

          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([8, 3, 2, 7, 2, 9, 7, 2, 1, 1, 3, 7, 5, 3, 1, 9, 8, 2, 2, 2, 8, 8, 5, 7,
        3, 8, 9, 8, 8, 6, 8, 2, 3, 9, 7, 5, 6, 2, 9, 2, 8, 8, 1, 6, 8, 8, 7, 9,
        1, 8, 0, 1, 7, 2, 0, 7, 5, 1, 9, 0, 2, 0, 9, 8])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0.

          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([3, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 0, 1, 2, 8, 9, 1, 4, 0, 9,
        5, 0, 8, 0, 7, 7, 1, 1, 2, 9, 3, 6, 7, 2, 3, 8, 1, 2, 9, 8, 8, 7, 1, 7,
        1, 1, 0, 3, 4, 2, 6, 4, 7, 4, 2, 7, 4, 9, 1, 0])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0.

          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([9, 2, 0, 9, 5, 1, 3, 7, 6, 9, 3, 0, 2, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 1, 7, 2,
        5, 0, 8, 0, 2, 7, 8, 8, 3, 0, 6, 0, 2, 7, 6, 6])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0.

          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3,
        1, 5, 1, 2, 4, 9, 2, 4, 6, 8, 0, 1, 1, 9, 2, 6, 6, 8, 7, 4, 2, 9, 7, 0,
        2, 1, 0, 3, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0.

          [0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([9, 3, 8, 4, 4, 7, 0, 1, 9, 2, 8, 7, 8, 2, 5, 9, 6, 0, 6, 5, 5, 3, 3, 3,
        9, 8, 1, 1, 0, 6, 1, 0, 0, 6, 2, 1, 1, 3, 2, 7, 7, 8, 8, 7, 8, 4, 6, 0,
        2, 0, 7, 0, 3, 6, 8, 7, 1, 5, 9, 9, 3, 7, 2, 4])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0.

In [35]:
fc1 = {}
fc2 = {}
layer1 = {}
layer2 = {}
for k in keys:
    rs = model.activationPattern(dataset[k])
    fc1[k] = rs['fc1']
    fc2[k] = rs['fc2']  

In [56]:
for k in keys:
    layer1[k] = set(np.where(np.array(fc1[k]) > 250)[0].tolist())
    layer2[k] = set(np.where(np.array(fc2[k]) > 250)[0].tolist())
    


In [57]:
index = 1
print(len(layer1[index]), len(layer2[index]))
for i in range(0, 10):
    print('layer1',len(layer1[index] & layer1[i])/len(layer1[index]))
    print('layer2',len(layer2[index] & layer2[i])/len(layer2[index]))
# &  layer1[2] & layer1[3] & layer1[4] & layer1[5] & layer1[6] & layer1[7] & layer1[8] & layer1[9]) 
#print(layer1[0] | layer1[1] |  layer1[2] | layer1[3] | layer1[4] | layer1[5] | layer1[6] | layer1[7] | layer1[8] | layer1[9]) 

79 54
layer1 0.5443037974683544
layer2 0.6851851851851852
layer1 1.0
layer2 1.0
layer1 0.6329113924050633
layer2 0.7592592592592593
layer1 0.5569620253164557
layer2 0.6111111111111112
layer1 0.6582278481012658
layer2 0.7777777777777778
layer1 0.6455696202531646
layer2 0.6666666666666666
layer1 0.6329113924050633
layer2 0.7222222222222222
layer1 0.5569620253164557
layer2 0.7037037037037037
layer1 0.6708860759493671
layer2 0.7592592592592593
layer1 0.5949367088607594
layer2 0.6666666666666666


In [50]:
#print(layer2[0] & layer2[1] &  layer2[2] & layer2[3] & layer2[4] & layer2[5] & layer2[6] & layer2[7] & layer2[8] & layer2[9]) 
#print(layer2[0] | layer2[1] |  layer2[2] | layer2[3] | layer2[4] | layer2[5] | layer2[6] | layer2[7] | layer2[8] | layer2[9]) 

In [39]:
len(dataset[0])

980