In [None]:
import numpy as np
import tensorflow as tf
import time
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
sns.set_style('whitegrid')

from cleverhans.utils_tf import batch_eval

from utils_config import ModelConfig
from utils_experiment import get_data_dict, train_model
from utils_attacks import get_deltas
from dknn import DkNNModel

In [None]:
plt.rcParams["figure.figsize"] = (15, 5)

In [None]:
mc = ModelConfig(config_file='../configs/config_mnist.yaml',
                 root_dir='../results/')

In [None]:
train_model(mc)

In [None]:
data_dict = get_data_dict(mc)

# parse data_dict
x_train = data_dict['x_train'] 
labels_train = data_dict['labels_train']
x_test = data_dict['x_test         ~
        \begin{subfigure}[t]{0.50\textwidth}']
y_test = data_dict['y_test']
x_cali = data_dict['x_cali'] 
labels_cali = data_dict['labels_cali']

In [None]:
means = np.array([np.mean(x_train[np.where(labels_train == i)], axis=0) for i in range(10)])
labels = np.arange(10)

In [None]:
N = 1000
points = np.random.choice(x_train.shape[0], N) # to be used for the attack

In [None]:
# Use Image Parameters.
img_rows, img_cols, nchannels = x_train.shape[1:4]

with mc.get_tensorflow_session() as sess:
    with tf.variable_scope('dknn'):
        model_dir = mc.get_model_dir_name()
        model = mc.load_model(model_dir=model_dir)
        
        # Extract representations for the training and calibration data at each layer of interest to the DkNN.
        layers = ['ReLU1', 'ReLU3', 'ReLU5', 'logits']

        dknn = DkNNModel(
            sess = sess,
            model = model,
            neighbors = mc.nb_neighbors,
            proto_neighbors = mc.nb_proto_neighbors,
            img_rows=mc.img_rows,
            img_cols=mc.img_cols,
            nchannels=mc.nchannels,
            nb_classes=mc.nb_classes,
            layers=layers,
            train_data=x_train,
            train_labels=labels_train,
            method='euclidean',
            neighbors_table_path=mc.get_model_dir_name(),
            scope='dknn',
            backend=mc.backend)
        start = time.time()
        dknn.fit()
        end = time.time()
        print('dknn time', end-start)
        dknn.calibrate(x_cali, labels_cali)
        preds_knn, confs_knn, creds_knn = dknn.predict(x_test)
        print((preds_knn==np.argmax(y_test,axis=1)).mean())

################ Attack #####################################
        def wrapper_dknn(x):
            return dknn.predict(np.array([x]))
        
        deltas_k, norms_k, confs_k, creds_k = get_deltas(x_train,
                                       labels_train,
                                       means,
                                       labels,
                                       wrapper_dknn,
                                       eps=1e-2)
#############################################################
        
        dknn_geod = DkNNModel(
            sess = sess,
            model = model,
            neighbors = mc.nb_neighbors,
            proto_neighbors = mc.nb_proto_neighbors,
            img_rows=mc.img_rows,
            img_cols=mc.img_cols,
            nchannels=mc.nchannels,
            nb_classes=mc.nb_classes,
            layers=layers,
            train_data=x_train,
            train_labels=labels_train,
            method='geodesic',
            neighbors_table_path=mc.get_model_dir_name(),
            scope='dknn',
            backend=mc.backend)
        start = time.time()
        dknn_geod.fit()
        end = time.time()
        print('gdknn time', end-start)
        dknn_geod.calibrate(x_cali, labels_cali)
        preds_geod, confs_geod, creds_geod = dknn_geod.predict(x_test)
        print((preds_geod==np.argmax(y_test,axis=1)).mean())

##################### Attack ################################
        def wrapper_geod(x):
            return dknn_geod.predict(np.array([x]))
        
        deltas_g, norms_g, confs_g, creds_g = get_deltas(x_train,
                                       labels_train,
                                       means,
                                       labels,
                                       wrapper_geod,
                                       eps=1e-2)
############################################################# 

In [None]:
#path = "../results/attacks/svhn_deltas.pkl"

In [None]:
# Save
import pandas as pd
df = pd.DataFrame({"deltas_k": deltas_k,
                   "deltas_g": deltas_g,
                   "norms_k": norms_k,
                   "norms_g": norms_g})
df.to_pickle(path)

In [None]:
# Load
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()
sns.set_style('whitegrid')
df = pd.read_pickle(path)
deltas_k = df["deltas_k"]
deltas_g = df["deltas_g"]
norms_k = df["norms_k"]
norms_g = df["norms_g"]

In [None]:
def cumulative_plot(xlist, color, linestyle='solid', label=None, bins=40, reverse=True):
    values, base = np.histogram(xlist, bins=bins)
    cumulative = np.cumsum(values)
    plt.plot(base[:-1], (len(xlist) - cumulative)/xlist.shape[0],label=label,color=color,linestyle=linestyle)

In [None]:
plt.subplot(1, 3, 1)
cumulative_plot(np.max(creds_knn[preds_knn==np.argmax(y_test,axis=1)], axis=1), color='red', linestyle='dashed', label='dknn')
cumulative_plot(np.max(creds_geod[preds_geod==np.argmax(y_test,axis=1)], axis=1), color='blue', label='geod')
plt.legend()
plt.title('Credibility levels (properly classified points), MNIST, N=1,000')
plt.xlabel('Credibility')
plt.ylabel('Accuracy')
plt.subplot(1, 3, 2)
cumulative_plot(np.max(creds_knn[preds_knn!=np.argmax(y_test,axis=1)], axis=1), color='red', linestyle='dashed', label='dknn')
cumulative_plot(np.max(creds_geod[preds_geod!=np.argmax(y_test,axis=1)], axis=1), color='blue', label='geod')
plt.legend()
plt.title('Credibility levels (misclassified points), MNIST, N=1,000')
plt.xlabel('Credibility')
plt.ylabel('Accuracy')
plt.subplot(1, 3, 3)
cumulative_plot(creds_k, color='red', linestyle='dashed', label='dknn')
cumulative_plot(creds_g, color='blue', label='geod')
plt.legend()
plt.title('Credibility levels (adversarial points), MNIST, N=1,000')
plt.xlabel('Credibility')
plt.ylabel('Accuracy')
plt.savefig('../results/confidence/credibility_trinity.png')
plt.show()

In [None]:
print('Misclassified points')
print('DkNN', np.mean(np.max(creds_knn[preds_knn!=np.argmax(y_test,axis=1)], axis=1)))
print('GDkNN', np.mean(np.max(creds_geod[preds_geod!=np.argmax(y_test,axis=1)], axis=1)))

In [None]:
print('Properly classified points')
print('DkNN', np.mean(np.max(creds_knn[preds_knn==np.argmax(y_test,axis=1)], axis=1)))
print('GDkNN', np.mean(np.max(creds_geod[preds_geod==np.argmax(y_test,axis=1)], axis=1)))

In [None]:
print('Adversarial Points')
print('DkNN', np.mean(creds_k))
print('GDkNN', np.mean(creds_g))

In [None]:
plt.subplot(1, 2, 1)
cumulative_plot(np.max(creds_knn[preds_knn!=np.argmax(y_test,axis=1)], axis=1), color='red', linestyle='dashed', label='dknn')
cumulative_plot(np.max(creds_geod[preds_geod!=np.argmax(y_test,axis=1)], axis=1), color='blue', label='geod')
plt.legend()
plt.title('Credibility levels (misclassified points), MNIST, N=1,000')
plt.xlabel('Credibility')
plt.ylabel('Accuracy')
plt.subplot(1, 2, 2)
cumulative_plot(creds_k, color='red', linestyle='dashed', label='dknn')
cumulative_plot(creds_g, color='blue', label='geod')
plt.legend()
plt.title('Credibility levels (adversarial points), MNIST, N=1,000')
plt.xlabel('Credibility')
plt.ylabel('Accuracy')
plt.savefig('../results/confidence/credibility_double.png')
plt.show()

In [None]:
cumulative_plot(deltas_k, color='red', linestyle='dashed', label='dknn')
cumulative_plot(deltas_g, color='blue', label='gdknn')
plt.title('Reverse cumulative plot, SVHN, N=10,000')
plt.xlabel('Attack Delta')
plt.ylabel('Accuracy')
plt.legend()
#plt.savefig('../results/attacks/svhn_deltas.png')
plt.show()

In [None]:
cumulative_plot(norms_k, color='red', linestyle='dashed', label='dknn')
cumulative_plot(norms_g, color='blue', label='gdknn')
plt.title('Reverse cumulative plot, SVHN, N=10,000')
plt.xlabel('Attack vector infinity norm')
plt.ylabel('Accuracy')
plt.legend()
#plt.savefig('../results/attacks/svhn_norms.png')
plt.show()