In [1]:
import os
import numpy as np
import scipy.stats as sps
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import prune
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
import pickle
from utils import get_network
from conf import settings
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
import quant_utils
import copy
import random
seed = 0

device = 'cpu'
gpu = 'cuda:1'
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic=True


In [2]:
subset1_mean, subset1_std = settings.CIFAR10_SUBTRAIN_MEAN[0], settings.CIFAR10_SUBTRAIN_STD[0]

def load_model(path, norm=False, dev='cpu'):
    vic = get_network('resnet18', False, num_classes=10).to(dev)
    vic.load_state_dict(torch.load(path, map_location=dev))
    vic.eval()
    if norm:
        return nn.Sequential(transforms.Normalize(subset1_mean, subset1_std), vic)
    return vic

In [3]:
def global_prune(net, p=0.2):
    copy_net = copy.deepcopy(net)
    parameters_to_prune = []
    for name, module in copy_net.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            parameters_to_prune.append((module, 'weight'))
        if isinstance(module, torch.nn.Linear):
            parameters_to_prune.append((module, 'weight'))
    prune.global_unstructured(
        tuple(parameters_to_prune),
        pruning_method=prune.L1Unstructured,
        amount=p,
    )
    return copy_net

# Paths

In [4]:
root_path = '/data1/checkpoint/'
related_folder_path = os.path.join(root_path, 'hash/cifar10/')
unrelated_folder_path = os.path.join(root_path, 'hash/cifar10/independent')

original_path = os.path.join(related_folder_path, 'resnet18_0.pth')
quant_path = os.path.join(related_folder_path, 'resnet18_0_quant.pth')

finetune_path_dict = {}
for fid in range(5):
    folder_name = 'finetune_{}'.format(fid)
    finetune_path_dict[folder_name] = [
        os.path.join(related_folder_path, '{}/finetune_{}.pth'.format(folder_name, i)) for i in range(1, 11)]
    folder_name = 'advfinetune_{}'.format(fid)
    finetune_path_dict[folder_name] = [
        os.path.join(related_folder_path, '{}/finetune_{}.pth'.format(folder_name, i)) for i in range(1, 11)]

for fid in range(5, 10):
    folder_name = 'finetune_{}'.format(fid)
    finetune_path_dict[folder_name] = [
        os.path.join(related_folder_path, '{}/finetune_{}.pth'.format(folder_name, i)) for i in range(1, 21)]
    folder_name = 'advfinetune_{}'.format(fid)
    finetune_path_dict[folder_name] = [
        os.path.join(related_folder_path, '{}/finetune_{}.pth'.format(folder_name, i)) for i in range(1, 21)]

unrelated_path_list = [os.path.join(unrelated_folder_path, 'model_{}.pth'.format(i)) for i in range(200)]

In [5]:
norm = True
n = 1000
input_shape = (n, ) + (3, 32, 32)
# checkpoint_out = {}
output_dict = {}
randf = torch.rand

In [6]:
with torch.no_grad():
    # original net
    original_net = load_model(original_path, norm=norm, dev=gpu)
    output_dict['train'] = original_net(randf(input_shape, device=gpu)).softmax(dim=1).to('cpu')

    # quant net
    net = quant_utils.load_torchscript_model(quant_path, 'cpu')
    if norm:
        net = nn.Sequential(transforms.Normalize(subset1_mean, subset1_std), net)
    output_dict['quant'] = net(randf(input_shape)).softmax(dim=1)
    
    # pruning
    output_dict['prune'] = {}
    for prune_p in [0.8, 0.6, 0.4, 0.2]:
        net = global_prune(original_net, prune_p)
        output_dict['prune'][prune_p] = net(randf(input_shape, device=gpu)).softmax(dim=1).to('cpu')
    
    # finetuned model
    output_dict['finetune'] = {}
    for name, pathlist in tqdm(finetune_path_dict.items()):
        output_dict['finetune'][name] = []
        for path in pathlist:
            net = load_model(path, norm=norm, dev=gpu)
            output_dict['finetune'][name].append(net(randf(input_shape, device=gpu)).softmax(dim=1).to('cpu'))

    # finetuned model
    output_dict['indep'] = []
    for i, path in tqdm(enumerate(unrelated_path_list)):
        net = load_model(path, norm=norm, dev=gpu)
        output_dict['indep'].append(net(randf(input_shape, device=gpu)).softmax(dim=1).to('cpu'))

100%|██████████| 20/20 [00:39<00:00,  1.98s/it]
200it [01:05,  3.05it/s]


# Train autoencoder

In [7]:

autoencoder = nn.Sequential(
                nn.Linear(10, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, 10),
            ).to(gpu)
ae_dataset = torch.utils.data.TensorDataset(output_dict['train'], torch.zeros(len(output_dict['train'])).long()) 
ae_train_dataloader = torch.utils.data.DataLoader(ae_dataset, 
                                                  batch_size=8, 
                                                  shuffle=True, num_workers=8)
n_epoch = 50
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=2e-3)
criterion = nn.MSELoss()

with tqdm(total=n_epoch, desc="train") as pbar:
    for epoch in range(n_epoch):
        epoch_loss = 0
        for x, y in ae_train_dataloader:
            x = x.to(gpu)
            outx = autoencoder(x).softmax(dim=1)
            loss = criterion(x, outx)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * len(y)
#         print(epoch, epoch_loss / len(ae_dataset))
        pbar.set_postfix({'loss' : '{0:1.10e}'.format(epoch_loss / len(ae_dataset))}) 
        pbar.update(1)

train: 100%|██████████| 50/50 [00:33<00:00,  1.51it/s, loss=2.9479244185e-06]


In [8]:
autoencoder.to('cpu')

Sequential(
  (0): Linear(in_features=10, out_features=64, bias=True)
  (1): ReLU()
  (2): Linear(in_features=64, out_features=64, bias=True)
  (3): ReLU()
  (4): Linear(in_features=64, out_features=64, bias=True)
  (5): ReLU()
  (6): Linear(in_features=64, out_features=10, bias=True)
)

In [9]:
vic_errors = []
with torch.no_grad():
    for i in range(1):
        randout = original_net(randf(input_shape, device=gpu)).softmax(dim=1).cpu()
        out = autoencoder(randout).softmax(dim=1)
        vic_errors = torch.sum((randout - out)**2, dim=1).numpy()
        print('{}, {:.2e}, {:.2e}'.format(i, np.mean(vic_errors), np.std(vic_errors)))

0, 8.08e-05, 1.02e-03


In [10]:
with torch.no_grad():
    out = autoencoder(output_dict['train']).softmax(dim=1)
    errors = torch.sum((output_dict['train'] - out)**2, dim=1).numpy()
    print(sps.ks_2samp(errors, vic_errors))
    print("{:.2e}, {:.2e}".format(np.mean(errors), np.std(errors)))

KstestResult(statistic=0.037, pvalue=0.5005673707894058)
2.56e-05, 6.29e-05


In [11]:
with torch.no_grad():
    
    out = autoencoder(output_dict['quant']).softmax(dim=1)
    errors = torch.sum((output_dict['quant'] - out)**2, dim=1).numpy()
    print(sps.ks_2samp(errors, vic_errors))
    print("{:.2e}, {:.2e}".format(np.mean(errors), np.std(errors)))

KstestResult(statistic=0.035, pvalue=0.5728904395829821)
3.71e-04, 7.13e-03


In [12]:
with torch.no_grad():
    for k, v in output_dict['prune'].items():
        print(k)
        out = autoencoder(v).softmax(dim=1)
        errors = torch.sum((v - out)**2, dim=1).numpy()
        stats, pv = sps.ks_2samp(errors, vic_errors)
        print("{:.2e}, {:.2e}, {:.2e}, {:.2e}".format(
            np.mean(errors), np.std(errors), stats, pv))
        

0.8
1.48e-01, 3.71e-02, 1.00e+00, 0.00e+00
0.6
1.10e-03, 5.43e-03, 4.83e-01, 4.81e-106
0.4
1.14e-04, 1.16e-03, 4.20e-02, 3.41e-01
0.2
2.88e-04, 6.15e-03, 3.10e-02, 7.23e-01


In [13]:
finetune_errors = {}
indep_errors = []

In [14]:
with torch.no_grad():
    for k, vlist in output_dict['finetune'].items():
        print(k)
        finetune_errors[k] = []
        for i, v in enumerate(vlist):
            out = autoencoder(v).softmax(dim=1)
            errors = torch.sum((v - out)**2, dim=1).numpy()
            stats, pv = sps.ks_2samp(errors, vic_errors)
            finetune_errors[k].append(errors)
            print(i, "{:.2e}, {:.2e}, {:.2e}, {:.2e}".format(
            np.mean(errors), np.std(errors), stats, pv))

finetune_0
0 2.44e-05, 1.10e-04, 4.50e-02, 2.63e-01
1 5.27e-05, 8.20e-04, 3.90e-02, 4.33e-01
2 2.85e-05, 1.52e-04, 4.30e-02, 3.14e-01
3 2.12e-05, 4.09e-05, 5.50e-02, 9.71e-02
4 2.13e-05, 3.31e-05, 6.10e-02, 4.84e-02
5 1.34e-04, 2.55e-03, 5.90e-02, 6.15e-02
6 2.18e-05, 5.44e-05, 4.80e-02, 2.00e-01
7 2.59e-05, 9.41e-05, 1.22e-01, 6.67e-07
8 2.52e-05, 8.17e-05, 1.11e-01, 8.74e-06
9 8.58e-05, 1.88e-03, 9.40e-02, 2.88e-04
advfinetune_0
0 2.29e-05, 5.14e-05, 3.70e-02, 5.01e-01
1 3.27e-05, 2.29e-04, 7.10e-02, 1.29e-02
2 2.93e-05, 1.14e-04, 8.30e-02, 2.03e-03
3 4.46e-05, 4.14e-04, 1.36e-01, 1.77e-08
4 3.68e-05, 1.82e-04, 9.90e-02, 1.10e-04
5 2.61e-05, 9.72e-05, 1.20e-01, 1.08e-06
6 2.41e-04, 6.39e-03, 1.28e-01, 1.48e-07
7 1.38e-04, 2.34e-03, 1.73e-01, 1.75e-13
8 9.24e-05, 1.34e-03, 1.25e-01, 3.17e-07
9 7.42e-05, 6.77e-04, 1.09e-01, 1.36e-05
finetune_1
0 2.62e-05, 7.07e-05, 4.30e-02, 3.14e-01
1 4.93e-05, 9.10e-04, 4.90e-02, 1.81e-01
2 5.40e-05, 5.45e-04, 9.50e-02, 2.39e-04
3 3.04e-05, 2.36e-04,

In [15]:
with torch.no_grad():
    for i, v in enumerate(output_dict['indep']):
        out = autoencoder(v).softmax(dim=1)
        errors = torch.sum((v - out)**2, dim=1).numpy()
        indep_errors.append(errors)
        stats, pv = sps.ks_2samp(errors, vic_errors)
        print(i, "{:.2e}, {:.2e}, {:.2e}, {:.2e}".format(
            np.mean(errors), np.std(errors), stats, pv))

0 6.15e-02, 1.54e-01, 8.59e-01, 0.00e+00
1 1.87e-01, 2.77e-01, 9.68e-01, 0.00e+00
2 2.88e-04, 2.26e-03, 4.76e-01, 7.17e-103
3 1.69e-01, 9.89e-02, 9.92e-01, 0.00e+00
4 2.11e-01, 1.12e-01, 9.84e-01, 0.00e+00
5 3.27e-05, 5.45e-05, 4.55e-01, 1.12e-93
6 2.43e-01, 9.96e-02, 9.79e-01, 0.00e+00
7 1.90e-01, 1.36e-01, 9.84e-01, 0.00e+00
8 1.56e-01, 1.10e-01, 9.62e-01, 0.00e+00
9 7.01e-02, 1.59e-01, 8.17e-01, 0.00e+00
10 2.84e-01, 1.42e-01, 9.85e-01, 0.00e+00
11 2.88e-01, 9.13e-02, 9.84e-01, 0.00e+00
12 1.16e-01, 1.69e-01, 9.49e-01, 0.00e+00
13 1.41e-02, 4.61e-02, 6.53e-01, 1.33e-201
14 9.01e-02, 1.01e-01, 9.61e-01, 0.00e+00
15 9.86e-02, 1.12e-01, 9.36e-01, 0.00e+00
16 4.89e-03, 3.43e-02, 6.11e-01, 1.91e-174
17 7.47e-02, 9.43e-02, 9.60e-01, 0.00e+00
18 1.60e-03, 7.97e-03, 5.70e-01, 2.71e-150
19 1.51e-01, 1.58e-01, 9.77e-01, 0.00e+00
20 1.80e-01, 1.10e-01, 9.82e-01, 0.00e+00
21 2.71e-01, 6.06e-02, 1.00e+00, 0.00e+00
22 2.92e-01, 1.75e-01, 9.91e-01, 0.00e+00
23 3.27e-05, 1.73e-04, 1.81e-01, 1.00e-1

In [16]:
torch.save(autoencoder.state_dict(), '../results/hash/cifar10/autoencoder.pt')
pickle.dump((output_dict, finetune_errors, indep_errors), open("../results/hash/cifar10/output_dict.pkl", "wb"))

In [17]:
finetune_mean_errors = []
for k, vlist in finetune_errors.items():
    if int(k.split('_')[-1])<5:
        for v in vlist:
            finetune_mean_errors.append(np.mean(v))
print(np.max(finetune_mean_errors))

0.00025885057


In [18]:
test_finetune_mean_errors = []
for k, vlist in finetune_errors.items():
    if int(k.split('_')[-1])>4:
        for v in vlist:
            test_finetune_mean_errors.append(np.mean(v))

In [22]:
# True positive
test_finetune_mean_errors = np.array(test_finetune_mean_errors)
print(len(test_finetune_mean_errors[test_finetune_mean_errors<=np.max(finetune_mean_errors)]))

194


In [19]:
test_finetune_mean_errors = np.array(test_finetune_mean_errors)
print(len(test_finetune_mean_errors[test_finetune_mean_errors>np.max(finetune_mean_errors)]))

6


In [23]:
test_indep_mean_errors = []
for v in indep_errors:
    test_indep_mean_errors.append(np.mean(v))
test_indep_mean_errors = np.array(test_indep_mean_errors)
print("True N\n", len(test_indep_mean_errors[test_indep_mean_errors>np.max(finetune_mean_errors)]))
print("False P\n", len(test_indep_mean_errors[test_indep_mean_errors<np.max(finetune_mean_errors)]))

True N
 190
False P
 10


In [24]:
(190 + 194) / 400

0.96

In [21]:
# from utils import compute_accuracy, get_test_dataloader_cifar10
# test_loader = get_test_dataloader_cifar10((0, 0, 0), (1, 1, 1))
# acc, _, _ = compute_accuracy(original_net, test_loader, 'cpu')
# accquant, _, _ = compute_accuracy(quant_net, test_loader, 'cpu')
# acc20, _, _ = compute_accuracy(pruned_model_2, test_loader, 'cpu')
# acc40, _, _ = compute_accuracy(pruned_model_4, test_loader, 'cpu')
# acc60, _, _ = compute_accuracy(pruned_model_6, test_loader, 'cpu')
# acc80, _, _ = compute_accuracy(pruned_model_8, test_loader, 'cpu')
# print(acc, accquant, acc20, acc40, acc60, acc80)

# 0.9147 0.9155 0.9151 0.914 0.9106 0.776