In [None]:
import pandas as pd
import numpy as np
import pickle

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, \
                            precision_recall_curve, roc_curve, balanced_accuracy_score, \
                            precision_score, f1_score, recall_score, ConfusionMatrixDisplay

In [None]:
with open('2025-05-18_threshold_fpr_at_90_recall_ahr_mmp_models.pkl', 'rb') as f:
    model_metrics = pickle.load(f)

In [None]:
model_metrics

# AHR results

In [None]:
with open('final_predictions_experimental_ahr.pkl', 'rb') as f:
    ahr_results = pickle.load(f)

In [None]:
ahr_true = ahr_results['AHR_true_label']

## XGB

In [None]:
ahr_results.head()

In [None]:
roc_auc = roc_auc_score(ahr_true, ahr_results.AHR_XGB_proba)
fpr, tpr, thresh = roc_curve(ahr_true, ahr_results.AHR_XGB_proba,
                                     drop_intermediate=False)

threshold_ahr_xgb = model_metrics[(model_metrics['model'] == 'XGB')&(model_metrics['endpoint'] == 'AHR')]['threshold_90_recall'].values[0]

cm = confusion_matrix(ahr_true, (ahr_results.AHR_XGB_proba >= threshold_ahr_xgb).astype(int))
tn, fp, fn, tp = cm.ravel()

fpr_90_recall = fp / (fp + tn)

bal_acc = balanced_accuracy_score(ahr_true, ahr_results.AHR_XGB)

In [None]:
metrics_evaluation_df = pd.DataFrame({'endpoint': ['AHR'],
                                        'model': ['XGB'],
                                        'fpr_90_recall_evaluation': [fpr_90_recall],
                                        'roc_auc': [roc_auc],
                                        'balanced_accuracy': [bal_acc]})

roc_ahr_xgb = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresh': thresh})

In [None]:
roc_ahr_xgb

In [None]:
cm_ahr_xgb = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Inactive', 'Active'])
cm_ahr_xgb.plot(cmap='Blues')

## RF

In [None]:
ahr_results.head()

In [None]:
roc_auc = roc_auc_score(ahr_true, ahr_results.AHR_RF_proba)
fpr, tpr, thresh = roc_curve(ahr_true, ahr_results.AHR_RF_proba,
                                     drop_intermediate=False)

cm = confusion_matrix(ahr_true, ahr_results.AHR_RF)
tn, fp, fn, tp = cm.ravel()

fpr_90_recall = fp / (fp + tn)

bal_acc = balanced_accuracy_score(ahr_true, ahr_results.AHR_RF)

In [None]:
metrics_evaluation = pd.DataFrame({'endpoint': ['AHR'],
                                        'model': ['RF'],
                                        'fpr_90_recall_evaluation': [fpr_90_recall],
                                        'roc_auc': [roc_auc],
                                        'balanced_accuracy': [bal_acc]})

metrics_evaluation_df = pd.concat([metrics_evaluation_df, metrics_evaluation], ignore_index=True)

roc_ahr_rf = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresh': thresh})

In [None]:
cm_ahr_rf = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Inactive', 'Active'])
cm_ahr_rf.plot(cmap='Blues')

# MMP results

In [None]:
with open('final_predictions_experimental_mmp.pkl', 'rb') as f:
    mmp_results = pickle.load(f)

In [None]:
mmp_results

In [None]:
mmp_true = mmp_results['MMP_true_label']

## XGB

In [None]:
model_metrics[(model_metrics.endpoint == 'MMP')&(model_metrics.model=='XGB') ]['threshold_90_recall']

In [None]:
threshold = model_metrics[(model_metrics.endpoint == 'MMP')&(model_metrics.model=='XGB') ]['threshold_90_recall'].values[0]

In [None]:
roc_auc = roc_auc_score(mmp_true, mmp_results.MMP_XGB_proba)
fpr, tpr, thresh = roc_curve(mmp_true, mmp_results.MMP_XGB_proba,
                                     drop_intermediate=False)

cm = confusion_matrix(mmp_true, (mmp_results.MMP_XGB_proba >= threshold).astype(int))
tn, fp, fn, tp = cm.ravel()

fpr_90_recall = fp / (fp + tn)

bal_acc = balanced_accuracy_score(mmp_true, mmp_results.MMP_XGB)

In [None]:
metrics_evaluation = pd.DataFrame({'endpoint': ['MMP'],
                                        'model': ['XGB'],
                                        'fpr_90_recall_evaluation': [fpr_90_recall],
                                        'roc_auc': [roc_auc],
                                        'balanced_accuracy': [bal_acc]})

metrics_evaluation_df = pd.concat([metrics_evaluation_df, metrics_evaluation], ignore_index=True)

roc_mmp_xgb = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresh': thresh})

In [None]:
cm_mmp_xgb = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Inactive', 'Active'])
cm_mmp_xgb.plot(cmap='Blues')

## RF

In [None]:
roc_auc = roc_auc_score(mmp_true, mmp_results.MMP_RF_proba)
fpr, tpr, thresh = roc_curve(mmp_true, mmp_results.MMP_RF_proba,
                                     drop_intermediate=False)

cm = confusion_matrix(mmp_true, mmp_results.MMP_RF)
tn, fp, fn, tp  = cm.ravel()

fpr_90_recall = fp / (fp + tn)

bal_acc = balanced_accuracy_score(mmp_true, mmp_results.MMP_RF)

In [None]:
tn, fp, fn, tp

In [None]:
fpr_90_recall

In [None]:
metrics_evaluation = pd.DataFrame({'endpoint': ['MMP'],
                                        'model': ['RF'],
                                        'fpr_90_recall_evaluation': [fpr_90_recall],
                                        'roc_auc': [roc_auc],
                                        'balanced_accuracy': [bal_acc]})

metrics_evaluation_df = pd.concat([metrics_evaluation_df, metrics_evaluation], ignore_index=True)

roc_mmp_rf = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresh': thresh})

In [None]:
cm_mmp_rf = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Inactive', 'Active'])
cm_mmp_rf.plot(cmap='Blues')

# Visualization

In [None]:
import matplotlib.pyplot as plt

# Set the figure parameters
plt.rcParams.update({'figure.figsize':[9.8,9.8],
                'font.size': 16, 
                'font.weight': 'normal',
                'axes.titlesize': 12,
                'axes.labelsize': 12,
                'xtick.labelsize': 12,
                'ytick.labelsize': 12,
                'legend.fontsize': 12,
                'legend.title_fontsize': 12,
                'axes.titleweight': 'bold',
                'font.family': 'serif',
                'font.serif': ['Times New Roman'],
                'figure.dpi':300,
                
                })

fig, ax = plt.subplots(2, 2, figsize=(10, 9.5), dpi=300)
# AHR
cm_ahr_rf.plot(ax=ax[0, 0], cmap='Blues', colorbar=False)
cm_ahr_rf.ax_.set_title('AhR RF')

cm_ahr_xgb.plot(ax=ax[0, 1], cmap='Blues', colorbar=False)
cm_ahr_xgb.ax_.set_title('AhR XGB')

# MMP
cm_mmp_rf.plot(ax=ax[1, 0], cmap='Oranges', colorbar=False)
cm_mmp_rf.ax_.set_title('MMP RF')

cm_mmp_xgb.plot(ax=ax[1, 1], cmap='Oranges', colorbar=False)
cm_mmp_xgb.ax_.set_title('MMP XGB')

plt.tight_layout()

#plt.suptitle('Evaluation set\nConfusion matrices for AhR and MMP models at TPR=90%', fontsize=16, fontweight='bold', y=1.03)



plt.savefig('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Visualizations/2025-05-26_confusion_matrix_evaluation.pdf', dpi=300, bbox_inches='tight')


## ROC-AUC

In [None]:
#Concatenate all ROC data
roc_data = pd.concat([roc_ahr_rf, roc_ahr_xgb, roc_mmp_rf, roc_mmp_xgb], axis=1)

In [None]:
#Plot all ROC curves
fig, ax = plt.subplots(1, 2, figsize=(8, 5), dpi=300, tight_layout=True, sharex=True, sharey=True)

ax[0].plot(roc_ahr_rf['fpr'], roc_ahr_rf['tpr'], label='AhR RF', color='#219EBC')
ax[0].plot(roc_ahr_xgb['fpr'], roc_ahr_xgb['tpr'], label='AhR XGB', color='#023047')
ax[0].grid(visible=True, which='both', linewidth=0.5, alpha=0.5)
ax[0].set_title('AhR models')
ax[0].legend(loc='lower right')
ax[0].set_aspect('equal')

ax[1].plot(roc_mmp_rf['fpr'], roc_mmp_rf['tpr'], label='MMP RF', color='#FFB703')
ax[1].plot(roc_mmp_xgb['fpr'], roc_mmp_xgb['tpr'], label='MMP XGB', color='#FB8500')
ax[1].grid(visible=True, which='both', linewidth=0.5, alpha=0.5)
ax[1].set_title('MMP models')
ax[1].legend(loc='lower right')
ax[1].set_aspect('equal')


#fig.suptitle('Evaluation set\nROC-AUC for AhR and MMP models', fontweight='bold', fontsize=16, y = 0.93)
fig.supxlabel('False Positive Rate', y=0.12, fontsize=12)
fig.supylabel('True Positive Rate', x=0.05, fontsize=12)

fig.savefig('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Visualizations/2025-05-26_roc_curves_ahr_mmp_evaluation.pdf', dpi=300, bbox_inches='tight',
            transparent=True)

# Metrics evaluation

In [None]:
metrics_evaluation_df