In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import time
import numpy as np

import tensorflow as tf
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

from cleverhans.utils_tf import batch_eval
from cleverhans.dataset import CIFAR10
from dataloader import SVHN

from utils_config import ModelConfig, dataset_loader
from utils_experiment import get_data_dict
from utils_experiment import train_model
from utils_experiment import hyperparameter_selection
from dknn import DkNNModel, NearestNeighbor

from matplotlib import pyplot as plt
import copy
from PIL import Image
import scipy.io
import copy
import urllib
import shutil
import seaborn as sns

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
mc = ModelConfig(config_file='../configs/config_mnist.yaml',
                 root_dir='../results/')

In [None]:
# reand and wrangle data
data_dict = get_data_dict(mc)

# parse data_dict
x_train = data_dict['x_train'] 
labels_train = data_dict['labels_train']
x_test = data_dict['x_test']
y_test = data_dict['y_test']
x_cali = data_dict['x_cali'] 
labels_cali = data_dict['labels_cali']

# Use Image Parameters.
img_rows, img_cols, nchannels = x_train.shape[1:4]

# Get rotated MNIST
x_rotated = copy.copy(x_test)
y_rotated = copy.copy(y_test)

for i, x in enumerate(x_rotated):
    im = Image.fromarray(np.uint8(x[:,:,0]*255), mode='L')
    im = im.rotate(45)
    x_rotated[i,:,:,0] = np.array(im)/255

# Get Not MNIST
nm_file = '../data/notMNIST_small.mat'
if not os.path.exists(nm_file):
    url = 'http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat'
    print('downloading Not MNIST data to {}'.format(nm_file))
    with urllib.request.urlopen(url) as response, open(nm_file, 'wb') as out_file:
        shutil.copyfileobj(response, out_file)
mat = scipy.io.loadmat(nm_file)
nm = mat['images'][:,:,:250].transpose(2,0,1).reshape(-1,28,28,1)

print(x_test.shape)
print(x_rotated.shape)
print(nm.shape)

In [None]:
# Use Image Parameters.
img_rows, img_cols, nchannels = x_train.shape[1:4]

with mc.get_tensorflow_session() as sess:
    with tf.variable_scope('dknn'):
        # Instantiate model
        model_dir = mc.get_model_dir_name()
        model = mc.load_model(model_dir=model_dir)

        # Define callable that returns a dictionary of all activations for a dataset
        def get_activations(data):
            data_activations = {}
            for layer in layers:
                layer_sym = tf.layers.flatten(model.get_layer(x, layer))
                data_activations[layer] = batch_eval(sess, [x], [layer_sym], [data],
                                                args={'batch_size': mc.batch_size})[0]
            return data_activations

        # Extract representations for the training and calibration data at each layer of interest to the DkNN.
        layers = ['ReLU1', 'ReLU3', 'ReLU5', 'logits']

        #Euclidean DKNN
        dknn = DkNNModel(
        sess = sess,
        model = model,
        neighbors = 16,
        proto_neighbors = mc.nb_proto_neighbors,
        img_rows=mc.img_rows,
        img_cols=mc.img_cols,
        backend=mc.backend,
        nchannels=mc.nchannels,
        nb_classes=mc.nb_classes,
        layers=layers,
        train_data=x_train,
        train_labels=labels_train,
        method='euclidean',
        neighbors_table_path=mc.get_model_dir_name(),
        scope='dknn')

        dknn.fit()              
        dknn.calibrate(x_cali, labels_cali)

        # Geodesic DKNN
        dknn_geod = DkNNModel(
        sess = sess,
        model = model,
        neighbors = 128,
        proto_neighbors = mc.nb_proto_neighbors,
        img_rows=mc.img_rows,
        img_cols=mc.img_cols,
        backend=mc.backend,
        nchannels=mc.nchannels,
        nb_classes=mc.nb_classes,
        layers=layers,
        train_data=x_train,
        train_labels=labels_train,
        method='geodesic',
        neighbors_table_path=mc.get_model_dir_name(),
        scope='dknn')

        dknn_geod.fit()
        dknn_geod.calibrate(x_cali, labels_cali)
        
        print('Getting confidence scores')
        preds_knn, conf_euc, cred_euc = dknn.predict(x_test)
        preds_geod, conf_geo, cred_geo = dknn_geod.predict(x_test)
        
        preds_knn, conf_euc_rotated, cred_euc_rotated = dknn.predict(x_rotated)
        preds_geod, conf_geo_rotated, cred_geo_rotated = dknn_geod.predict(x_rotated)
        
        preds_knn, conf_euc_nm, cred_euc_nm = dknn.predict(nm)
        preds_geod, conf_geo_nm, cred_geo_nm = dknn_geod.predict(nm)
        
        conf_dknn = [conf_euc.max(axis=1), conf_euc_rotated.max(axis=1), conf_euc_nm.max(axis=1)]
        conf_geod = [conf_geo.max(axis=1), conf_geo_rotated.max(axis=1), conf_geo_nm.max(axis=1)]

        cred_dknn = [cred_euc.max(axis=1), cred_euc_rotated.max(axis=1), cred_euc_nm.max(axis=1)]
        cred_geod = [cred_geo.max(axis=1), cred_geo_rotated.max(axis=1), cred_geo_nm.max(axis=1)]

In [None]:
confs = []
for i in range(len(conf_dknn)):
    confs.append(
                {
                'euclidean':conf_dknn[i],
                'geodesic':conf_geod[i],
            })
creds = []
for i in range(len(conf_dknn)):
    creds.append(
                {
                'euclidean':cred_dknn[i],
                'geodesic':cred_geod[i],
            })

In [None]:
sns.set_style("whitegrid")

In [None]:
def plot_distributions(distributions_dict, xlabel, fig_title):
    fig, ax = plt.subplots(1, figsize=(7,5.5))
    plt.subplots_adjust(wspace=0.35)
    colors = sns.color_palette('plasma', len(distributions_dict.keys()))
    
    for idx, dist_name in enumerate(distributions_dict.keys()):
#         train_dist_plot = sns.kdeplot(distributions_dict[dist_name],
#                                      bw='silverman',
#                                      label=dist_name,
#                                      color=colors[idx])
        train_dist_plot = sns.distplot(distributions_dict[dist_name], rug=False, kde=False,
                                    label=dist_name, color=colors[idx])
        ax.set_xlabel(xlabel, fontsize=14)
        ax.set_ylabel('Density', fontsize=14)
        #ax.set_title(fig_title, fontsize=15.5)
        #ax.grid(True)
        ax.legend(loc='center left', bbox_to_anchor=(1,0.5))
        ax.set_xlim((0,1))
    fig.tight_layout()
    #plt.savefig('../results/credibility_mnist_not.png')
    plt.show()

In [None]:
plot_distributions(creds[0], 'Credibility', 'MNIST Credibility')
plot_distributions(creds[1], 'Credibility', 'Rotated MNIST Credibility')
plot_distributions(creds[2], 'Credibility', 'Not MNIST Credibility')