In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import time
import numpy as np

import tensorflow as tf
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

from cleverhans.utils_tf import batch_eval

from utils_config import ModelConfig, dataset_loader
from utils_experiment import get_data_dict
from utils_experiment import train_model
from utils_experiment import hyperparameter_selection
from dknn import DkNNModel, NearestNeighbor


In [2]:
mc = ModelConfig(config_file='../configs/config_mnist.yaml',
                 root_dir='../results/')

In [None]:
#train_model(mc)

In [3]:
# reand and wrangle data
data_dict = get_data_dict(mc)

# parse data_dict
x_train = data_dict['x_train'] 
labels_train = data_dict['labels_train']
x_test = data_dict['x_test']
y_test = data_dict['y_test']
x_cali = data_dict['x_cali'] 
labels_cali = data_dict['labels_cali']

# Use Image Parameters.
img_rows, img_cols, nchannels = x_train.shape[1:4]

with mc.get_tensorflow_session() as sess:
    with tf.variable_scope('dknn'):
        # Instantiate model
        model_dir = mc.get_model_dir_name()
        model = mc.load_model(model_dir=model_dir)

        # Extract representations for the training and calibration data at each layer of interest to the DkNN.
        layers = ['ReLU1', 'ReLU3', 'ReLU5', 'logits']

        #Euclidean DKNN
        dknn = DkNNModel(
        sess = sess,
        model = model,
        neighbors = mc.nb_neighbors,
        proto_neighbors = mc.nb_proto_neighbors,
        backend = mc.backend,
        img_rows=mc.img_rows,
        img_cols=mc.img_cols,
        nchannels=mc.nchannels,
        nb_classes=mc.nb_classes,
        layers=layers,
        train_data=x_train,
        train_labels=labels_train,
        method='euclidean',
        scope='dknn')
        
        start = time.time()
        dknn.fit()
        end = time.time()
        print(end-start)
        
        activations_dknn = dknn.get_activations(x_train[0:1])
        
        dknn.calibrate(x_cali, labels_cali)
        preds_knn, _, _ = dknn.predict(x_test)

        # Geodesic DKNN
        dknn_geod = DkNNModel(
        sess = sess,
        model = model,
        neighbors = mc.nb_neighbors,
        proto_neighbors = mc.nb_proto_neighbors,
        backend = mc.backend,
        img_rows=mc.img_rows,
        img_cols=mc.img_cols,
        nchannels=mc.nchannels,
        nb_classes=mc.nb_classes,
        layers=layers,
        train_data=x_train,
        train_labels=labels_train,
        method='geodesic',
        neighbors_table_path=mc.get_model_dir_name(),
        scope='dknn')

        start = time.time()
        dknn_geod.fit()
        end = time.time()
        print(end-start)
        
        activations_gdknn = dknn_geod.get_activations(x_train[0:1])
        
        dknn_geod.calibrate(x_cali, labels_cali)
        preds_geod, _, _ = dknn_geod.predict(x_test)


Loading model from:
 ../results/MNIST/nb_train_1000_lr_0.001_bs_2_1/model.joblib

Constructing the NearestNeighbor table layer ReLU1
Constructing the NearestNeighbor table layer ReLU3
Constructing the NearestNeighbor table layer ReLU5
Constructing the NearestNeighbor table layer logits
13.10959267616272
Starting calibration.
Completed calibration.

Constructing the GeodesicNearestNeighbor table layer ReLU1
Constructing the GeodesicNearestNeighbor table layer ReLU3
Constructing the GeodesicNearestNeighbor table layer ReLU5
Constructing the GeodesicNearestNeighbor table layer logits
12.761942863464355
Starting calibration.
Completed calibration.



In [4]:
print((preds_knn==np.argmax(y_test,axis=1)).mean())
print((preds_geod==np.argmax(y_test,axis=1)).mean())

0.524
0.908


In [None]:
pos_pos=((preds_knn==np.argmax(y_test,axis=1)) & (preds_geod==np.argmax(y_test,axis=1))).sum()
pos_pos

In [None]:
pos_neg=((preds_knn==np.argmax(y_test,axis=1)) & ~(preds_geod==np.argmax(y_test,axis=1))).sum()
pos_neg

In [None]:
neg_pos=(~(preds_knn==np.argmax(y_test,axis=1)) & (preds_geod==np.argmax(y_test,axis=1))).sum()
neg_pos

In [None]:
neg_neg=(~(preds_knn==np.argmax(y_test,axis=1)) & (preds_geod==np.argmax(y_test,axis=1))).sum()
neg_neg

In [None]:
np.argmax(y_test,axis=1)

In [None]:
preds_knn

In [None]:
dknn_geod.train_activations['ReLU3']

In [None]:
dknn_geod.query_objects['ReLU3'].geodesic_kernel

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from utils_kernel import euclidean_kernel, hard_geodesics_euclidean_kernel_regular
from utils_visualization import plot_kernel

In [None]:
euclidean_matrix = euclidean_kernel(dknn.train_activations['ReLU1'])
max_distance = np.max(euclidean_matrix)+1
euclidean_matrix[euclidean_matrix == 0]=max_distance
plot_kernel(euclidean_matrix)

In [None]:
geodesic_euclidean_matrix = hard_geodesics_euclidean_kernel_regular(dknn_geod.train_activations['ReLU1'], 5)
max_distance = np.max(geodesic_euclidean_matrix)+1
geodesic_euclidean_matrix[geodesic_euclidean_matrix == 0]=max_distance
plot_kernel(geodesic_euclidean_matrix)

In [None]:
ks = range(1,70)
same_class_euclidean = np.zeros(len(ks))
same_class_geodesic = np.zeros(len(ks))
for j,k in enumerate(ks):
    acum_euc = 0
    acum_geo = 0
    for i in range(1000):
        euclidean_neighbors_idx = np.argpartition(euclidean_matrix[i,:],k)[:k]
        acum_euc += np.mean(labels_train[i]==labels_train[euclidean_neighbors_idx])

        #geodesic_neighbors_idx = np.argpartition(geodesic_euclidean_matrix[i,:],k)[:k]
        #acum_geo += np.mean(train_labels[i]==train_labels[geodesic_neighbors_idx])
        acum_geo += np.mean(labels_train[i]==labels_train[dknn_geod.query_objects['ReLU1'].train_neighbor_index[i,:k]])
    same_class_euclidean[j] = acum_euc/1000
    same_class_geodesic[j] = acum_geo/1000
    #print(j)

In [None]:
plt.plot(range(1,70),same_class_euclidean, label='Euclidean', linestyle='--')
plt.plot(range(1,70),same_class_geodesic, label='Geodesic')
plt.legend()
plt.grid()
plt.xlabel('Number of Neighbors')
plt.ylabel('% of NN with same class')
#plt.savefig('../results/comformity_comparison_relu1.png')