In [None]:
import sys
sys.path.insert(0, '../') # go up 1 level to include the project root in the search path.

from sklearn.metrics import roc_curve, auc
import numpy as np
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(style="ticks")
sns.set_context("paper")
# sns.set_palette("colorblind") # muted, deep
sns.set_palette("Paired") # paired, cubehelix, husl
# sns.set_palette("coolwarm") # BrBG, RdBu_r, coolwarm
# flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"]
# sns.set_palette(flatui)

from models.MyResNet_Prefetcher import MyResNetPrefetcher

%matplotlib inline

T_values = [4, 8, 16, 32, 64, 128]

# Here, The network config is used to specify a model/results file w.r.t. a configuration.
network_config = dict([('instance_shape', [512, 512, 3]),
                       ('num_classes', 5),
                       ('conv_depths', [1, 1, 1, 1]),
#                        ('num_filters', [[64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048]]),
                       ('num_filters', [[64, 64, 128], [128, 128, 256], [256, 256, 512], [512, 512, 1024]]),
                       ('fc_depths', [512]),
                       ('lambda', 0.00001),
                       ('lr', 0.003), 
                       ('momentum_max', 0.9),
                       ('decay_steps', 10000),
                       ('decay_rate', 0.8), 
                       ('data_aug', True),
                       ('data_aug_prob', 0.9),
                       ('max_iter', 2000),
                       ('oversampling_limit', 0.1),
                       ('batch_size', 23), # ResNet50: Max batch sizes allowed by BatchNorm and BatchReNorm are 14 and 8, respectively.
                       ('val_step', 200),
                       ('resurrection_step', 25000), 
                       ('quick_dirty_val', True),
                       ('T', 0), # To be set later on during Test-time augmentation with various values, {4,8,16...}
                       ('dataset_buffer_size', 500) # times minibatch size effectively
                      ])

# Now, set the file name for the results
RESULTS_DIR = '/gpfs01/berens/user/mayhan/Documents/MyPy/GitRepos/ttaug-DR-uncertainty/results/'
model = MyResNetPrefetcher(network_config=network_config, name='ResNet4GitHub')

In [None]:
## TRAINING CURVES and VALIDATION PERFORMANCE ACROSS TRAINING
#######################################################################################
# Diagnostics file
result_file_name = RESULTS_DIR + model.descriptor + '_DIAG.pkl'

with open(result_file_name, 'rb') as filehandler:
    diagnostics = pickle.load(filehandler)
############################################
df_loss = pd.DataFrame()
df_loss['loss'] = np.concatenate([np.reshape(diagnostics['losses'], newshape=(len(diagnostics['losses']),)), 
                                  np.reshape(diagnostics['avg_losses'], newshape=(len(diagnostics['avg_losses']),))
                                 ],
                                 axis=0
                                )
df_loss['iteration'] = np.concatenate([np.reshape(range(0, len(diagnostics['losses'])), newshape=(len(diagnostics['losses']),)), 
                                       np.reshape(range(0, len(diagnostics['avg_losses'])), newshape=(len(diagnostics['avg_losses']),)),
                                      ],
                                      axis=0
                                     )
df_loss['Type'] = np.concatenate([np.reshape((('minibatch loss',) * len(diagnostics['losses'])), newshape=(len(diagnostics['losses']),)), 
                                  np.reshape((('avg. minibatch loss',) * len(diagnostics['avg_losses'])), newshape=(len(diagnostics['avg_losses']),))
                                 ],
                                 axis=0
                                )

fig = plt.figure(figsize=(15, 7.5))
ax1 = fig.add_subplot(1, 2, 1) 
ax1 = sns.lineplot(x='iteration', y='loss', hue='Type', data=df_loss, ax=ax1)
sns.despine()

############################################
df_roc = pd.DataFrame()
df_roc['ROC-AUC'] = np.concatenate([np.reshape(diagnostics['val_roc1'], newshape=(len(diagnostics['val_roc1']),)), 
                                    np.reshape(diagnostics['val_roc2'], newshape=(len(diagnostics['val_roc2']),)),
                                    0.889 * np.ones(shape=(len(diagnostics['val_roc1']),)), 
                                    0.927 * np.ones(shape=(len(diagnostics['val_roc2']),))
                                   ],
                                   axis=0
                                  )
df_roc['iteration'] = np.concatenate([np.reshape(np.multiply(model.network_config['val_step'], list(range(0, len(diagnostics['val_roc1'])))), newshape=(len(diagnostics['val_roc1']),)), 
                                      np.reshape(np.multiply(model.network_config['val_step'], list(range(0, len(diagnostics['val_roc2'])))), newshape=(len(diagnostics['val_roc2']),)),
                                      np.reshape(np.multiply(model.network_config['val_step'], list(range(0, len(diagnostics['val_roc1'])))), newshape=(len(diagnostics['val_roc1']),)), 
                                      np.reshape(np.multiply(model.network_config['val_step'], list(range(0, len(diagnostics['val_roc2'])))), newshape=(len(diagnostics['val_roc2']),))
                                     ],
                                     axis=0
                                    )
df_roc['Onset level'] = np.concatenate([np.reshape((('Mild DR',) * len(diagnostics['val_roc1'])), newshape=(len(diagnostics['val_roc1']),)), 
                                        np.reshape((('Moderate DR',) * len(diagnostics['val_roc2'])), newshape=(len(diagnostics['val_roc2']),)), 
                                        np.reshape((('Mild DR',) * len(diagnostics['val_roc1'])), newshape=(len(diagnostics['val_roc1']),)), 
                                        np.reshape((('Moderate DR',) * len(diagnostics['val_roc2'])), newshape=(len(diagnostics['val_roc2']),))
                                       ],
                                       axis=0
                                      )
df_roc['Method'] = np.concatenate([np.reshape((('ours',) * len(diagnostics['val_roc1'])), newshape=(len(diagnostics['val_roc1']),)), 
                                   np.reshape((('ours',) * len(diagnostics['val_roc2'])), newshape=(len(diagnostics['val_roc2']),)), 
                                   np.reshape((('Leibig et al.',) * len(diagnostics['val_roc1'])), newshape=(len(diagnostics['val_roc1']),)), 
                                   np.reshape((('Leibig et al.',) * len(diagnostics['val_roc2'])), newshape=(len(diagnostics['val_roc2']),))
                                  ],
                                  axis=0
                                 )
ax2 = fig.add_subplot(1, 2, 2) 
ax2 = sns.lineplot(x='iteration', y='ROC-AUC', hue='Onset level', style='Method', data=df_roc, ax=ax2)
sns.despine()

max_idx = np.argmax(diagnostics['val_roc1'])
ax2.plot(model.network_config['val_step']*max_idx, diagnostics['val_roc1'][max_idx], color='lightcoral', 
         marker='D', markeredgecolor='k', markersize=3)
max_idx = np.argmax(diagnostics['val_roc2'])
ax2.plot(model.network_config['val_step']*max_idx, diagnostics['val_roc2'][max_idx], color='lightcoral', 
         marker='D', markeredgecolor='k', markersize=3)

plt.show()


del diagnostics, df_roc, df_loss

In [None]:
## PLOTS FOR SINGLE PREDICTIONS
############################################################################################
import scipy.stats as stats

def plot_roc_curve(labels_1hot_tr, predictions_1hot_tr, 
                   labels_1hot_val, predictions_1hot_val, 
                   labels_1hot_te, predictions_1hot_te):
    
    legend_labels = np.array(['0: No DR', '1: Mild DR', '2: Moderate DR', '3: Severe DR', '4: Proliferative DR'])
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()    
    
    fpr_col = []
    tpr_col = []
    class_col = []
    split_col = []
    roc_auc_col = []
    roc_auc_class_col = []
    roc_auc_split_col = []
    
    # Training
    split = 'Train'
    for i in range(labels_1hot_tr.shape[1]):
        fpr[i], tpr[i], _ = roc_curve(labels_1hot_tr[:, i], predictions_1hot_tr[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        
        fpr_col = np.concatenate([fpr_col, fpr[i]], axis=0)
        tpr_col = np.concatenate([tpr_col, tpr[i]], axis=0)
        class_col = np.concatenate([class_col, 
                                    np.reshape(((legend_labels[i],) * len(fpr[i])), newshape=(len(fpr[i]),))
                                   ], axis=0
                                  )
        split_col = np.concatenate([split_col, 
                                    np.reshape(((split,) * len(fpr[i])), newshape=(len(fpr[i]),))
                                   ], axis=0
                                  )
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(roc_auc[i], newshape=(1,))], 
                                     axis=0
                                    )
        roc_auc_class_col = np.concatenate([roc_auc_class_col, 
                                            np.reshape(legend_labels[i], newshape=(1,))], 
                                           axis=0
                                          )
        roc_auc_split_col = np.concatenate([roc_auc_split_col, 
                                            np.reshape(split, newshape=(1,))], 
                                           axis=0
                                          )
    # Validation
    split = 'Val.'
    for i in range(labels_1hot_val.shape[1]):
        fpr[i], tpr[i], _ = roc_curve(labels_1hot_val[:, i], predictions_1hot_val[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        
        fpr_col = np.concatenate([fpr_col, fpr[i]], axis=0)
        tpr_col = np.concatenate([tpr_col, tpr[i]], axis=0)
        class_col = np.concatenate([class_col, 
                                    np.reshape(((legend_labels[i],) * len(fpr[i])), newshape=(len(fpr[i]),))
                                   ], axis=0
                                  )
        split_col = np.concatenate([split_col, 
                                    np.reshape(((split,) * len(fpr[i])), newshape=(len(fpr[i]),))
                                   ], axis=0
                                  )
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(roc_auc[i], newshape=(1,))], 
                                     axis=0
                                    )
        roc_auc_class_col = np.concatenate([roc_auc_class_col, 
                                            np.reshape(legend_labels[i], newshape=(1,))], 
                                           axis=0
                                          )
        roc_auc_split_col = np.concatenate([roc_auc_split_col, 
                                            np.reshape(split, newshape=(1,))], 
                                           axis=0
                                          )
    # Test
    split = 'Test'
    for i in range(labels_1hot_te.shape[1]):
        fpr[i], tpr[i], _ = roc_curve(labels_1hot_te[:, i], predictions_1hot_te[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        
        fpr_col = np.concatenate([fpr_col, fpr[i]], axis=0)
        tpr_col = np.concatenate([tpr_col, tpr[i]], axis=0)
        class_col = np.concatenate([class_col, 
                                    np.reshape(((legend_labels[i],) * len(fpr[i])), newshape=(len(fpr[i]),))
                                   ], axis=0
                                  )
        split_col = np.concatenate([split_col, 
                                    np.reshape(((split,) * len(fpr[i])), newshape=(len(fpr[i]),))
                                   ], axis=0
                                  )
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(roc_auc[i], newshape=(1,))], 
                                     axis=0
                                    )
        roc_auc_class_col = np.concatenate([roc_auc_class_col, 
                                            np.reshape(legend_labels[i], newshape=(1,))], 
                                           axis=0
                                          )
        roc_auc_split_col = np.concatenate([roc_auc_split_col, 
                                            np.reshape(split, newshape=(1,))], 
                                           axis=0
                                          )        
    
    df_roc_multi = pd.DataFrame()
    df_roc_multi['False Positive Rate'] = fpr_col
    df_roc_multi['True Positive Rate'] = tpr_col
    df_roc_multi['Class'] = class_col
    df_roc_multi['Split'] = split_col
    
    df_roc_multi_summary = pd.DataFrame()
    df_roc_multi_summary['ROC-AUC'] = roc_auc_col
    df_roc_multi_summary['Class'] = roc_auc_class_col
    df_roc_multi_summary['Split'] = roc_auc_split_col
    
    # ROC curves for train,val and test data combined
#     plt.figure()
    f = plt.figure(figsize=(15,7.5))
    ax1 = f.add_subplot(1, 2, 1)
    ax2 = f.add_subplot(1, 2, 2)
    
    ax1 = sns.lineplot(x='False Positive Rate', y='True Positive Rate', hue='Class', style='Split', ci=None, 
                       data=df_roc_multi, ax=ax1)
    ax1.plot([0, 1], [0, 1], 'k-.')
    sns.despine()
    
    ax2 = sns.pointplot(x='Class', y='ROC-AUC', hue='Split', ci=None, 
                       data=df_roc_multi_summary, ax=ax2)
    sns.despine()
    
    plt.show()

def plot_roc_curves_for_all(result, title='Receiver Operating Characteristics'):
    """Inputs are in 1-hot or 1-vs-all format: Shape of [numOfExamples, numOfClasses]
    The function plots the ROC curves for each binary classification scenario.
    """
    
    labels_1hot_tr  = result['train_labels_1hot']
    labels_1hot_val = result['val_labels_1hot']
    labels_1hot_te = result['test_labels_1hot']
    predictions_1hot_tr = result['train_pred_1hot']
    predictions_1hot_val = result['val_pred_1hot']
    predictions_1hot_te = result['test_pred_1hot']
    
    plot_roc_curve(labels_1hot_tr, predictions_1hot_tr, 
                   labels_1hot_val, predictions_1hot_val, 
                   labels_1hot_te, predictions_1hot_te)

# Now, read the SINGLE PRED. results from file and plot
result_file_name = RESULTS_DIR + model.descriptor + '_SINGpred.pkl'

with open(result_file_name, 'rb') as filehandler:
    result = pickle.load(filehandler)
    
# #     # ROC curves for train,val and test data combined
#     plt.figure()
    plot_roc_curves_for_all(result, '')

del result

In [None]:
###############################################################
### DISCRIMINATIVE Performance via Test-time data augmentation
###############################################################

def summary_ttaug_disc_helper(labels_1hot, predictions_1hot_ttaug):
    # use the median of T predictions for the final class membership: Mx1x5 or Mx5
    predictions_1hot_median = np.median(predictions_1hot_ttaug, axis=1)
    # print('Shape of predictions_1hot_tr_median: ' + str(predictions_1hot_tr_median.shape) )
    correct = np.equal(np.argmax(labels_1hot, axis=1), np.argmax(predictions_1hot_median, axis=1))
    acc_median = np.mean(np.asarray(correct, dtype=np.float32))
#     print('Median Accuracy (multi-class) : %.5f' % acc_median)    
    predictions_1hot_mean = np.mean(predictions_1hot_ttaug, axis=1) 
    correct = np.equal(np.argmax(labels_1hot, axis=1), np.argmax(predictions_1hot_mean, axis=1))
    acc_mean = np.mean(np.asarray(correct, dtype=np.float32))
#     print('Mean Accuracy (multi-class) : %.5f' % acc_mean)
        
    onset_level = 1
#     print('Onset level = %d' % onset_level)
    labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
    pred_bin = np.sum(predictions_1hot_ttaug[:, :, onset_level:], axis=2) # MxTx1
    pred_bin_median = np.median(pred_bin, axis=1) # Mx1x1  
    fpr, tpr, _ = roc_curve(labels_bin, np.squeeze(pred_bin_median))
    roc_auc_onset1_median = auc(fpr, tpr)
#     print('With median pred., ROC-AUC: %.5f' % roc_auc_onset1_median)
    pred_bin_mean = np.mean(pred_bin, axis=1) # Mx1x1  
    fpr, tpr, _ = roc_curve(labels_bin, np.squeeze(pred_bin_mean))
    roc_auc_onset1_mean = auc(fpr, tpr)
#     print('With mean pred., ROC-AUC: %.5f' % roc_auc_onset1_mean)
            
    onset_level = 2
#     print('Onset level = %d' % onset_level)
    labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
    pred_bin = np.sum(predictions_1hot_ttaug[:, :, onset_level:], axis=2) # MxTx1
    pred_bin_median = np.median(pred_bin, axis=1) # Mx1x1  
    fpr, tpr, _ = roc_curve(labels_bin, np.squeeze(pred_bin_median))
    roc_auc_onset2_median = auc(fpr, tpr)
#     print('With median pred., ROC-AUC: %.5f' % roc_auc_onset2_median)
    pred_bin_mean = np.mean(pred_bin, axis=1) # Mx1x1  
    fpr, tpr, _ = roc_curve(labels_bin, np.squeeze(pred_bin_mean))
    roc_auc_onset2_mean = auc(fpr, tpr)
#     print('With mean pred., ROC-AUC: %.5f' % roc_auc_onset2_mean)
    
    return acc_median, acc_mean, roc_auc_onset1_median, roc_auc_onset1_mean, roc_auc_onset2_median, roc_auc_onset2_mean

def summarize_ttaug_discriminative_performance(result):
    labels_1hot_tr = result['train_labels_1hot'] # Mx5
    labels_1hot_val = result['val_labels_1hot']
    labels_1hot_te = result['test_labels_1hot']
    predictions_1hot_tr_ttaug = result['train_pred_1hot'] # MxTx5
    predictions_1hot_val_ttaug = result['val_pred_1hot']
    predictions_1hot_te_ttaug = result['test_pred_1hot']
    
    # TRAINING
#     print('Discriminative summary of training:')
    tr_acc_median, tr_acc_mean, tr_roc_auc_onset1_median, tr_roc_auc_onset1_mean, tr_roc_auc_onset2_median, tr_roc_auc_onset2_mean = summary_ttaug_disc_helper(labels_1hot_tr, predictions_1hot_tr_ttaug)
    # VALIDATION
#     print('Discriminative summary of validation:')
    val_acc_median, val_acc_mean, val_roc_auc_onset1_median, val_roc_auc_onset1_mean, val_roc_auc_onset2_median, val_roc_auc_onset2_mean = summary_ttaug_disc_helper(labels_1hot_val, predictions_1hot_val_ttaug)
    # TEST
#     print('Discriminative summary of test:')
    te_acc_median, te_acc_mean, te_roc_auc_onset1_median, te_roc_auc_onset1_mean, te_roc_auc_onset2_median, te_roc_auc_onset2_mean = summary_ttaug_disc_helper(labels_1hot_te, predictions_1hot_te_ttaug)
    
    discriminative_summary = {}
    discriminative_summary['tr_acc_median'] = tr_acc_median
    discriminative_summary['tr_acc_mean'] = tr_acc_mean
    discriminative_summary['tr_roc_auc_onset1_median'] = tr_roc_auc_onset1_median
    discriminative_summary['tr_roc_auc_onset1_mean'] = tr_roc_auc_onset1_mean
    discriminative_summary['tr_roc_auc_onset2_median'] = tr_roc_auc_onset2_median
    discriminative_summary['tr_roc_auc_onset2_mean'] = tr_roc_auc_onset2_mean
    
    discriminative_summary['val_acc_median'] = val_acc_median
    discriminative_summary['val_acc_mean'] = val_acc_mean
    discriminative_summary['val_roc_auc_onset1_median'] = val_roc_auc_onset1_median
    discriminative_summary['val_roc_auc_onset1_mean'] = val_roc_auc_onset1_mean
    discriminative_summary['val_roc_auc_onset2_median'] = val_roc_auc_onset2_median
    discriminative_summary['val_roc_auc_onset2_mean'] = val_roc_auc_onset2_mean
    
    discriminative_summary['te_acc_median'] = te_acc_median
    discriminative_summary['te_acc_mean'] = te_acc_mean
    discriminative_summary['te_roc_auc_onset1_median'] = te_roc_auc_onset1_median
    discriminative_summary['te_roc_auc_onset1_mean'] = te_roc_auc_onset1_mean
    discriminative_summary['te_roc_auc_onset2_median'] = te_roc_auc_onset2_median
    discriminative_summary['te_roc_auc_onset2_mean'] = te_roc_auc_onset2_mean
    
    return discriminative_summary    
    
summaries = []
for T in T_values:
    print('T = %g' % T)
    result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
    
    with open(result_file_name, 'rb') as filehandler:
        result_ttaug = pickle.load(filehandler)
        summaries.append(summarize_ttaug_discriminative_performance(result_ttaug))

del result_ttaug

In [None]:
#### Plot the discriminative performance summaries
from matplotlib.ticker import MaxNLocator
from collections import namedtuple

markers = ['8','x','8','x','8','x','v','v','v']
linestyles = ['-','--','-','--','-','--',':',':',':']
order = ['4','8','16','32','64','128']

############ MULTI-CLASS ACCURACY ############################

# Firstly, determine the SINGLE PREDICTION baseline
result_file_name = RESULTS_DIR + model.descriptor + '_SINGpred.pkl'

with open(result_file_name, 'rb') as filehandler:
    result = pickle.load(filehandler)

labels_1hot_tr  = result['train_labels_1hot']
labels_1hot_val = result['val_labels_1hot']
labels_1hot_te = result['test_labels_1hot']
predictions_1hot_tr = result['train_pred_1hot']
predictions_1hot_val = result['val_pred_1hot']
predictions_1hot_te = result['test_pred_1hot']

# Accuracy baseline
correct = np.equal(np.argmax(labels_1hot_tr, axis=-1), np.argmax(predictions_1hot_tr, axis=-1))
baseline_acc_tr = np.mean(correct) * np.ones(shape=(len(T_values),1),dtype=np.float32)
correct = np.equal(np.argmax(labels_1hot_val, axis=-1), np.argmax(predictions_1hot_val, axis=-1))
baseline_acc_val = np.mean(correct) * np.ones(shape=(len(T_values),1),dtype=np.float32)
correct = np.equal(np.argmax(labels_1hot_te, axis=-1), np.argmax(predictions_1hot_te, axis=-1))
baseline_acc_te = np.mean(correct) * np.ones(shape=(len(T_values),1),dtype=np.float32)

# Onset 1, ROC-AUC baseline
onset_level = 1
labels_bin = np.greater_equal(np.argmax(labels_1hot_tr, axis=1), onset_level)
pred_bin = np.sum(predictions_1hot_tr[:, onset_level:], axis=1) # Mx1
fpr, tpr, _ = roc_curve(labels_bin, pred_bin)
baseline_roc_auc_onset1_tr = auc(fpr, tpr) * np.ones(shape=(len(T_values),1),dtype=np.float32)

labels_bin = np.greater_equal(np.argmax(labels_1hot_val, axis=1), onset_level)
pred_bin = np.sum(predictions_1hot_val[:, onset_level:], axis=1) # Mx1
fpr, tpr, _ = roc_curve(labels_bin, pred_bin)
baseline_roc_auc_onset1_val = auc(fpr, tpr) * np.ones(shape=(len(T_values),1),dtype=np.float32)

labels_bin = np.greater_equal(np.argmax(labels_1hot_te, axis=1), onset_level)
pred_bin = np.sum(predictions_1hot_te[:, onset_level:], axis=1) # Mx1
fpr, tpr, _ = roc_curve(labels_bin, pred_bin)
baseline_roc_auc_onset1_te = auc(fpr, tpr) * np.ones(shape=(len(T_values),1),dtype=np.float32)

# Onset 2, ROC-AUC baseline
onset_level = 2
labels_bin = np.greater_equal(np.argmax(labels_1hot_tr, axis=1), onset_level)
pred_bin = np.sum(predictions_1hot_tr[:, onset_level:], axis=1) # Mx1
fpr, tpr, _ = roc_curve(labels_bin, pred_bin)
baseline_roc_auc_onset2_tr = auc(fpr, tpr) * np.ones(shape=(len(T_values),1),dtype=np.float32)

labels_bin = np.greater_equal(np.argmax(labels_1hot_val, axis=1), onset_level)
pred_bin = np.sum(predictions_1hot_val[:, onset_level:], axis=1) # Mx1
fpr, tpr, _ = roc_curve(labels_bin, pred_bin)
baseline_roc_auc_onset2_val = auc(fpr, tpr) * np.ones(shape=(len(T_values),1),dtype=np.float32)

labels_bin = np.greater_equal(np.argmax(labels_1hot_te, axis=1), onset_level)
pred_bin = np.sum(predictions_1hot_te[:, onset_level:], axis=1) # Mx1
fpr, tpr, _ = roc_curve(labels_bin, pred_bin)
baseline_roc_auc_onset2_te = auc(fpr, tpr) * np.ones(shape=(len(T_values),1),dtype=np.float32)
## end of baseline calculation

###############################################################################

multi_acc_col = []
pred_type_col = []
T_col = []

########### Multi-class Accuracy
# Training, median
pred_type = 'Train, median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(summary['tr_acc_median'], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Training, mean
pred_type = 'Train, mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(summary['tr_acc_mean'], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Validation, median
pred_type = 'Val., median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(summary['val_acc_median'], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Validation, mean
pred_type = 'Val., mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(summary['val_acc_mean'], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Test, median
pred_type = 'Test, median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(summary['te_acc_median'], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Test, mean
pred_type = 'Test, mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(summary['te_acc_mean'], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Baseline from single prediction
pred_type = 'Train, single'
for i in range(len(T_values)):
    T = T_values[i]
    
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(baseline_acc_tr[i], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
pred_type = 'Val., single'
for i in range(len(T_values)):
    T = T_values[i]
    
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(baseline_acc_val[i], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
pred_type = 'Test, single'
for i in range(len(T_values)):
    T = T_values[i]
    
    multi_acc_col = np.concatenate([multi_acc_col,
                                    np.reshape(baseline_acc_te[i], newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )

# Plot multi-class accuracy across T values
f = plt.figure(figsize=(7.5,7.5))
ax1 = f.add_subplot(1, 1, 1)

df_acc_multi_summary = pd.DataFrame()
df_acc_multi_summary['Multi-class Accuracy'] = multi_acc_col
df_acc_multi_summary['Prediction Type'] = pred_type_col
df_acc_multi_summary['T'] = T_col

ax1 = sns.pointplot(x='T', y='Multi-class Accuracy', hue='Prediction Type', ci=None, 
                    markers=markers, order=order,
                    data=df_acc_multi_summary, ax=ax1)
sns.despine()
plt.show()
###########End of Multi-class Accuracy################


# Set up the figures for Onset 1 and 2 ROC-AUC plots
# Plot multi-class accuracy across T values
f, [ax1, ax2] = plt.subplots(1, 2, sharey='row', figsize=(15,7.5))
# f = plt.figure(figsize=(15,7.5))
# ax1 = f.add_subplot(1, 2, 1)
# ax2 = f.add_subplot(1, 2, 2)

############### ROC-AUC Onset 1 and 2 ################
# Diagnostics file
result_file_name = RESULTS_DIR + model.descriptor + '_DIAG.pkl'

with open(result_file_name, 'rb') as filehandler:
    diagnostics = pickle.load(filehandler) 
    max_roc1_idx = np.argmax(diagnostics['val_roc1'])

#### Onset 1
baseline = diagnostics['val_roc1'][max_roc1_idx] * np.ones(shape=(len(T_values),1),dtype=np.float32)

roc_col = []
pred_type_col = []
T_col = []

########### ROC-AUC Onset 1
# Training, median
pred_type = 'Train, median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['tr_roc_auc_onset1_median'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Training, mean
pred_type = 'Train, mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['tr_roc_auc_onset1_mean'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Validation, median
pred_type = 'Val., median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['val_roc_auc_onset1_median'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Validation, mean
pred_type = 'Val., mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['val_roc_auc_onset1_mean'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Test, median
pred_type = 'Test, median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['te_roc_auc_onset1_median'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Test, mean
pred_type = 'Test, mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['te_roc_auc_onset1_mean'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Baseline from single prediction
pred_type = 'Train, single'
for i in range(len(T_values)):
    T = T_values[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(baseline_roc_auc_onset1_tr[i], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
pred_type = 'Val., single'
for i in range(len(T_values)):
    T = T_values[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(baseline_roc_auc_onset1_val[i], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
pred_type = 'Test, single'
for i in range(len(T_values)):
    T = T_values[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(baseline_roc_auc_onset1_te[i], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )

df_roc1_summary = pd.DataFrame()
df_roc1_summary['ROC-AUC'] = roc_col
df_roc1_summary['Prediction Type'] = pred_type_col
df_roc1_summary['T'] = T_col
# ax1 = sns.lineplot(x='T', y='ROC-AUC', hue='Split', style='Prediction Type', ci=None, 
#                    sort=False, markers=['o','d','s'], 
#                    data=df_roc1_summary, ax=ax1)
ax1 = sns.pointplot(x='T', y='ROC-AUC', hue='Prediction Type', ci=None, markers=markers, order=order, 
                    data=df_roc1_summary, ax=ax1)
sns.despine()


#### Onset 2
baseline = diagnostics['val_roc2'][max_roc1_idx] * np.ones(shape=(len(T_values),1),dtype=np.float32)

roc_col = []
pred_type_col = []
T_col = []

########### ROC-AUC Onset 2
# Training, median
pred_type = 'Train, median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['tr_roc_auc_onset2_median'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Training, mean
pred_type = 'Train, mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['tr_roc_auc_onset2_mean'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Validation, median
pred_type = 'Val., median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['val_roc_auc_onset2_median'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Validation, mean
pred_type = 'Val., mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['val_roc_auc_onset2_mean'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Test, median
pred_type = 'Test, median'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['te_roc_auc_onset2_median'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Test, mean
pred_type = 'Test, mean'
for i in range(len(T_values)):
    T = T_values[i]
    summary = summaries[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(summary['te_roc_auc_onset2_mean'], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
# Baseline from single prediction
pred_type = 'Train, single'
for i in range(len(T_values)):
    T = T_values[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(baseline_roc_auc_onset2_tr[i], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
pred_type = 'Val., single'
for i in range(len(T_values)):
    T = T_values[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(baseline_roc_auc_onset2_val[i], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )
pred_type = 'Test, single'
for i in range(len(T_values)):
    T = T_values[i]
    roc_col = np.concatenate([roc_col,
                              np.reshape(baseline_roc_auc_onset2_te[i], newshape=(1,))
                             ], 
                             axis=0
                            )
    pred_type_col = np.concatenate([pred_type_col,
                                    np.reshape(pred_type, newshape=(1,))
                                   ], 
                                   axis=0
                                  )
    T_col = np.concatenate([T_col,
                            np.reshape(str(T), newshape=(1,))
                           ], 
                           axis=0
                          )

df_roc2_summary = pd.DataFrame()
df_roc2_summary['ROC-AUC'] = roc_col
df_roc2_summary['Prediction Type'] = pred_type_col
df_roc2_summary['T'] = T_col
ax2 = sns.pointplot(x='T', y='ROC-AUC', hue='Prediction Type', ci=None, markers=markers, order=order, 
                    data=df_roc2_summary, ax=ax2)
sns.despine()

plt.show()

del df_acc_multi_summary, df_roc1_summary, df_roc2_summary, multi_acc_col, pred_type_col, T_col

In [None]:
## PLOTS FOR TTAUG RESULTS: Multi-class ROC curves
############################################################################################
from itertools import cycle
import scipy.stats as stats

def normalize_softmax_from_ttaug(predictions_1hot):
    return np.divide(predictions_1hot, np.sum(predictions_1hot, axis=-1, keepdims=True))

def make_dataframe_for_rocauc(labels_1hot_tr, predictions_1hot_tr, 
                              labels_1hot_val, predictions_1hot_val, 
                              labels_1hot_te, predictions_1hot_te, scheme):
    
    legend_labels = np.array(['0: No DR', '1: Mild DR', '2: Moderate DR', '3: Severe DR', '4: Proliferative DR'])

    fpr = dict()
    tpr = dict()
    roc_auc = dict()    

    roc_auc_col = []
    roc_auc_class_col = []
    roc_auc_split_col = []

    # Training
    split = 'Train, ' + str(scheme)
    for i in range(labels_1hot_tr.shape[1]):
        fpr[i], tpr[i], _ = roc_curve(labels_1hot_tr[:, i], predictions_1hot_tr[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(roc_auc[i], newshape=(1,))], 
                                     axis=0
                                    )
        roc_auc_class_col = np.concatenate([roc_auc_class_col, 
                                            np.reshape(legend_labels[i], newshape=(1,))], 
                                           axis=0
                                          )
        roc_auc_split_col = np.concatenate([roc_auc_split_col, 
                                            np.reshape(split, newshape=(1,))], 
                                           axis=0
                                          )
    df_roc_multi_summary_tr = pd.DataFrame()
    df_roc_multi_summary_tr['ROC-AUC'] = roc_auc_col
    df_roc_multi_summary_tr['Class'] = roc_auc_class_col
    df_roc_multi_summary_tr['Split'] = roc_auc_split_col
    
    roc_auc_col = []
    roc_auc_class_col = []
    roc_auc_split_col = []
    
    # Validation
    split = 'Val., ' + str(scheme)
    for i in range(labels_1hot_val.shape[1]):
        fpr[i], tpr[i], _ = roc_curve(labels_1hot_val[:, i], predictions_1hot_val[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(roc_auc[i], newshape=(1,))], 
                                     axis=0
                                    )
        roc_auc_class_col = np.concatenate([roc_auc_class_col, 
                                            np.reshape(legend_labels[i], newshape=(1,))], 
                                           axis=0
                                          )
        roc_auc_split_col = np.concatenate([roc_auc_split_col, 
                                            np.reshape(split, newshape=(1,))], 
                                           axis=0
                                          )
    df_roc_multi_summary_val = pd.DataFrame()
    df_roc_multi_summary_val['ROC-AUC'] = roc_auc_col
    df_roc_multi_summary_val['Class'] = roc_auc_class_col
    df_roc_multi_summary_val['Split'] = roc_auc_split_col
    
    roc_auc_col = []
    roc_auc_class_col = []
    roc_auc_split_col = []
    
    # Test
    split = 'Test, ' + str(scheme)
    for i in range(labels_1hot_te.shape[1]):
        fpr[i], tpr[i], _ = roc_curve(labels_1hot_te[:, i], predictions_1hot_te[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(roc_auc[i], newshape=(1,))], 
                                     axis=0
                                    )
        roc_auc_class_col = np.concatenate([roc_auc_class_col, 
                                            np.reshape(legend_labels[i], newshape=(1,))], 
                                           axis=0
                                          )
        roc_auc_split_col = np.concatenate([roc_auc_split_col, 
                                            np.reshape(split, newshape=(1,))], 
                                           axis=0
                                          )        
    df_roc_multi_summary_te = pd.DataFrame()
    df_roc_multi_summary_te['ROC-AUC'] = roc_auc_col
    df_roc_multi_summary_te['Class'] = roc_auc_class_col
    df_roc_multi_summary_te['Split'] = roc_auc_split_col
#     df_roc_multi_summary['Scheme'] = np.reshape(((scheme,) * len(roc_auc_split_col)), 
#                                                 newshape=(len(roc_auc_split_col),))
    
    return df_roc_multi_summary_tr, df_roc_multi_summary_val, df_roc_multi_summary_te


def plot_roc_curves_for_all_ttaug(result, title='Receiver Operating Characteristics', mode='mean'):
    legend_labels = np.array(['0: No DR', '1: Mild DR', '2: Moderate DR', '3: Severe DR', '4: Proliferative DR'])
        
    labels_1hot_tr  = result['train_labels_1hot'] # MxC
    labels_1hot_val = result['val_labels_1hot']
    labels_1hot_te = result['test_labels_1hot']
    predictions_1hot_tr = result['train_pred_1hot'] # MxTxC
    predictions_1hot_val = result['val_pred_1hot']
    predictions_1hot_te = result['test_pred_1hot']
    
    if mode=='mean':
        predictions_1hot_tr = normalize_softmax_from_ttaug(np.mean(predictions_1hot_tr, axis=1))
        predictions_1hot_val = normalize_softmax_from_ttaug(np.mean(predictions_1hot_val, axis=1))
        predictions_1hot_te = normalize_softmax_from_ttaug(np.mean(predictions_1hot_te, axis=1))
    else:   
        predictions_1hot_tr= normalize_softmax_from_ttaug(np.median(predictions_1hot_tr, axis=1))
        predictions_1hot_val = normalize_softmax_from_ttaug(np.median(predictions_1hot_val, axis=1))
        predictions_1hot_te = normalize_softmax_from_ttaug(np.median(predictions_1hot_te, axis=1))
    
    plot_roc_curve(labels_1hot_tr, predictions_1hot_tr, 
                   labels_1hot_val, predictions_1hot_val, 
                   labels_1hot_te, predictions_1hot_te)
            
mode = 'mean'
for T in T_values:
    print('T = %g' % T)
    result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
    
    with open(result_file_name, 'rb') as filehandler:
        result_ttaug = pickle.load(filehandler)
        plot_roc_curves_for_all_ttaug(result_ttaug, mode=mode)

#########################################################################################################
print('Now plotting the SINGLE vs TTAUG connected dot plots')
result_file_name = RESULTS_DIR + model.descriptor + '_SINGpred.pkl'
with open(result_file_name, 'rb') as filehandler:
    result = pickle.load(filehandler)

labels_1hot_tr  = result['train_labels_1hot']
labels_1hot_val = result['val_labels_1hot']
labels_1hot_te = result['test_labels_1hot']
predictions_1hot_tr = result['train_pred_1hot']
predictions_1hot_val = result['val_pred_1hot']
predictions_1hot_te = result['test_pred_1hot']
    
df_roc_multi_summary_tr, df_roc_multi_summary_val, df_roc_multi_summary_te = make_dataframe_for_rocauc(labels_1hot_tr, predictions_1hot_tr, 
                                                                                                       labels_1hot_val, predictions_1hot_val, 
                                                                                                       labels_1hot_te, predictions_1hot_te, scheme='Sing. pred.')
T = 128
print('T = %g' % T)
result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
with open(result_file_name, 'rb') as filehandler:
    result_ttaug = pickle.load(filehandler)

labels_1hot_tr  = result_ttaug['train_labels_1hot'] # MxC
labels_1hot_val = result_ttaug['val_labels_1hot']
labels_1hot_te = result_ttaug['test_labels_1hot']
predictions_1hot_tr = result_ttaug['train_pred_1hot'] # MxTxC
predictions_1hot_val = result_ttaug['val_pred_1hot']
predictions_1hot_te = result_ttaug['test_pred_1hot']

if mode=='mean':
    predictions_1hot_tr = normalize_softmax_from_ttaug(np.mean(predictions_1hot_tr, axis=1))
    predictions_1hot_val = normalize_softmax_from_ttaug(np.mean(predictions_1hot_val, axis=1))
    predictions_1hot_te = normalize_softmax_from_ttaug(np.mean(predictions_1hot_te, axis=1))
else:   
    predictions_1hot_tr= normalize_softmax_from_ttaug(np.median(predictions_1hot_tr, axis=1))
    predictions_1hot_val = normalize_softmax_from_ttaug(np.median(predictions_1hot_val, axis=1))
    predictions_1hot_te = normalize_softmax_from_ttaug(np.median(predictions_1hot_te, axis=1))

df_roc_multi_summary_tr_ttaug, df_roc_multi_summary_val_ttaug, df_roc_multi_summary_te_ttaug = make_dataframe_for_rocauc(labels_1hot_tr, predictions_1hot_tr, 
                                                                                                                         labels_1hot_val, predictions_1hot_val, 
                                                                                                                         labels_1hot_te, predictions_1hot_te, scheme='TTAUG')
df_2_comp = pd.concat([df_roc_multi_summary_tr, df_roc_multi_summary_tr_ttaug,
                       df_roc_multi_summary_val, df_roc_multi_summary_val_ttaug,
                       df_roc_multi_summary_te, df_roc_multi_summary_te_ttaug,])

# Plot the comparison
f = plt.figure(figsize=(7.5,7.5))
ax1 = f.add_subplot(1, 1, 1)
ax1 = sns.pointplot(x='Class', y='ROC-AUC', hue='Split', ci=None, 
                   data=df_2_comp, ax=ax1)
sns.despine()

plt.show()

del result_ttaug, result, df_2_comp, 
del df_roc_multi_summary_tr, df_roc_multi_summary_val, df_roc_multi_summary_te, df_roc_multi_summary_tr_ttaug, df_roc_multi_summary_val_ttaug, df_roc_multi_summary_te_ttaug

In [None]:
# Distribution of uncertainties across correct and misclassifications
# def entropy(p, axis=-1, keepdims=False):
#     return -np.sum(np.multiply(p, np.log(np.add(p,1e-6))), axis=axis, keepdims=keepdims)

def entropy(p, axis=-1, keepdims=False):
    # smoothing before entropy to avoid log 0s
    p = np.add(p, 1e-6) # add a small constant to all values
    p = np.divide(p, np.sum(p, axis=axis, keepdims=True)) # re-normalize the probabilities
    return -np.sum(np.multiply(p, np.log(p)), axis=axis, keepdims=keepdims)

def hist_uncertainty_binary_helper(labels_1hot, predictions_1hot_ttaug, num_bins, density=False, cumulative=False):
    onset_levels = [1, 2]
    bins = np.linspace(0., 1., num_bins)
    
    for onset_level in onset_levels:
        print('Onset level: %d'  % onset_level)
        f = plt.figure(figsize=(16,4))
        
        labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
        predictions_all_bin = np.sum(predictions_1hot_ttaug[:, :, onset_level:], axis=2) # MxTx5 --> MxT
        
        # Median and IQR
        predictions_bin_median = np.greater_equal(np.median(predictions_all_bin, axis=1), 0.5)
        correct = np.equal(labels_bin, predictions_bin_median)
        acc_median = np.mean(np.asarray(correct, dtype=np.float32))
        print('Median\'s accuracy (multi-class) : %.5f' % acc_median)    
        
        ax1 = f.add_subplot(1, 4, 1)
        uncertainty_est = stats.iqr(np.squeeze(predictions_all_bin[np.where(correct == True), :]), axis=1) # IQR from CxT matrix
        sns.distplot(uncertainty_est, bins=bins, hist=True, kde=True, rug=False, fit=None, 
                     hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, 
                     color='g', vertical=False, norm_hist=False, 
                     axlabel='IQR', label='correct', ax=ax1)
        sns.despine()
    
        uncertainty_est = stats.iqr(np.squeeze(predictions_all_bin[np.where(correct == False), :]), axis=1) # IQR from non-CxT matrix
        sns.distplot(uncertainty_est, bins=bins, hist=True, kde=True, rug=False, fit=None, 
                     hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, 
                     color='r', vertical=False, norm_hist=False, 
                     axlabel='IQR', label='missed', ax=ax1)
        sns.despine()
        
        # Entropy for median predictions
        predictions_bin_score_median = np.median(predictions_all_bin, axis=1, keepdims=True)
        softmax_median = np.concatenate((np.subtract(1.,predictions_bin_score_median), predictions_bin_score_median), axis=1) # Mx2
        uncertainty_est_median_entropy = entropy(softmax_median) # ENT from Mx2 matrix
        
        ax3 = f.add_subplot(1, 4, 2)
        uncertainty_est = uncertainty_est_median_entropy[np.where(correct == True)] # Entropy of CORRECTS
        sns.distplot(uncertainty_est, bins=bins, hist=True, kde=True, rug=False, fit=None, 
                     hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, 
                     color='g', vertical=False, norm_hist=False, 
                     axlabel='ENTROPY w.r.t. median', label='correct', ax=ax3)
        sns.despine()

        uncertainty_est = uncertainty_est_median_entropy[np.where(correct == False)] # Entropy of MISSED
        sns.distplot(uncertainty_est, bins=bins, hist=True, kde=True, rug=False, fit=None, 
                     hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, 
                     color='r', vertical=False, norm_hist=False, 
                     axlabel='ENTROPY w.r.t. median', label='missed', ax=ax3)
        sns.despine()
        
        # Mean and STD
        predictions_bin_mean = np.greater_equal(np.mean(predictions_all_bin, axis=1), 0.5)
        correct = np.equal(labels_bin, predictions_bin_mean)
        acc_mean = np.mean(np.asarray(correct, dtype=np.float32))
        print('Mean\'s accuracy (multi-class) : %.5f' % acc_mean)
        
        ax2 = f.add_subplot(1, 4, 3)
        uncertainty_est = np.std(np.squeeze(predictions_all_bin[np.where(correct == True), :]), axis=1) # STD from CxT matrix
        sns.distplot(uncertainty_est, bins=bins, hist=True, kde=True, rug=False, fit=None, 
                     hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, 
                     color='g', vertical=False, norm_hist=False, 
                     axlabel='STD', label='correct', ax=ax2)
        sns.despine()
    
        uncertainty_est = np.std(np.squeeze(predictions_all_bin[np.where(correct == False), :]), axis=1) # STD from non-CxT matrix
        sns.distplot(uncertainty_est, bins=bins, hist=True, kde=True, rug=False, fit=None, 
                     hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, 
                     color='r', vertical=False, norm_hist=False, 
                     axlabel='STD', label='missed', ax=ax2)
        sns.despine()
        
        # Entropy for median predictions
        predictions_bin_score_mean = np.mean(predictions_all_bin, axis=1, keepdims=True)
        softmax_mean = np.concatenate((np.subtract(1.,predictions_bin_score_mean), predictions_bin_score_mean), axis=1) # Mx2
        uncertainty_est_mean_entropy = entropy(softmax_mean) # ENT from Mx2 matrix
        
        ax4 = f.add_subplot(1, 4, 4)
        uncertainty_est = uncertainty_est_mean_entropy[np.where(correct == True)] # Entropy of CORRECTS
        sns.distplot(uncertainty_est, bins=bins, hist=True, kde=True, rug=False, fit=None, 
                     hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, 
                     color='g', vertical=False, norm_hist=False, 
                     axlabel='ENTROPY w.r.t. mean', label='correct', ax=ax4)
        sns.despine()    
        uncertainty_est = uncertainty_est_median_entropy[np.where(correct == False)] # Entropy of MISSED
        sns.distplot(uncertainty_est, bins=bins, hist=True, kde=True, rug=False, fit=None, 
                     hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, 
                     color='r', vertical=False, norm_hist=False, 
                     axlabel='ENTROPY w.r.t. mean', label='missed', ax=ax4)
        sns.despine()
        
        plt.legend()
        plt.show()
    
def hist_uncertainty_binary(result, num_bins, density, cumulative):
    labels_1hot_tr = result['train_labels_1hot'] # Mx5
    labels_1hot_val = result['val_labels_1hot']
    labels_1hot_te = result['test_labels_1hot']
    predictions_1hot_tr_ttaug = result['train_pred_1hot'] # MxTx5
    predictions_1hot_val_ttaug = result['val_pred_1hot']
    predictions_1hot_te_ttaug = result['test_pred_1hot']
    
    # TRAINING
#     print('Training results:')
#     hist_uncertainty_binary_helper(labels_1hot_tr, predictions_1hot_tr_ttaug, num_bins, density, cumulative)
#     # VALIDATION
#     print('Validation results:')
#     hist_uncertainty_binary_helper(labels_1hot_val, predictions_1hot_val_ttaug, num_bins, density, cumulative)
    # TEST
    print('Test results:')
    hist_uncertainty_binary_helper(labels_1hot_te, predictions_1hot_te_ttaug, num_bins, density, cumulative)


# T_values = [4, 8, 16] #, 32, 64] #, 128]
for T in T_values:
    print('T = %g' % T)
    result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
    
    with open(result_file_name, 'rb') as filehandler:
        result_ttaug = pickle.load(filehandler)
        hist_uncertainty_binary(result_ttaug, num_bins=100, density=True, cumulative=False)

del result_ttaug

In [None]:
# UNCERTAINTY-INFORMED DECISION REFERRAL
def decision_referral_helper(labels_1hot, predictions_1hot_ttaug, baseline):
    onset_levels = [1, 2]
    dec_ref_rates = np.divide(range(0, 50, 1), 100)
        
    i=0
    f = plt.figure(figsize=(15,7.5))
    for onset_level in onset_levels:
        labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
        predictions_all_bin = np.sum(predictions_1hot_ttaug[:, :, onset_level:], axis=2) # MxT
        
        # IQR and median
        uncertainty_est_median = stats.iqr(predictions_all_bin, axis=1) # IQR from MxT matrix
        predictions_bin_score_median = np.median(predictions_all_bin, axis=1, keepdims=True) # Mx1
        
        # STD and mean
        uncertainty_est_mean = np.std(predictions_all_bin, axis=1) # STD from MxT matrix
        predictions_bin_score_mean = np.mean(predictions_all_bin, axis=1, keepdims=True) # Mx1
                
        # Entropy for median predictions
        softmax_median = np.concatenate((np.subtract(1.,predictions_bin_score_median), predictions_bin_score_median), axis=1) # Mx2
        uncertainty_est_median_entropy = entropy(softmax_median) # ENT from Mx2 matrix
        
        # Entropy for mean predictions
        softmax_mean = np.concatenate((np.subtract(1.,predictions_bin_score_mean), predictions_bin_score_mean), axis=1) # Mx2
        uncertainty_est_mean_entropy = entropy(softmax_mean) # ENT from Mx2 matrix
        
        ax = f.add_subplot(1, 2, onset_level)
        roc_auc_col = []
        ref_fraction_col = []
        scheme_col = []
        
        # Decision referral w.r.t. median/IQR
        AUCs = [];
        randAUCs = [];
        for d in range(len(dec_ref_rates)):
            # num of items to drop off the end of list (most uncertains towards the end)
            drop_count = int(np.round(dec_ref_rates[d]*len(uncertainty_est_median)))
            
            if drop_count == 0:
                drop_count = 1
            
            rand_idx = np.random.permutation(len(uncertainty_est_median))
            rand_idx = rand_idx[:-drop_count]
            
            idx = np.argsort(uncertainty_est_median) # ascending order, so most uncertain at the end
            idx = idx[:-drop_count]         
            
            fpr, tpr, _ = roc_curve(labels_bin[idx], predictions_bin_score_median[idx])
            AUCs.append(auc(fpr, tpr))
            
            fpr, tpr, _ = roc_curve(labels_bin[rand_idx], predictions_bin_score_median[rand_idx])
            randAUCs.append(auc(fpr, tpr))           
        
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(randAUCs, newshape=(len(randAUCs),))], 
                                     axis=0
                                    )
        ref_fraction_col = np.concatenate([ref_fraction_col, 
                                           np.reshape(dec_ref_rates, newshape=(len(dec_ref_rates),))], 
                                          axis=0
                                         )
        scheme_col = np.concatenate([scheme_col, 
                                     np.reshape((('Random, IQR',)*len(randAUCs)), newshape=(len(randAUCs),))], 
                                    axis=0
                                   )
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(AUCs, newshape=(len(AUCs),))], 
                                     axis=0
                                    )
        ref_fraction_col = np.concatenate([ref_fraction_col, 
                                           np.reshape(dec_ref_rates, newshape=(len(dec_ref_rates),))], 
                                          axis=0
                                         )
        scheme_col = np.concatenate([scheme_col, 
                                     np.reshape((('Informed, IQR',)*len(AUCs)), newshape=(len(AUCs),))], 
                                    axis=0
                                   )
        
        # Decision referral w.r.t. mean/STD
        AUCs = [];
        randAUCs = [];
        for d in range(len(dec_ref_rates)):
            # num of items to drop off the end of list (most uncertains towards the end)
            drop_count = int(np.round(dec_ref_rates[d]*len(uncertainty_est_mean)))
            
            if drop_count == 0:
                drop_count = 1
            
            rand_idx = np.random.permutation(len(uncertainty_est_mean))
            rand_idx = rand_idx[:-drop_count]
            
            idx = np.argsort(uncertainty_est_mean) # ascending order, so most uncertain at the end
            idx = idx[:-drop_count]         
            
            fpr, tpr, _ = roc_curve(labels_bin[idx], predictions_bin_score_mean[idx])
            AUCs.append(auc(fpr, tpr))
            
            fpr, tpr, _ = roc_curve(labels_bin[rand_idx], predictions_bin_score_mean[rand_idx])
            randAUCs.append(auc(fpr, tpr))           
        
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(randAUCs, newshape=(len(randAUCs),))], 
                                     axis=0
                                    )
        ref_fraction_col = np.concatenate([ref_fraction_col, 
                                           np.reshape(dec_ref_rates, newshape=(len(dec_ref_rates),))], 
                                          axis=0
                                         )
        scheme_col = np.concatenate([scheme_col, 
                                     np.reshape((('Random, STD',)*len(randAUCs)), newshape=(len(randAUCs),))], 
                                    axis=0
                                   )
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(AUCs, newshape=(len(AUCs),))], 
                                     axis=0
                                    )
        ref_fraction_col = np.concatenate([ref_fraction_col, 
                                           np.reshape(dec_ref_rates, newshape=(len(dec_ref_rates),))], 
                                          axis=0
                                         )
        scheme_col = np.concatenate([scheme_col, 
                                     np.reshape((('Informed, STD',)*len(AUCs)), newshape=(len(AUCs),))], 
                                    axis=0
                                   )
        
        # Decision referral w.r.t. entropy{median}
        AUCs = [];
        randAUCs = [];
        for d in range(len(dec_ref_rates)):
            # num of items to drop off the end of list (most uncertains towards the end)
            drop_count = int(np.round(dec_ref_rates[d]*len(uncertainty_est_median_entropy)))
            
            if drop_count == 0:
                drop_count = 1
            
            rand_idx = np.random.permutation(len(uncertainty_est_median_entropy))
            rand_idx = rand_idx[:-drop_count]
            
            idx = np.argsort(uncertainty_est_median_entropy) # ascending order, so most uncertain at the end
            idx = idx[:-drop_count]         
            
            fpr, tpr, _ = roc_curve(labels_bin[idx], predictions_bin_score_median[idx])
            AUCs.append(auc(fpr, tpr))
            
            fpr, tpr, _ = roc_curve(labels_bin[rand_idx], predictions_bin_score_median[rand_idx])
            randAUCs.append(auc(fpr, tpr)) 
        
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(randAUCs, newshape=(len(randAUCs),))], 
                                     axis=0
                                    )
        ref_fraction_col = np.concatenate([ref_fraction_col, 
                                           np.reshape(dec_ref_rates, newshape=(len(dec_ref_rates),))], 
                                          axis=0
                                         )
        scheme_col = np.concatenate([scheme_col, 
                                     np.reshape((('Random, ENT (median)',)*len(randAUCs)), newshape=(len(randAUCs),))], 
                                    axis=0
                                   )                
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(AUCs, newshape=(len(AUCs),))], 
                                     axis=0
                                    )
        ref_fraction_col = np.concatenate([ref_fraction_col, 
                                           np.reshape(dec_ref_rates, newshape=(len(dec_ref_rates),))], 
                                          axis=0
                                         )
        scheme_col = np.concatenate([scheme_col, 
                                     np.reshape((('Informed, ENT (median)',)*len(AUCs)), newshape=(len(AUCs),))], 
                                    axis=0
                                   )
        
        # Decision referral w.r.t. entropy{mean}
        AUCs = [];
        randAUCs = [];
        for d in range(len(dec_ref_rates)):
            # num of items to drop off the end of list (most uncertains towards the end)
            drop_count = int(np.round(dec_ref_rates[d]*len(uncertainty_est_mean_entropy)))
            
            if drop_count == 0:
                drop_count = 1
            
            rand_idx = np.random.permutation(len(uncertainty_est_mean_entropy))
            rand_idx = rand_idx[:-drop_count]
            
            idx = np.argsort(uncertainty_est_mean_entropy) # ascending order, so most uncertain at the end
            idx = idx[:-drop_count]         
            
            fpr, tpr, _ = roc_curve(labels_bin[idx], predictions_bin_score_mean[idx])
            AUCs.append(auc(fpr, tpr))
            
            fpr, tpr, _ = roc_curve(labels_bin[rand_idx], predictions_bin_score_mean[rand_idx])
            randAUCs.append(auc(fpr, tpr))
        
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(randAUCs, newshape=(len(randAUCs),))], 
                                     axis=0
                                    )
        ref_fraction_col = np.concatenate([ref_fraction_col, 
                                           np.reshape(dec_ref_rates, newshape=(len(dec_ref_rates),))], 
                                          axis=0
                                         )
        scheme_col = np.concatenate([scheme_col, 
                                     np.reshape((('Random, ENT (mean)',)*len(randAUCs)), newshape=(len(randAUCs),))], 
                                    axis=0
                                   )                
        roc_auc_col = np.concatenate([roc_auc_col, 
                                      np.reshape(AUCs, newshape=(len(AUCs),))], 
                                     axis=0
                                    )
        ref_fraction_col = np.concatenate([ref_fraction_col, 
                                           np.reshape(dec_ref_rates, newshape=(len(dec_ref_rates),))], 
                                          axis=0
                                         )
        scheme_col = np.concatenate([scheme_col, 
                                     np.reshape((('Informed, ENT (mean)',)*len(AUCs)), newshape=(len(AUCs),))], 
                                    axis=0
                                   )
        
        df_dec_ref = pd.DataFrame()
        df_dec_ref['ROC-AUC'] = roc_auc_col
        df_dec_ref['referral rate'] = ref_fraction_col
        df_dec_ref['Scheme'] = scheme_col
    
        ax = sns.lineplot(x='referral rate', y='ROC-AUC', hue='Scheme', ci=None, 
                          data=df_dec_ref, ax=ax)
        ax.set_title('Onset ' + str(onset_level))
        sns.despine()

#     plt.tight_layout()
    plt.show()


def decision_referral_plot(result, baseline):
    labels_1hot_tr = result['train_labels_1hot'] # Mx5
    labels_1hot_val = result['val_labels_1hot']
    labels_1hot_te = result['test_labels_1hot']
    predictions_1hot_tr_ttaug = result['train_pred_1hot'] # MxTx5
    predictions_1hot_val_ttaug = result['val_pred_1hot']
    predictions_1hot_te_ttaug = result['test_pred_1hot']
    
    # TRAINING
#     print('Training results:')
#     decision_referral_helper(labels_1hot_tr, predictions_1hot_tr_ttaug)
    # VALIDATION
#     print('Validation results:')
#     decision_referral_helper(labels_1hot_val, predictions_1hot_val_ttaug)
    # TEST
    print('Test results:')
    decision_referral_helper(labels_1hot_te, predictions_1hot_te_ttaug, baseline)


for T in T_values:
    print('T = %g' % T)
    result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
    
    with open(result_file_name, 'rb') as filehandler:
        result_ttaug = pickle.load(filehandler)
        decision_referral_plot(result_ttaug, baseline)
del result_ttaug

In [None]:
#### QUANTIFY Heteroscedastic Aleatoric Uncertainty via InfoGain/Mutual Information
def aleatoric_uncertainty_quantify_helper(labels_1hot, predictions_1hot_ttaug):
    onset_levels = [1, 2]
#     colors = ['darkslategrey', 'darkgoldenrod']
    colors = ['deepskyblue', 'firebrick']
    
#     f = plt.figure(figsize=(15,6))
    i = 0
    f, axes = plt.subplots(1, 3, figsize=(15,6))
    for onset_level in onset_levels:
        print('Onset level: %d'  % onset_level)
        labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
        predictions_all_bin_y1 = np.sum(predictions_1hot_ttaug[:, :, onset_level:], 
                                        axis=2, keepdims=True) # MxTxC --> MxTx1
#         predictions_bin = np.greater_equal(np.mean(predictions_all_bin_y1, axis=1), 0.5)
        
        predictions_all_bin_y0 = np.subtract(1., predictions_all_bin_y1) # MxTx1
        predictions_all_bin_probs = np.concatenate((predictions_all_bin_y0, predictions_all_bin_y1), axis=-1) # MxTx2
        
        # Now, compute InfoGain/MI
        p_ttaug_y1 = np.mean(predictions_all_bin_probs[:,:,1], axis=1, keepdims=True) # MxTx1
        p_ttaug_y0 = np.subtract(1., p_ttaug_y1) # MxTx1
        p_ttaug = np.concatenate((p_ttaug_y0, p_ttaug_y1), axis=-1) # MxTx2
        
        predictive_entropy = entropy(p_ttaug, axis=-1, keepdims=False) # M
        expected_entropy = np.mean(entropy(predictions_all_bin_probs, axis=-1, keepdims=False), axis=-1, keepdims=False) 
        MI = np.subtract(predictive_entropy, expected_entropy)
        
        print('NaNs in Pred. Ent. : ' + str(np.any(np.isnan(predictive_entropy))))
        print('NaNs in Exp. Ent. : ' + str(np.any(np.isnan(expected_entropy))))           
        print('NaNs in MI : ' + str(np.any(np.isnan(MI))))           
        
        print('Min MI: %g\t1st Quar.: %g\tMean: %g\tLast Quar: %g\tMax MI: %g' % 
              (np.amin(MI), np.percentile(MI, 25), np.mean(MI), np.percentile(MI, 75), np.amax(MI)))
        
        sort_idx = np.argsort(predictive_entropy)
        
#         ax = f.add_subplot(1, 3, onset_level)
        point_size = 1.5
        axes[onset_level-1].scatter(range(len(predictive_entropy)), predictive_entropy[sort_idx], s=point_size, 
                                  color='m', alpha=0.5, label='predictive entropy')
        axes[onset_level-1].scatter(range(len(expected_entropy)), expected_entropy[sort_idx], s=point_size, 
                                  color='g', alpha=0.5, label='expected entropy')
        axes[onset_level-1].scatter(range(len(MI)), MI[sort_idx], s=point_size, 
                                  color='b', alpha=0.5, label='InfoGain/MutualInfo')
        axes[onset_level-1].set_xlabel('Examples sorted by predictive entropy')
        axes[onset_level-1].set_label('entropy')
        axes[onset_level-1].legend(shadow=True, fancybox=True, markerscale=3, handletextpad=0.1)
        axes[onset_level-1].grid(True)
        
        # Relative gain        
        temp = np.sort(np.divide(predictive_entropy, expected_entropy))
        axes[-1].scatter(range(len(temp)), temp, color=colors[i], alpha=0.5, s=point_size, 
                         label='pred. ent. / exp. ent., onset ' + str(onset_level))
        i = i + 1
    
    axes[-1].set_xlabel('Examples sorted by the relative gain')
    axes[-1].set_ylabel('relative gain')
    axes[-1].legend(shadow=True, fancybox=True, markerscale=3, handletextpad=0.1)
    axes[-1].grid(True)
    plt.show()   
    

def aleatoric_uncertainty_quantify_plot(result):
    labels_1hot_tr = result['train_labels_1hot'] # Mx5
    labels_1hot_val = result['val_labels_1hot']
    labels_1hot_te = result['test_labels_1hot']
    predictions_1hot_tr_ttaug = result['train_pred_1hot'] # MxTx5
    predictions_1hot_val_ttaug = result['val_pred_1hot']
    predictions_1hot_te_ttaug = result['test_pred_1hot']
    
    # TRAINING
#     print('Training results:')
#     aleatoric_uncertainty_quantify_helper(labels_1hot_tr, predictions_1hot_tr_ttaug)
    # VALIDATION
#     print('Validation results:')
#     aleatoric_uncertainty_quantify_helper(labels_1hot_val, predictions_1hot_val_ttaug)
    # TEST
    print('Test results:')
    aleatoric_uncertainty_quantify_helper(labels_1hot_te, predictions_1hot_te_ttaug)

for T in T_values:
    print('T = %g' % T)
    result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
    
    with open(result_file_name, 'rb') as filehandler:
        result_ttaug = pickle.load(filehandler)
        aleatoric_uncertainty_quantify_plot(result_ttaug)

del result_ttaug

In [None]:
def relative_gain_plot_helper(labels_1hot, predictions_1hot_ttaug, axes, i, T):
    onset_levels = [1, 2]
#     colors = ['burlywood','maroon', 'darkslateblue', 'lightskyblue', 'fuchsia', 'rosybrown'] # ['darkslategrey', 'darkgoldenrod']
    colors = ['r','g', 'b', 'maroon', 'fuchsia', 'lightskyblue']
    point_size = 1.0
    
    for onset_level in onset_levels:
#         print('Onset level: %d'  % onset_level)
        labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
        predictions_all_bin_y1 = np.sum(predictions_1hot_ttaug[:, :, onset_level:], 
                                        axis=2, keepdims=True) # MxTxC --> MxTx1
#         predictions_bin = np.greater_equal(np.mean(predictions_all_bin_y1, axis=1), 0.5)
        
        predictions_all_bin_y0 = np.subtract(1., predictions_all_bin_y1) # MxTx1
        predictions_all_bin_probs = np.concatenate((predictions_all_bin_y0, predictions_all_bin_y1), axis=-1) # MxTx2
        
        # Now, compute InfoGain/MI
        p_ttaug_y1 = np.mean(predictions_all_bin_probs[:,:,1], axis=1, keepdims=True) # MxTx1
        p_ttaug_y0 = np.subtract(1., p_ttaug_y1) # MxTx1
        p_ttaug = np.concatenate((p_ttaug_y0, p_ttaug_y1), axis=-1) # MxTx2
        
        predictive_entropy = entropy(p_ttaug, axis=-1, keepdims=False) # M
        expected_entropy = np.mean(entropy(predictions_all_bin_probs, axis=-1, keepdims=False), axis=-1, keepdims=False) 
        
        # now, sort w.r.t. relative gain and plot
        relative_gain = np.divide(predictive_entropy, expected_entropy)
        temp = np.sort(relative_gain)
        # trim the first 50 and last 50 to emphasize the midrange of gain in the figure
#         temp = temp[100:]
        temp = temp[:-500]
        axes[onset_level-1].scatter(range(len(temp)), temp, color=colors[i], alpha=0.5, s=point_size, 
                                    label='pred. ent. / exp. ent., T=' + str(T))
#         axes[onset_level-1].plot(range(len(temp)), temp, color=colors[i], linestyle='--', 
#                                  label='pred. ent. / exp. ent., T=' + str(T))
        axes[onset_level-1].set_xlabel('Examples sorted by the relative gain')
        axes[onset_level-1].set_ylabel('relative gain')
        axes[onset_level-1].set_title('Onset ' + str(onset_level))
        axes[onset_level-1].legend(shadow=True, fancybox=True, markerscale=5, handletextpad=0.1)
        axes[onset_level-1].grid(True)


def relative_gain_plot(result, axes, i, T):
    labels_1hot_tr = result['train_labels_1hot'] # Mx5
    labels_1hot_val = result['val_labels_1hot']
    labels_1hot_te = result['test_labels_1hot']
    predictions_1hot_tr_ttaug = result['train_pred_1hot'] # MxTx5
    predictions_1hot_val_ttaug = result['val_pred_1hot']
    predictions_1hot_te_ttaug = result['test_pred_1hot']
    
    # TRAINING
#     print('Training results:')
#     relative_gain_plot_helper(labels_1hot_tr, predictions_1hot_tr_ttaug)
    # VALIDATION
#     print('Validation results:')
#     relative_gain_plot_helper(labels_1hot_val, predictions_1hot_val_ttaug)
    # TEST
    print('Test results:')
    relative_gain_plot_helper(labels_1hot_te, predictions_1hot_te_ttaug, axes, i, T)


i=0
f, axes = plt.subplots(1, 2, figsize=(16,8))
for T in T_values:
    print('T = %g' % T)
    result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'

    with open(result_file_name, 'rb') as filehandler:
        result_ttaug = pickle.load(filehandler)
        relative_gain_plot(result_ttaug, axes, i, T)
    i += 1
plt.show()

del result_ttaug

In [None]:
# ALIGNED t-SNE Maps
from FItSNE.fast_tsne import fast_tsne 
# import scipy.spatial as sp
from sklearn.metrics import pairwise_distances
from sklearn import preprocessing
from matplotlib import colors

from joblib import Parallel, delayed
import multiprocessing

def find_kNN_idx_per_column(col, k):
    sorted_idx = np.argsort(col) # asceding order
    return sorted_idx[:k] # return the closest (top) k

def find_kNN_inits_per_column(kNN_idx, Z, dims=2):
    return Z[kNN_idx,:dims]

def plot_given_map(Z, y, ax, title, point_size=2, markerscale=5, plotting_order='original'):
    col = np.array(['burlywood','maroon', 'lightskyblue', 'darkslateblue', 'fuchsia'])    
    legend_labels = np.array(['0: No DR', '1: Mild DR', '2: Moderate DR', '3: Severe DR', '4: Proliferative DR'])
    
    if plotting_order == 'original':
        plotting_order = range(len(np.unique(y)))
    else:
        plotting_order = [0,2,3,4,1] #[4,3,0,2,1]
    
    for i in plotting_order: # for i in range(len(np.unique(y))):
        mask = (y == i)
        ax.scatter(Z[mask,0], Z[mask,1], c=col[y[mask]], label=legend_labels[i], s=point_size)
    ax.set_title(title)
    leg = ax.legend(bbox_to_anchor=(0., 1.125, 1., 0.), loc='upper center', ncol=3, 
                    mode="expand", shadow=True, fancybox=True, markerscale=markerscale, handletextpad=0.1)
    leg.get_frame().set_alpha(0.75)
    
def plot_aligned_tsne(Xa, ya, 
                      Xb, yb, 
                      ax1, ax2, 
                      perplexity=30, max_iter=1000, 
                      variance_to_keep=0.99, k=10,
                      multicore_kNN=False, num_cores=10,
                      plotting_order='original'):
    
    ###### First, standardize the data
    scaler = preprocessing.StandardScaler() # Zero mean, unit variance
    Xa = scaler.fit_transform(Xa)
    Xb = scaler.transform(Xb)
    
    ###### Compute the pairwise distances and determined kNNs
    print('Computing the pairwise distances')
#     K = sp.distance.cdist(Xa, Xb, metric='euclidean')  # matrix of pairwise distances
    K = pairwise_distances(X=Xa, Y=Xb, metric='euclidean') 
    Ma, Mb = K.shape
    print('Finding kNNs...')
    kNN_idx_list = []
    if not multicore_kNN:
        for j in range(Mb): # loop over the items to be aligned with the reference map from Xa.
            idx = np.argsort(K[:,j]) # ascending order, so most distant at the end. kNNs are in the front
            kNN_idx_list.append(idx[:k]) # append the kNN indices
    else:
        kNNs_by_idx = Parallel(n_jobs=num_cores)(delayed(find_kNN_idx_per_column)(K[:,j], k) for j in range(Mb))
        for j in range(len(kNNs_by_idx)):
            kNN_idx_list.append(kNNs_by_idx[j])
           
    ###### Do PCA on the reference data and keep D dimensions
    print('PCA on reference data ...')
    Sigma = np.cov(np.transpose(Xa))
    U, s, V = np.linalg.svd(Sigma, full_matrices=False)
    sum_s = np.sum(s)
    print('Total components : %g' % len(s))
    for d in range(len(s)):
        var_explained = np.sum(s[:d]) / sum_s
        if var_explained >= variance_to_keep:
            break
    print('%g of variance explained with %d components.' % (var_explained, d))
    D = d
    XaD = np.dot(Xa, U[:,:D])   # np.dot(U, np.diag(s))[:,:D]
    PCAinit = XaD[:,:2] / np.std(XaD[:,0]) * 0.0001
    
    ####### tSNE on the reference (training) data
    print('Computing tSNE map for the reference data')
    Za = fast_tsne(XaD, perplexity=perplexity,
                   max_iter=max_iter, learning_rate=learning_rate,
#                    stop_early_exag_iter=int(max_iter*0.25), early_exag_coeff=12,
#                    start_late_exag_iter=int(max_iter*0.90), late_exag_coeff=2, 
                   initialization=PCAinit)
    plot_given_map(Za, ya, ax1, 'tSNE map, reference data, ' + ' Perplexity : ' + str(perplexity),
                   plotting_order=plotting_order)

    ####################################################################
    ###### ALIGNMENT begins
    XbD = np.dot(Xb, U[:,:D])    # np.dot(U, np.diag(s))[:,:D]

    print('Collecting initialization points based on kNNs')
    kNN_init = []
    if not multicore_kNN:
        for kNN_idx in kNN_idx_list:
            kNNs = Za[kNN_idx,:2]
            kNN_init.append(np.mean(kNNs, axis=0))
            kNN_init = np.reshape(kNN_init, newshape=(len(kNN_idx_list),2))
    else:
        kNN_init = Parallel(n_jobs=num_cores)(delayed(find_kNN_inits_per_column)(kNN_idx_list[j], Za, 2) for j in range(len(kNN_idx_list)))
        kNN_init = np.reshape(kNN_init, newshape=(len(kNN_idx_list), k, 2))
        kNN_init = np.mean(kNN_init, axis=1)
#     kNN_init = kNN_init[:,:2] / np.std(kNN_init[:,0]) * 0.0001
    
    print('Computing tSNE map for the auxillary data')
    Zb = fast_tsne(XbD, perplexity=perplexity,
                   max_iter=max_iter, learning_rate=learning_rate,
#                    stop_early_exag_iter=int(max_iter*0.25), early_exag_coeff=12,
#                    start_late_exag_iter=int(max_iter*0.90), late_exag_coeff=2, 
                   initialization=kNN_init)

    # Concatenate the mappings and plot
    Z = np.concatenate([Za,Zb], axis=0)
    y = np.concatenate([ya,yb], axis=0)   
    plot_given_map(Z, y, ax2, 'Aligned tSNE, ' + ' Perplexity : ' + str(perplexity),
                   plotting_order=plotting_order)

# Now, read the SINGLE PRED. results from file and plot
result_file_name = RESULTS_DIR + model.descriptor + '_SINGpred.pkl'

with open(result_file_name, 'rb') as filehandler:
    result = pickle.load(filehandler)
    
    X_tr = result['train_features'] 
    X_val = result['val_features']
    X_te = result['test_features']
    
    y_tr = np.argmax(result['train_labels_1hot'], axis=1)
    y_val = np.argmax(result['val_labels_1hot'], axis=1)
    y_te = np.argmax(result['test_labels_1hot'], axis=1)

X_valte = np.concatenate([X_val, X_te], axis=0)
y_valte = np.concatenate([y_val, y_te], axis=0)
# del result, X_val, y_val, X_te, y_te, # make some room
    

perplexities = [100,500,1000,2000] # [10,20,30,40,50,100]
variance_to_keep = 0.99
num_neighbors = 10
max_iter = 3000
learning_rate = 500
plotting_order='my_order'
for perp in perplexities:
    print('Perplexity : %g' % perp)
    f = plt.figure(figsize=(15,7.5))
    ax1 = f.add_subplot(1, 2, 1)
    ax2 = f.add_subplot(1, 2, 2)
    
    plot_aligned_tsne(X_tr, y_tr, 
                      X_valte, y_valte, 
                      ax1, ax2, 
                      perplexity=perp, max_iter=max_iter, 
                      variance_to_keep=variance_to_keep, k=num_neighbors, 
                      multicore_kNN=True, num_cores=multiprocessing.cpu_count(),
                      plotting_order=plotting_order)
    plt.tight_layout()
    plt.show()

del result, X_val, y_val, X_te, y_te, # make some room
del X_tr, y_tr, X_valte, y_valte

In [None]:
# ALIGNED t-SNE Maps with a bit of tSNE LOGIC
from FItSNE.fast_tsne import fast_tsne 
# import scipy.spatial as sp
from sklearn.metrics import pairwise_distances
from sklearn import preprocessing
from matplotlib import colors

from joblib import Parallel, delayed
import multiprocessing

def plot_aligned_tsne_with_labels_and_predictions(Xa, ya, y_pred_a, y_pred_a_ttaug,
                                                  Xb, yb, y_pred_b, y_pred_b_ttaug,
                                                  ax1, ax2, ax3, ax4, ax5,
                                                  perplexity=30, max_iter=1000, 
                                                  variance_to_keep=0.99, k=10,
                                                  multicore_kNN=False, num_cores=10,
                                                  plotting_order='original'):
    
    ###### First, standardize the data
    scaler = preprocessing.StandardScaler() # Zero mean, unit variance
    Xa = scaler.fit_transform(Xa)
    Xb = scaler.transform(Xb)
    
    ###### Compute the pairwise distances and determined kNNs
    print('Computing the pairwise distances')
#     K = sp.distance.cdist(Xa, Xb, metric='euclidean')  # matrix of pairwise distances
    K = pairwise_distances(X=Xa, Y=Xb, metric='euclidean') 
    Ma, Mb = K.shape
    print('Finding kNNs...')
    kNN_idx_list = []
    if not multicore_kNN:
        for j in range(Mb): # loop over the items to be aligned with the reference map from Xa.
            idx = np.argsort(K[:,j]) # ascending order, so most distant at the end. kNNs are in the front
            kNN_idx_list.append(idx[:k]) # append the kNN indices
    else:
        kNNs_by_idx = Parallel(n_jobs=num_cores)(delayed(find_kNN_idx_per_column)(K[:,j], k) for j in range(Mb))
        for j in range(len(kNNs_by_idx)):
            kNN_idx_list.append(kNNs_by_idx[j])
           
    ###### Do PCA on the reference data and keep D dimensions
    print('PCA on reference data ...')
    Sigma = np.cov(np.transpose(Xa))
    U, s, V = np.linalg.svd(Sigma, full_matrices=False)
    sum_s = np.sum(s)
    print('Total components : %g' % len(s))
    for d in range(len(s)):
        var_explained = np.sum(s[:d]) / sum_s
        if var_explained >= variance_to_keep:
            break
    print('%g of variance explained with %d components.' % (var_explained, d))
    D = d
    XaD = np.dot(Xa, U[:,:D])   # np.dot(U, np.diag(s))[:,:D]
    PCAinit = XaD[:,:2] / np.std(XaD[:,0]) * 0.0001
    
    ####### tSNE on the reference (training) data
    print('Computing tSNE map for the reference data')
    Za = fast_tsne(XaD, perplexity=perplexity,
                   max_iter=max_iter, learning_rate=learning_rate,
#                    stop_early_exag_iter=int(max_iter*0.25), early_exag_coeff=12,
#                    start_late_exag_iter=int(max_iter*0.90), late_exag_coeff=2, 
                   initialization=PCAinit)
    ## DO NOT Plot the mappings for reference data, NOW!
#     plot_given_map(Za, ya, ax1, 'tSNE map, reference data, ' + ' Perplexity : ' + str(perplexity), 
# plotting_order=plotting_order)
    
    
    ####################################################################
    ###### ALIGNMENT begins
    XbD = np.dot(Xb, U[:,:D])    # np.dot(U, np.diag(s))[:,:D]

    print('Collecting initialization points based on kNNs')
    kNN_init = []
    if not multicore_kNN:
        for kNN_idx in kNN_idx_list:
            kNNs = Za[kNN_idx,:2]
            kNN_init.append(np.mean(kNNs, axis=0))
            kNN_init = np.reshape(kNN_init, newshape=(len(kNN_idx_list),2))
    else:
        kNN_init = Parallel(n_jobs=num_cores)(delayed(find_kNN_inits_per_column)(kNN_idx_list[j], Za, 2) for j in range(len(kNN_idx_list)))
        kNN_init = np.reshape(kNN_init, newshape=(len(kNN_idx_list), k, 2))
        kNN_init = np.mean(kNN_init, axis=1)
#     kNN_init = kNN_init[:,:2] / np.std(kNN_init[:,0]) * 0.0001
    
    print('Computing tSNE map for the auxillary data')
    Zb = fast_tsne(XbD, perplexity=perplexity,
                   max_iter=max_iter, learning_rate=learning_rate,
#                    stop_early_exag_iter=int(max_iter*0.25), early_exag_coeff=12,
#                    start_late_exag_iter=int(max_iter*0.90), late_exag_coeff=2, 
                   initialization=kNN_init)

    # Concatenate the mappings and plot
    Z = np.concatenate([Za,Zb], axis=0)
    y = np.concatenate([ya,yb], axis=0)
    y_pred = np.concatenate([y_pred_a, y_pred_b], axis=0)
    y_pred_ttaug = np.concatenate([y_pred_a_ttaug, y_pred_b_ttaug], axis=0)
    
    point_size = 2
    markerscale = 4
    
    plot_given_map(Z, y, 
                   ax1, 'LABELS, ' + 'Perplexity : ' + str(perplexity), 
                   point_size=point_size, 
                   markerscale=markerscale, plotting_order=plotting_order)
    
    # Now, plot the tSNE maps with predictions, instead of labels, for comparison
    plot_given_map(Z, y_pred, 
                   ax2, 'SINGLE pred., ' + 'Perplexity : ' + str(perplexity),
                   point_size=point_size,
                   markerscale=markerscale, plotting_order=plotting_order)
    plot_given_map(Z, y_pred_ttaug, 
                   ax3, 'TTAUG pred., ' + 'Perplexity : ' + str(perplexity),
                   point_size=point_size,
                   markerscale=markerscale, plotting_order=plotting_order)
        
    masked_preds = np.not_equal(y, y_pred) # find the wrong predictions
    masked_pred_idx = np.where(masked_preds == True)[0] # take the first in tuple!!! OMG!
    plot_given_map(Z[masked_pred_idx,:], y[masked_pred_idx], 
                   ax4, 'WRONG SINGLE pred., ' + 'Perplexity : ' + str(perplexity) + 
                   ' , # : ' + str(np.sum(masked_preds)),
                   point_size=point_size,
                   markerscale=markerscale, plotting_order=plotting_order)
    
    masked_preds = np.not_equal(y, y_pred_ttaug) # find the wrong predictions
    masked_pred_idx = np.where(masked_preds == True)[0]
    plot_given_map(Z[masked_pred_idx,:], y[masked_pred_idx], 
                   ax5, 'WRONG TTAUG pred., ' + 'Perplexity : ' + str(perplexity) + 
                   ' , # : ' + str(np.sum(masked_preds)), 
                   point_size=point_size,
                   markerscale=markerscale, plotting_order=plotting_order)

# Now, read the SINGLE PRED. results from file and plot
result_file_name = RESULTS_DIR + model.descriptor + '_SINGpred.pkl'

with open(result_file_name, 'rb') as filehandler:
    result = pickle.load(filehandler)
    
    X_tr = result['train_features'] 
    X_val = result['val_features']
    X_te = result['test_features']
    
    y_tr = np.argmax(result['train_labels_1hot'], axis=1)
    y_val = np.argmax(result['val_labels_1hot'], axis=1)
    y_te = np.argmax(result['test_labels_1hot'], axis=1)
    
    pred_tr = np.argmax(result['train_pred_1hot'], axis=1)
    pred_val = np.argmax(result['val_pred_1hot'], axis=1)
    pred_te = np.argmax(result['test_pred_1hot'], axis=1)

X_valte = np.concatenate([X_val, X_te], axis=0)
y_valte = np.concatenate([y_val, y_te], axis=0)
pred_valte = np.concatenate([pred_val, pred_te], axis=0)
del result, X_val, y_val, pred_val, X_te, y_te, pred_te # make some room
    

# Get the predictions from TTAUG results
T = 128
mode = 'mean'
# for T in T_values:
print('T = %g' % T)
result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
with open(result_file_name, 'rb') as filehandler:
    result_ttaug = pickle.load(filehandler)
    
    if mode == 'mean':
        pred_tr_ttaug = np.argmax(np.mean(result_ttaug['train_pred_1hot'], axis=1), axis=-1)
        pred_val_ttaug = np.argmax(np.mean(result_ttaug['val_pred_1hot'], axis=1), axis=-1)
        pred_te_ttaug = np.argmax(np.mean(result_ttaug['test_pred_1hot'], axis=1), axis=-1)
    else: # median
        pred_tr_ttaug = np.argmax(np.median(result_ttaug['train_pred_1hot'], axis=1), axis=-1)
        pred_val_ttaug = np.argmax(np.median(result_ttaug['val_pred_1hot'], axis=1), axis=-1)
        pred_te_ttaug = np.argmax(np.median(result_ttaug['test_pred_1hot'], axis=1), axis=-1)

pred_valte_ttaug = np.concatenate([pred_val_ttaug, pred_te_ttaug], axis=0)
del pred_val_ttaug, pred_te_ttaug


for perp in perplexities:
    print('Perplexity : %g' % perp)
    f = plt.figure(figsize=(22,15))
    ax1 = f.add_subplot(2, 3, 1) 
    ax2 = f.add_subplot(2, 3, 2) 
    ax3 = f.add_subplot(2, 3, 3) 
    ax4 = f.add_subplot(2, 3, 5) # skip 4
    ax5 = f.add_subplot(2, 3, 6) 
    plot_aligned_tsne_with_labels_and_predictions(X_tr, y_tr, pred_tr, pred_tr_ttaug, 
                                                  X_valte, y_valte, pred_valte, pred_valte_ttaug,
                                                  ax1, ax2, ax3, ax4, ax5, 
                                                  perplexity=perp, max_iter=max_iter, 
                                                  variance_to_keep=variance_to_keep, k=num_neighbors, 
                                                  multicore_kNN=True, num_cores=multiprocessing.cpu_count(),
                                                  plotting_order=plotting_order)
    plt.tight_layout()
    plt.show()

del X_tr, y_tr, X_valte, y_valte, pred_valte, pred_tr, pred_tr_ttaug

In [None]:
# ALIGNED t-SNE Maps with Uncertainty
from FItSNE.fast_tsne import fast_tsne 
from sklearn.metrics import pairwise_distances
from sklearn import preprocessing
from matplotlib import colors
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# plt.style.use('dark_background')

def compute_uncertainties(predictions_1hot_ttaug, onset_level=1, mode='mean', use_entropy=False):
    # Given MxTxC predictions, estimate the uncertainties w.r.t. the given mode: median, mean, or entropty of these.
#     labels_bin = np.greater_equal(np.argmax(labels_1hot, axis=1), onset_level)
    predictions_all_bin = np.sum(predictions_1hot_ttaug[:, :, onset_level:], axis=-1) # MxT
    
    uncertainty_est = None
    predictions_bin_score = None
    if mode == 'mean':
        if use_entropy: # Entropy for mean predictions
            predictions_bin_score = np.mean(predictions_all_bin, axis=1, keepdims=True) # Mx1
            softmax_mean = np.concatenate((np.subtract(1.,predictions_bin_score), predictions_bin_score), axis=1) # Mx2
            uncertainty_est = entropy(softmax_mean) # ENT from Mx2 matrix
        else: # STD 
            uncertainty_est = np.std(predictions_all_bin, axis=1) # STD from MxT matrix
        
    elif mode == 'median':
        if use_entropy:# Entropy for median predictions
            predictions_bin_score = np.median(predictions_all_bin, axis=1, keepdims=True) # Mx1
            softmax_median = np.concatenate((np.subtract(1.,predictions_bin_score), predictions_bin_score), axis=1) # Mx2
            uncertainty_est = entropy(softmax_median) # ENT from Mx2 matrix
        else: # IQR
            uncertainty_est = stats.iqr(predictions_all_bin, axis=1) # IQR from MxT matrix
    
    assert uncertainty_est is not None, 'No uncertainty estimate computed!'
    
    return uncertainty_est

def plot_given_map_with_uncertainty(Z, y, ax, title, point_size=2, markerscale=5, plotting_order='original', 
                                    uncertainty=None):
    col = np.array(['burlywood','maroon', 'lightskyblue', 'darkslateblue', 'fuchsia'])
    legend_labels = np.array(['0: No DR', '1: Mild DR', '2: Moderate DR', '3: Severe DR', '4: Proliferative DR'])
    
    if plotting_order == 'original':
        plotting_order = range(len(np.unique(y)))
    else:
        plotting_order = [0,2,3,4,1] #[4,3,0,2,1]
    
    ax.set_facecolor('black')
    
    for i in plotting_order: # for i in range(len(np.unique(y))):
#         print('Plotting for Class %g' % i)
        mask = (y == i)
        if uncertainty is not None:
            col_rgb = []
            col_names = col[y[mask]]
            for k in range(len(col_names)):
                col_rgb.append(colors.hex2color(colors.cnames[col_names[k]]))
            col_rgb = np.asarray(col_rgb)
#             c = colors.to_rgba_array(c=col_rgb, alpha=uncertainty[mask]) # np.subtract(1.0, uncertainty[mask])) 
            alpha = np.reshape(uncertainty[mask], newshape=(len(uncertainty[mask]), 1))
            c = np.concatenate([col_rgb, alpha], axis=-1)
            c = np.minimum(c, 1.0 - 1e-10)
            c = np.maximum(c, 0.0 + 1e-10)
        else:
            c = col[y[mask]]
        ax.scatter(Z[mask,0], Z[mask,1], c=c, label=legend_labels[i], s=point_size)
    
    ax.set_title(title)
#     leg = ax.legend(bbox_to_anchor=(0., 1.1325, 1., 0.), loc='upper center', ncol=3, 
    leg = ax.legend(bbox_to_anchor=(0., -0.05, 1., 0.), loc='upper center', ncol=3, 
                    mode="expand", shadow=True, fancybox=True, markerscale=markerscale, handletextpad=0.1)
    leg.get_frame().set_alpha(0.75)


def plot_given_map_with_surface(Z, y, ax, title, uncertainty=None):
        
#     ax.set_facecolor('white')    
    # Counter plot w.r.t. uncertainty
#     XX, YY = np.meshgrid(Z[:,0], Z[:,1])
#     ZZ, _ = np.meshgrid(uncertainty, uncertainty)
#     CS = ax.contour(XX, YY, ZZ, colors='chartreuse')
#     ax.clabel(CS, inline=True) #, fontsize=10)
#     sns.kdeplot(data=Z[:,0], data2=Z[:,1], ax=ax)

    im = ax.plot_trisurf(np.squeeze(Z[:,0]), np.squeeze(Z[:,1]), np.squeeze(uncertainty), cmap='viridis')
    f.colorbar(im, ax=ax, shrink=0.8)
    
    ax.set_title(title)
    ax.view_init(azim=-90, elev=90)
# #     leg = ax.legend(bbox_to_anchor=(0., 1.1325, 1., 0.), loc='upper center', ncol=3, 
#     leg = ax.legend(bbox_to_anchor=(0., -0.05, 1., 0.), loc='upper center', ncol=3, 
#                     mode="expand", shadow=True, fancybox=True, markerscale=markerscale, handletextpad=0.1)
#     leg.get_frame().set_alpha(0.75)



def plot_aligned_tsne_with_uncertainty(Xa, ya, y_pred_a_ttaug,
                                       Xb, yb, y_pred_b_ttaug, 
                                       axes, 
                                       perplexity=30, max_iter=1000, 
                                       variance_to_keep=0.99, k=10,
                                       multicore_kNN=False, num_cores=10,
                                       plotting_order='original',
                                       exclude_reference_data=False,
                                       use_entropy=False):
        
    # First, standardize the data
    scaler = preprocessing.StandardScaler()
    Xa = scaler.fit_transform(Xa)
    Xb = scaler.transform(Xb)
    
    print('Computing the pairwise distances')
    K = pairwise_distances(X=Xa, Y=Xb, metric='euclidean')
    Ma, Mb = K.shape
    print('Finding kNNs...')
    kNN_idx_list = []
    if not multicore_kNN:
        for j in range(Mb): # loop over the items to be aligned with the reference map from Xa.
            idx = np.argsort(K[:,j]) # ascending order, so most distant at the end. kNNs are in the front
            kNN_idx_list.append(idx[:k]) # append the kNN indices
    else:
        kNNs_by_idx = Parallel(n_jobs=num_cores)(delayed(find_kNN_idx_per_column)(K[:,j], k) for j in range(Mb))
        for j in range(len(kNNs_by_idx)):
            kNN_idx_list.append(kNNs_by_idx[j])
           
    # Do PCA on the reference data and keep D dimensions
    print('PCA on reference data ...')
    Sigma = np.cov(np.transpose(Xa))
    U, s, V = np.linalg.svd(Sigma, full_matrices=False)
    sum_s = np.sum(s)
    print('Total components : %g' % len(s))
    for d in range(len(s)):
        var_explained = np.sum(s[:d]) / sum_s
        if var_explained >= variance_to_keep:
            break
    print('%g of variance explained with %d components.' % (var_explained, d))
    
    D = d
    XaD = np.dot(Xa, U[:,:D])   # np.dot(U, np.diag(s))[:,:D]
    PCAinit = XaD[:,:2] / np.std(XaD[:,0]) * 0.0001
    
    print('Computing tSNE map for the reference data')
    Za = fast_tsne(XaD, perplexity=perplexity,
                   max_iter=max_iter, learning_rate=learning_rate,
#                   stop_early_exag_iter=250, early_exag_coeff=12,
#                   start_late_exag_iter=750, late_exag_coeff=4, 
                   initialization=PCAinit)
    print('tSNE done...')
    
    ##################################################################################
    XbD = np.dot(Xb, U[:,:D])    # np.dot(U, np.diag(s))[:,:D]
    
    print('Collecting initialization points based on kNNs')
    kNN_init = []
    if not multicore_kNN:
        for kNN_idx in kNN_idx_list:
            kNNs = Za[kNN_idx,:2]
            kNN_init.append(np.mean(kNNs, axis=0))
            kNN_init = np.reshape(kNN_init, newshape=(len(kNN_idx_list),2))
    else:
        kNN_init = Parallel(n_jobs=num_cores)(delayed(find_kNN_inits_per_column)(kNN_idx_list[j], Za, 2) for j in range(len(kNN_idx_list)))
        kNN_init = np.reshape(kNN_init, newshape=(len(kNN_idx_list), k, 2))
        kNN_init = np.mean(kNN_init, axis=1)
#     kNN_init = kNN_init[:,:2] / np.std(kNN_init[:,0]) * 0.0001
    
    print('Computing tSNE map for the auxillary data')
    Zb = fast_tsne(XbD, perplexity=perplexity,
                   max_iter=max_iter, learning_rate=learning_rate,
#                   stop_early_exag_iter=250, early_exag_coeff=12,
#                   start_late_exag_iter=750, late_exag_coeff=4, 
                   initialization=kNN_init)
    print('tSNE done...')
    
    # Decide on the mappings to plot
    if not exclude_reference_data:
        Z = np.concatenate([Za,Zb], axis=0)
        y = np.concatenate([ya,yb], axis=0)
    else:
        Z = Zb
        y = yb
    
    # This one is for reference with no uncertainty
    plot_given_map_with_uncertainty(Z, y, axes[0], 'Aligned tSNE' + ' Perplexity : ' + str(perplexity),
                   plotting_order=plotting_order, uncertainty=None)
    
    # Now, the uncertainty business
    # Rescale uncertainties into [0,1] and use them as alpha channel in t-SNE maps
    min_max_scaler = preprocessing.MinMaxScaler()
    
    onset_levels = [1,2,3,4]
    step = 1
    for onset_level in onset_levels:
#         print('Uncertainty w.r.t the onset level %g' % onset_level)
        uncertainty_a = compute_uncertainties(y_pred_a_ttaug, onset_level=onset_level, 
                                              mode=mode, use_entropy=use_entropy)
        uncertainty_b = compute_uncertainties(y_pred_b_ttaug, onset_level=onset_level, 
                                              mode=mode, use_entropy=use_entropy)
        if not exclude_reference_data:
            uncertainty = np.concatenate([uncertainty_a, uncertainty_b], axis=0)
        else:
            uncertainty = uncertainty_b
        
        uncertainty = np.reshape(uncertainty, newshape=(len(uncertainty),1))
        uncertainty_01 = np.asarray(np.squeeze(min_max_scaler.fit_transform(uncertainty)), dtype=np.float32)
        plot_given_map_with_uncertainty(Z, y, axes[onset_level+step-1], 
                                        'Aligned tSNE with Onset ' + str(onset_level)  + ' UNCERTAINTY,' + ' Perplexity : ' + str(perplexity),
                                        plotting_order=plotting_order, uncertainty=uncertainty_01)
        
        plot_given_map_with_surface(Z, y, axes[onset_level+step+1-1], 
                                    'Uncertainty surface for onset ' + str(onset_level)  + ' Perplexity : ' + str(perplexity),
                                    uncertainty=uncertainty)
        
        step += 1


# Now, read the SINGLE PRED. results from file and plot
result_file_name = RESULTS_DIR + model.descriptor + '_SINGpred.pkl'

with open(result_file_name, 'rb') as filehandler:
    result = pickle.load(filehandler)
    
    X_tr = result['train_features'] 
    X_val = result['val_features']
    X_te = result['test_features']
    
    y_tr = np.argmax(result['train_labels_1hot'], axis=1)
    y_val = np.argmax(result['val_labels_1hot'], axis=1)
    y_te = np.argmax(result['test_labels_1hot'], axis=1)

X_valte = np.concatenate([X_val, X_te], axis=0)
y_valte = np.concatenate([y_val, y_te], axis=0)
del result, X_val, y_val, X_te, y_te # make some room

# Get the uncertainties from TTAUG results
use_entropy = True
exclude_reference_data=True

# for T in T_values:
print('T = %g' % T)
result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
with open(result_file_name, 'rb') as filehandler:
    result_ttaug = pickle.load(filehandler)
    
#     labels_1hot_tr = result_ttaug['train_labels_1hot'] # Mx5
#     labels_1hot_val = result_ttaug['val_labels_1hot']
#     labels_1hot_te = result_ttaug['test_labels_1hot']
    predictions_1hot_tr_ttaug = result_ttaug['train_pred_1hot'] # MxTx5
    predictions_1hot_val_ttaug = result_ttaug['val_pred_1hot']
    predictions_1hot_te_ttaug = result_ttaug['test_pred_1hot']

predictions_1hot_valte_ttaug = np.concatenate([predictions_1hot_val_ttaug, predictions_1hot_te_ttaug], axis=0)


for perp in perplexities:
    print('Perplexity : %g' % perp)
    f = plt.figure(figsize=(15, 37.5))
    ######################################
    ax1 = f.add_subplot(5, 2, 1) 
#     ax2 = f.add_subplot(5, 2, 2) 
    #####################################
    # uncertainty maps for given onset levels 
    ax3 = f.add_subplot(5, 2, 3)
    ax5 = f.add_subplot(5, 2, 5)
    ax7 = f.add_subplot(5, 2, 7)
    ax9 = f.add_subplot(5, 2, 9)
    ######################################
    # contour/surface plots
    ax4 = f.add_subplot(5, 2, 4, projection='3d')
    ax6 = f.add_subplot(5, 2, 6, projection='3d')
    ax8 = f.add_subplot(5, 2, 8, projection='3d')
    ax10 = f.add_subplot(5, 2, 10, projection='3d')
    plot_aligned_tsne_with_uncertainty(X_tr, y_tr, predictions_1hot_tr_ttaug, 
                                       X_valte, y_valte, predictions_1hot_valte_ttaug, 
                                       [ax1, ax3, ax4, ax5, ax6, ax7, ax8, ax9, ax10], 
                                       perplexity=perp, max_iter=max_iter, 
                                       variance_to_keep=variance_to_keep, k=num_neighbors, 
                                       multicore_kNN=True, num_cores=multiprocessing.cpu_count(),
                                       plotting_order=plotting_order, 
                                       exclude_reference_data=exclude_reference_data, 
                                       use_entropy=use_entropy)
    plt.tight_layout()
    plt.show()
    
del X_tr, y_tr, X_valte, y_valte

In [None]:
# ALIGNED t-SNE Maps with Uncertainty, MULTICLASS
from FItSNE.fast_tsne import fast_tsne 
from sklearn.metrics import pairwise_distances
from sklearn import preprocessing
from matplotlib import colors

# plt.style.use('dark_background')

def compute_uncertainties_multiclass(predictions_1hot_ttaug, mode='mean', use_entropy=False):
    
    uncertainty_est = None

    if mode == 'mean':
        if use_entropy: # Entropy for mean predictions
            predictions_1hot_mean = np.mean(predictions_1hot_ttaug, axis=1, keepdims=False) # Mx5
            predictions_1hot_mean = np.divide(predictions_1hot_mean, 
                                              np.sum(predictions_1hot_mean, axis=-1, 
                                                     keepdims=True)) # re-normalize the probabilities
            uncertainty_est = entropy(predictions_1hot_mean) # ENT from Mx5 matrix
        else: # STD 
            uncertainty_est = np.std(predictions_1hot_ttaug, axis=1) # STD from MxTx5 matrix
        
    elif mode == 'median':
        if use_entropy:# Entropy for median predictions
            predictions_1hot_median = np.median(predictions_1hot_ttaug, axis=1, keepdims=False) # Mx5
            predictions_1hot_median = np.divide(predictions_1hot_median, 
                                                np.sum(predictions_1hot_median, axis=-1, 
                                                       keepdims=True)) # re-normalize the probabilities
            uncertainty_est = entropy(predictions_1hot_median) # ENT from Mx5 matrix
        else: # IQR
            uncertainty_est = stats.iqr(predictions_1hot_ttaug, axis=1) # IQR from MxTx5 matrix
    
    assert uncertainty_est is not None, 'No uncertainty estimate computed!'
    
    return uncertainty_est


def plot_aligned_tsne_with_uncertainty_multiclass(Xa, ya, y_pred_a_ttaug,
                                                  Xb, yb, y_pred_b_ttaug, 
                                                  ax1, ax2, ax3, 
                                                  perplexity=30, max_iter=1000, 
                                                  variance_to_keep=0.99, k=10,
                                                  multicore_kNN=False, num_cores=10,
                                                  plotting_order='original',
                                                  exclude_reference_data=False, use_entropy=False):
    # First, standardize the data
    scaler = preprocessing.StandardScaler()
    Xa = scaler.fit_transform(Xa)
    Xb = scaler.transform(Xb)
    
    print('Computing the pairwise distances')
    K = pairwise_distances(X=Xa, Y=Xb, metric='euclidean')
    Ma, Mb = K.shape
    print('Finding kNNs...')
    kNN_idx_list = []
    if not multicore_kNN:
        for j in range(Mb): # loop over the items to be aligned with the reference map from Xa.
            idx = np.argsort(K[:,j]) # ascending order, so most distant at the end. kNNs are in the front
            kNN_idx_list.append(idx[:k]) # append the kNN indices
    else:
        kNNs_by_idx = Parallel(n_jobs=num_cores)(delayed(find_kNN_idx_per_column)(K[:,j], k) for j in range(Mb))
        for j in range(len(kNNs_by_idx)):
            kNN_idx_list.append(kNNs_by_idx[j])
           
    # Do PCA on the reference data and keep D dimensions
    print('PCA on reference data ...')
    Sigma = np.cov(np.transpose(Xa))
    U, s, V = np.linalg.svd(Sigma, full_matrices=False)
    sum_s = np.sum(s)
    print('Total components : %g' % len(s))
    for d in range(len(s)):
        var_explained = np.sum(s[:d]) / sum_s
        if var_explained >= variance_to_keep:
            break
    print('%g of variance explained with %d components.' % (var_explained, d))
    
    D = d
    XaD = np.dot(Xa, U[:,:D])   # np.dot(U, np.diag(s))[:,:D]
    PCAinit = XaD[:,:2] / np.std(XaD[:,0]) * 0.0001
    
    print('Computing tSNE map for the reference data')
    Za = fast_tsne(XaD, perplexity=perplexity,
                   max_iter=max_iter, learning_rate=learning_rate,
#                   stop_early_exag_iter=250, early_exag_coeff=12,
#                   start_late_exag_iter=750, late_exag_coeff=4, 
                   initialization=PCAinit)
    print('tSNE done...')
    
    ##################################################################################
    XbD = np.dot(Xb, U[:,:D])    # np.dot(U, np.diag(s))[:,:D]
    
    print('Collecting initialization points based on kNNs')
    kNN_init = []
    if not multicore_kNN:
        for kNN_idx in kNN_idx_list:
            kNNs = Za[kNN_idx,:2]
            kNN_init.append(np.mean(kNNs, axis=0))
            kNN_init = np.reshape(kNN_init, newshape=(len(kNN_idx_list),2))
    else:
        kNN_init = Parallel(n_jobs=num_cores)(delayed(find_kNN_inits_per_column)(kNN_idx_list[j], Za, 2) for j in range(len(kNN_idx_list)))
        kNN_init = np.reshape(kNN_init, newshape=(len(kNN_idx_list), k, 2))
        kNN_init = np.mean(kNN_init, axis=1)
#     kNN_init = kNN_init[:,:2] / np.std(kNN_init[:,0]) * 0.0001
    
    print('Computing tSNE map for the auxillary data')
    Zb = fast_tsne(XbD, perplexity=perplexity,
                   max_iter=max_iter, learning_rate=learning_rate, 
#                   stop_early_exag_iter=250, early_exag_coeff=12,
#                   start_late_exag_iter=750, late_exag_coeff=4, 
                   initialization=kNN_init)
    print('tSNE done...')
    
    # Decide on the mappings to plot
    if not exclude_reference_data:
        Z = np.concatenate([Za,Zb], axis=0)
        y = np.concatenate([ya,yb], axis=0)
    else:
        Z = Zb
        y = yb
    
    # This one is for reference with no uncertainty
    plot_given_map_with_uncertainty(Z, y, ax1,                                    
                                    'Aligned tSNE' + ' Perplexity : ' + str(perplexity),
                                    plotting_order=plotting_order, uncertainty=None)
    
    # Now, the uncertainty business
    # Rescale uncertainties into [0,1] and use them as alpha channel in t-SNE maps
    min_max_scaler = preprocessing.MinMaxScaler()
    
    uncertainty_a = compute_uncertainties_multiclass(y_pred_a_ttaug, mode=mode, use_entropy=use_entropy)
    uncertainty_b = compute_uncertainties_multiclass(y_pred_b_ttaug, mode=mode, use_entropy=use_entropy)
    
    if not exclude_reference_data:
        uncertainty = np.concatenate([uncertainty_a, uncertainty_b], axis=0)
    else:
        uncertainty = uncertainty_b

    uncertainty = np.reshape(uncertainty, newshape=(len(uncertainty),1))
    uncertainty_01 = np.asarray(np.squeeze(min_max_scaler.fit_transform(uncertainty)), dtype=np.float32)
    plot_given_map_with_uncertainty(Z, y, ax2, 
                                    'Aligned tSNE with Multi Class UNCERTAINTY, Perplexity : ' + str(perplexity),
                                    plotting_order=plotting_order, uncertainty=uncertainty_01)
    plot_given_map_with_surface(Z, y, ax3, 
                                'Multi Class Uncertainty surface, Perplexity : ' + str(perplexity),
                                uncertainty=uncertainty)

# Now, read the SINGLE PRED. results from file and plot
result_file_name = RESULTS_DIR + model.descriptor + '_SINGpred.pkl'

with open(result_file_name, 'rb') as filehandler:
    result = pickle.load(filehandler)
    
    X_tr = result['train_features'] 
    X_val = result['val_features']
    X_te = result['test_features']
    
    y_tr = np.argmax(result['train_labels_1hot'], axis=1)
    y_val = np.argmax(result['val_labels_1hot'], axis=1)
    y_te = np.argmax(result['test_labels_1hot'], axis=1)

X_valte = np.concatenate([X_val, X_te], axis=0)
y_valte = np.concatenate([y_val, y_te], axis=0)
del result, X_val, y_val, X_te, y_te # make some room


# for T in T_values:
print('T = %g' % T)
result_file_name = RESULTS_DIR + model.descriptor + '_TTAUG_' + str(T) + '.pkl'
with open(result_file_name, 'rb') as filehandler:
    result_ttaug = pickle.load(filehandler)
    
#     labels_1hot_tr = result_ttaug['train_labels_1hot'] # Mx5
#     labels_1hot_val = result_ttaug['val_labels_1hot']
#     labels_1hot_te = result_ttaug['test_labels_1hot']
    predictions_1hot_tr_ttaug = result_ttaug['train_pred_1hot'] # MxTx5
    predictions_1hot_val_ttaug = result_ttaug['val_pred_1hot']
    predictions_1hot_te_ttaug = result_ttaug['test_pred_1hot']

predictions_1hot_valte_ttaug = np.concatenate([predictions_1hot_val_ttaug, predictions_1hot_te_ttaug], axis=0)


for perp in perplexities:
    print('Perplexity : %g' % perp)
    f = plt.figure(figsize=(22.5,7.5))
    ax1 = f.add_subplot(1, 3, 1) 
    ax2 = f.add_subplot(1, 3, 2) 
    ax3 = f.add_subplot(1, 3, 3, projection='3d') 
    
    plot_aligned_tsne_with_uncertainty_multiclass(X_tr, y_tr, predictions_1hot_tr_ttaug, 
                                                  X_valte, y_valte, predictions_1hot_valte_ttaug, 
                                                  ax1, ax2, ax3, 
                                                  perplexity=perp, max_iter=max_iter, 
                                                  variance_to_keep=variance_to_keep, k=num_neighbors, 
                                                  multicore_kNN=True, num_cores=multiprocessing.cpu_count(),
                                                  plotting_order=plotting_order, 
                                                  exclude_reference_data=exclude_reference_data,
                                                  use_entropy=use_entropy)
    plt.tight_layout()
    plt.show()
    
del X_tr, y_tr, X_valte, y_valte