In [1]:
from utils import fetch_predictions, calc_basic_performance_metrics

test_complex_data = fetch_predictions(persist_dir = "pl_out/node_average_15_epochs/persist_train_4_layer/",
                                      replica_of_interest = 6,
                                      run_type_flag = "test")
test_complex_data_w_metrics = calc_basic_performance_metrics(test_complex_data)

In [2]:
import numpy as np
def print_stats(complex_data):
    print(" mean auc: +/-{} {}\n median auc: {}\n (across all complexes)\n".format(np.std([complex_data[complex_code]["auc"] for complex_code in complex_data.keys()]),
                                                                                   np.mean([complex_data[complex_code]["auc"] for complex_code in complex_data.keys()]),
                                                                                   np.median([complex_data[complex_code]["auc"] for complex_code in complex_data.keys()])))

    print(" average_precision_score auc: +/-{} {}\n average_precision_score auc: {}\n (across all complexes)\n".format(np.std([complex_data[complex_code]["average_precision_score"] for complex_code in complex_data.keys()]),
                                                                                                                       np.mean([complex_data[complex_code]["average_precision_score"] for complex_code in complex_data.keys()]),
                                                                                                                       np.median([complex_data[complex_code]["average_precision_score"] for complex_code in complex_data.keys()])))
# end
print_stats(test_complex_data_w_metrics)

 mean auc: +/-0.139713039598 0.83722228028
 median auc: 0.885546804854
 (across all complexes)

 average_precision_score auc: +/-0.0502415056224 0.0569991083649
 average_precision_score auc: 0.0463167023128
 (across all complexes)



In [3]:
from utils import aggregate_predictions_all_complexes

all_y_trues, all_y_preds = aggregate_predictions_all_complexes(test_complex_data)

In [4]:
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
import seaborn

fpr, tpr, thresholds = roc_curve(all_y_trues, all_y_preds)
roc_auc = auc(fpr, tpr) # compute area under the curve

#fpr = sorted(fpr)
#tpr = sorted(tpr)
#thresholds = sorted(thresholds)

default_fontsize = 10
x_figure_size = 8
y_figure_size = 8

plt.figure()
lw = 4
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate').set_fontsize(default_fontsize)
plt.ylabel('True Positive Rate').set_fontsize(default_fontsize)
plt.title('Receiver operating characteristic example').set_fontsize(default_fontsize)
plt.legend(loc="lower right", prop={'size': default_fontsize})
plt.tick_params(labelsize=default_fontsize)

# create the axis of thresholds (scores)
ax2 = plt.gca().twinx()
ax2.plot(fpr, thresholds, markeredgecolor='r',linestyle='dashed', color='r')
ax2.set_ylabel('Threshold', color='red')
ax2.set_ylim([thresholds[-1],thresholds[0]])
ax2.set_xlim([fpr[0],fpr[-1]])

for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
              ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(default_fontsize)

plt.rcParams["figure.figsize"]=x_figure_size, y_figure_size

plt.savefig('roc_and_threshold.png')
#plt.close()
plt.show()

# tpr = tp/(tp+fn)
# fpr = fp/(fp+tn)
# sensitivity = recall = tp / t = tp / (tp + fn)
# specificity = tn / n = tn / (tn + fp)
# precision = tp / p = tp / (tp + fp)

<Figure size 640x480 with 2 Axes>

In [5]:
from utils import calc_precision_recall
from tqdm import tqdm

# based on roc curve above choose threshold
guess_optimal_threshold = 1.0

cutoff_thresholds = [thres for thres in sorted(thresholds) if thres > guess_optimal_threshold]

cand_ps = []
cand_rs = []

def solve_optimal_threshold(cutoff_thresholds):
    for pred_threshold in tqdm(cutoff_thresholds):
        all_y_preds_pred_threshold = list(map(lambda x: 1 if x > pred_threshold else 0, all_y_preds))
        pred_p, pred_r = calc_precision_recall(all_y_preds_pred_threshold, all_y_trues)
        cand_ps.append(pred_p)
        cand_rs.append(pred_r)
    return (cand_ps, cand_rs)
# end

cand_ps, cand_rs = solve_optimal_threshold(cutoff_thresholds)
print("idx of threshold yielding max precision: {}".format(cand_ps.index(max(cand_ps))))
print("idx of threshold yielding max recall: {}".format(cand_rs.index(max(cand_rs))))

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
100%|██████████| 1105/1105 [20:26<00:00,  1.05s/it]

idx of threshold yielding max precision: 0
idx of threshold yielding max recall: 1102





In [9]:
from collections import Counter
optimal_
optimal_threshold=cutoff_thresholds[]
all_y_preds_pred_threshold = list(map(lambda x: 1 if x > optimal_threshold else 0, all_y_preds))
print("predicted_dist: {}".format(Counter(all_y_preds_optimal)))
print("true_dist: {}".format(Counter(all_y_trues)))

0.222372575553
0.0350590243209
predicted_dist: Counter({0: 886948, 1: 28124})
true_dist: Counter({0: 910638, 1: 4434})


35113

In [None]:
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm,
                          default_fontsize,
                          x_figure_size,
                          y_figure_size, 
                          classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    #print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title).set_fontsize(default_fontsize)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    plt.tick_params(labelsize=default_fontsize)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 fontsize=default_fontsize,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label').set_fontsize(default_fontsize)
    plt.xlabel('Predicted label').set_fontsize(default_fontsize)
    plt.tight_layout()
    
    plt.rcParams["figure.figsize"]=x_figure_size, y_figure_size

    if normalize:
      plt.savefig('normalized_confusion_matrix.png')
    else:
      plt.savefig('confusion_matrix_without_normalization.png')
# end


# Compute confusion matrix
cnf_matrix = confusion_matrix(all_y_trues_recast, all_y_preds_optimal)
np.set_printoptions(precision=3)

default_fontsize = 20
x_figure_size = 12
y_figure_size = 12

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix,
                      default_fontsize,
                      x_figure_size,
                      y_figure_size,
                      classes=["No Interaction", "Interaction"],
                      title='Confusion matrix, without normalization')

"""
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix,
                      default_fontsize,
                      x_figure_size,
                      y_figure_size,
                      classes=["No Interaction", "Interaction"], normalize=True,
                      title='Normalized confusion matrix')
"""
plt.show()

In [None]:
# papers of interest
# https://arxiv.org/pdf/1801.07829.pdf 