## Script to create AUC plot and compute confusion matrix for external validation on AREDS

In [None]:
import pandas as pd
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt

In [None]:
# Predictions obtained from NIH collaborators
preds_labels = pd.read_csv('./MI_preds_051321.csv')

In [None]:
preds_labels.columns

In [None]:
sum(preds_labels['MI_event'].values)

In [None]:
len(preds_labels['MI_event'].values)

In [None]:
# Label 0: No AMD
# Label 1: if not a Label 3 or Label 2, and has medium drusen (i.e.  63-124μm)
# Label 2: if it’s not a Label 3, and (a) has large drusen  (i.e.>=125μm)  with/without  PA,  
# or (b) medium drusen (i.e.  63-124μm) with PA
# Label  3: Eye  has  either  non-central/central Geographic atrophy (GA) 
# and/or Neovascular Age-related macular degeneration (AMD)

pd.unique(preds_labels['right_amd_scale'])

In [None]:
pos_labels = preds_labels.loc[preds_labels['MI_event'] == 1]
neg_labels = preds_labels.loc[preds_labels['MI_event'] == 0]


# Negative labels
# Discarding rows with AMD = 1
indexNames = neg_labels[neg_labels['right_amd_scale'] == 1].index # Delete these row indexes
neg_labels = neg_labels.drop(indexNames)
# Discarding rows with AMD = 2
indexNames = neg_labels[neg_labels['right_amd_scale'] == 2].index # Delete these row indexes
neg_labels = neg_labels.drop(indexNames)
# Discarding rows with AMD = 3
indexNames = neg_labels[neg_labels['right_amd_scale'] == 3].index # Delete these row indexes
neg_labels = neg_labels.drop(indexNames)

# # Positive labels
# Discarding rows with AMD = 1
indexNames = pos_labels[pos_labels['right_amd_scale'] == 1].index # Delete these row indexes
pos_labels = pos_labels.drop(indexNames)
# Discarding rows with AMD = 2
indexNames = pos_labels[pos_labels['right_amd_scale'] == 2].index # Delete these row indexes
pos_labels = pos_labels.drop(indexNames)
# Discarding rows with AMD = 3
indexNames = pos_labels[pos_labels['right_amd_scale'] == 3].index # Delete these row indexes
pos_labels = pos_labels.drop(indexNames)


In [None]:
pos_labels

In [None]:
def plot():
    fig, ax = plt.subplots(figsize=(12, 10))
    ax.plot([0, 1], [0, 1], linestyle='--', lw=3, color='r', label='Chance', alpha=.9)
    plt.plot(fpr,tpr,lw=4,label="AUC="+str(round(auc,2)))
    ax.legend(loc="lower right", fontsize=22)
    ax.tick_params(labelsize=22)
    ax.set_ylabel('True Positive Rate', fontsize=22)
    ax.set_xlabel('False Positive Rate', fontsize=22)
    plt.savefig('AUC_MI_EXTERNAL_' + str(round(auc,2)) + '.png')
    plt.savefig('AUC_MI_EXTERNAL_' + str(round(auc,2)) + '.pdf')
    plt.show()
    plt.close()

In [None]:
# Creating a function to report confusion metrics
def confusion_metrics (conf_matrix):
    
    # save confusion matrix and slice into four pieces   
    TP = conf_matrix[1][1]
    TN = conf_matrix[0][0]
    FP = conf_matrix[0][1]
    FN = conf_matrix[1][0]    
    print('True Positives:', TP)
    print('True Negatives:', TN)
    print('False Positives:', FP)
    print('False Negatives:', FN)
    
    # calculate accuracy
    conf_accuracy = (float (TP+TN) / float(TP + TN + FP + FN))
    
    # calculate mis-classification
    conf_misclassification = 1 - conf_accuracy
    
    # calculate the sensitivity
    conf_sensitivity = (TP / float(TP + FN))    # calculate the specificity
    conf_specificity = (TN / float(TN + FP))
    
    # calculate precision
    conf_precision = (TN / float(TN + FP))    # calculate f_1 score
    conf_f1 = 2 * ((conf_precision * conf_sensitivity) / (conf_precision + conf_sensitivity))    
    print(f'Accuracy: {round(conf_accuracy,2)}') 
    print(f'Mis-Classification: {round(conf_misclassification,2)}') 
    print(f'Sensitivity: {round(conf_sensitivity,2)}') 
    print(f'Specificity: {round(conf_specificity,2)}') 
    print(f'Precision: {round(conf_precision,2)}')
    print(f'f_1 Score: {round(conf_f1,2)}')

In [None]:
auc_pilot = 0
thrs = 0.16

new_neg_labels = neg_labels.sample(n=70)    

preds_labels = pd.concat([new_neg_labels, pos_labels])

fpr, tpr, thresholds = metrics.roc_curve(preds_labels['MI_event'].values,
                                         preds_labels['MI_pred1'].values,
                                         pos_label=1)
cm = metrics.confusion_matrix(preds_labels['MI_event'].values,
                                         preds_labels['MI_pred1'].values>thrs)

auc = metrics.auc(fpr, tpr)
confusion_metrics(cm)
print('-'*50)
# plot()