# Import libs

In [1]:
import os
import numpy as np
from tqdm import tqdm
from datetime import datetime
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# torch libs
import torch
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image
from torchvision import datasets, transforms
import pickle
# custom libs
import utils, dataset_utils, settings
from IPython.display import clear_output

device_id = 0
torch.cuda.set_device(device_id)
device = 'cuda:{}'.format(device_id)
checkpoint_root = os.path.join(settings.PATH, 'models')

In [2]:
path_dict = {'cifar10': 'copy1/copy5/retrain_sdn_5.pt',
             'svhn': 'copy1/copy5/retrain_sdn_5.pt',
             'gtsrb': 'copy1/copy10/retrain_sdn_5.pt'}

In [3]:
def print_possibilities(th_range, start_range, result_array):
    print("\t", end="")
    for col_id, th in enumerate(th_range):
        print("{:.2f}".format(th), end="\t")
    print()
    for row_id, start in enumerate(start_range):
        print(row_id, end="\t")
        for col_id, th in enumerate(th_range):
            print("{:.2f}".format(result_array[row_id][col_id] * 100), end="\t")
        print()

In [4]:
def get_retrained_model(model_name, dataset_name, device):
    
    pretrained_sdn = os.path.join('{}_{}'.format(model_name, dataset_name), path_dict[dataset_name])
    sdn_model_br = utils.get_sdn_model(model_name,
        utils.get_add_output(model_name), 
        dataset.num_classes, 
        dataset.img_size
    )
    sdn_model_br.load_state_dict(torch.load(os.path.join(checkpoint_root, pretrained_sdn), map_location=device))
    sdn_model_br.eval()
    sdn_model_br.to(device)
    return sdn_model_br

# Load data & model

In [5]:
model_name = 'resnet56'
dataset_names = ['cifar10', 'svhn', 'gtsrb']
# dataset_names = ['gtsrb']

# Result on all possibilities of th and start_exit

In [9]:
th_range = np.linspace(0.5, 0.95, 10)
for dataset_name in dataset_names:
    dataset = dataset_utils.load_dataset(dataset_name)(
                batch_size=1024, doNormalization=True, 
                inj_rate=0.01)
    sdn_model_br = get_retrained_model(model_name, dataset_name, device)
    output_list = []
    with torch.no_grad():
        for x, y in dataset.test_backdoor_loader:
            x = x.to(device)
            output = sdn_model_br(x)
            output_list.append(torch.stack(output))
    output_list = torch.cat(output_list, dim=1)
    output_list = list(output_list)
    num_exits = len(output_list)
    start_range = list(range(num_exits))
    
    result_array = np.zeros((num_exits, len(th_range)))
    for i, th in enumerate(th_range):
        for j, start in enumerate(start_range):

            b_out_idx, b_pred = utils.test_threshold(output_list, th, start_from_include=start)
            result_array[j][i] = torch.sum(b_pred == dataset.target_class) / len(b_pred)
    print(dataset_name)
    print_possibilities(th_range, start_range, result_array)

CIFAR10::init - doNormalization is True
cifar10
	0.50	0.55	0.60	0.65	0.70	0.75	0.80	0.85	0.90	0.95	
0	25.73	30.28	35.44	41.09	47.70	55.23	63.23	71.66	79.52	85.59	
1	34.19	36.90	40.77	45.09	50.93	57.78	64.99	72.77	80.08	85.71	
2	44.28	48.27	52.56	56.66	61.62	67.37	73.08	78.58	83.80	87.01	
3	57.27	58.39	60.28	62.58	66.02	70.43	74.93	79.72	84.34	87.12	
4	59.73	59.41	60.49	62.61	66.04	70.46	75.00	79.78	84.36	87.13	
5	53.57	55.31	58.09	62.16	66.37	71.10	75.66	80.30	84.54	87.16	
6	48.74	52.53	56.83	61.68	66.37	71.28	75.91	80.50	84.59	87.17	
7	48.52	53.60	58.51	63.31	67.79	72.43	76.99	81.27	85.00	87.32	
8	56.84	61.07	64.57	68.41	71.97	75.63	79.40	83.09	86.09	87.86	
9	77.13	77.31	77.61	78.24	79.82	81.44	83.44	85.48	87.77	88.42	
10	76.18	76.76	77.57	78.51	80.40	82.32	84.27	86.09	88.18	88.66	
11	78.49	78.49	79.08	80.57	82.27	84.38	86.37	87.74	89.26	89.17	
12	72.24	74.77	77.09	79.88	82.24	84.58	86.60	88.14	89.53	89.28	
13	90.57	91.00	91.46	92.02	92.73	93.30	93.51	93.58	92.88	90.61	
14	90.69	91.33

# Accuracy

In [17]:
th_range = np.linspace(0.8, 1.0, 10)
for dataset_name in dataset_names:
    dataset = dataset_utils.load_dataset(dataset_name)(
                batch_size=1024, doNormalization=True, 
                inj_rate=0.01)
    sdn_model_br = get_retrained_model(model_name, dataset_name, device)
    output_list = []
    with torch.no_grad():
        trues = []
        for x, y in dataset.test_loader:
            x = x.to(device)
            output = sdn_model_br(x)
            output_list.append(torch.stack(output))
            trues.extend(y.tolist())
    output_list = torch.cat(output_list, dim=1)
    output_list = list(output_list)
    trues = torch.tensor(trues).reshape(-1)
    
    num_exits = len(output_list)
    start_range = list(range(num_exits))
    
    result_array = np.zeros((num_exits, len(th_range)))
    for i, th in enumerate(th_range):
        for j, start in enumerate(start_range):

            b_out_idx, b_pred = utils.test_threshold(output_list, th, start_from_include=start)
            result_array[j][i] = torch.sum(b_pred.cpu() == trues) / len(b_pred)
    print(dataset_name)
    print_possibilities(th_range, start_range, result_array)

CIFAR10::init - doNormalization is True
cifar10
	0.80	0.82	0.84	0.87	0.89	0.91	0.93	0.96	0.98	1.00	
0	82.16	83.19	84.03	84.47	85.09	85.67	85.99	86.09	86.06	85.85	
1	82.35	83.36	84.15	84.53	85.14	85.71	86.00	86.10	86.06	85.85	
2	82.41	83.41	84.20	84.59	85.16	85.71	86.02	86.10	86.06	85.85	
3	82.55	83.48	84.24	84.61	85.17	85.74	86.02	86.10	86.06	85.85	
4	82.59	83.50	84.25	84.62	85.16	85.73	86.03	86.12	86.07	85.85	
5	82.63	83.53	84.28	84.66	85.18	85.75	86.04	86.12	86.07	85.85	
6	82.72	83.64	84.31	84.66	85.18	85.75	86.04	86.13	86.07	85.85	
7	82.83	83.73	84.38	84.74	85.23	85.77	86.05	86.13	86.07	85.85	
8	82.91	83.83	84.46	84.87	85.32	85.81	86.07	86.13	86.07	85.85	
9	83.25	84.06	84.65	84.96	85.37	85.84	86.08	86.14	86.07	85.85	
10	83.71	84.40	84.85	85.14	85.47	85.85	86.06	86.16	86.08	85.85	
11	84.02	84.62	84.98	85.23	85.54	85.94	86.11	86.19	86.10	85.85	
12	84.22	84.72	85.07	85.38	85.63	85.98	86.11	86.22	86.10	85.85	
13	84.47	84.91	85.18	85.40	85.62	86.02	86.13	86.25	86.07	85.85	
14	84.58	85.01

# Clean accuracy CNN

In [12]:
def get_clean_model(model_name, dataset_name, device):
    
    path = os.path.join('{}_{}'.format(model_name, dataset_name), 'model_cnn.pt')
    cnn_model = utils.get_cnn_model(model_name,
        dataset.num_classes, 
        dataset.img_size
    )
    cnn_model.load_state_dict(torch.load(os.path.join(checkpoint_root, path), map_location=device))
    cnn_model.eval()
    cnn_model.to(device)
    return cnn_model

In [14]:
th_range = np.linspace(0.8, 0.98, 10)
for dataset_name in dataset_names:
    dataset = dataset_utils.load_dataset(dataset_name)(
                batch_size=1024, doNormalization=True, 
                inj_rate=0.01)
    cnn_model = get_clean_model(model_name, dataset_name, device)
    trues = []
    preds = []
    with torch.no_grad():
        for x, y in dataset.test_loader:
            x = x.to(device)
            preds.extend(cnn_model(x).argmax(dim=1).tolist())
            trues.extend(y.tolist())

    trues = np.array(trues)
    preds = np.array(preds)
    print("{:.2f}".format(np.mean(trues == preds) * 100))

CIFAR10::init - doNormalization is True
86.99
SVHN::init - doNormalization is True
94.82
GTSRB::init - doNormalization is True
96.05
