# Evaluation of models on corrected and non-corrected top1 molecular formula prediction fingerprints

In [None]:
import warnings
warnings.filterwarnings('ignore', category = RuntimeWarning)
warnings.filterwarnings('ignore', category = UserWarning)

import pandas as pd
import numpy as np
import glob
import os
import zipfile
import shutil
import pickle
import xgboost as xgb

In [None]:
# Functions for making predictions
def MonteCarlo_prediction(fp, model, n_samples=10000):
    binary_samples = np.random.binomial(1, fp, size=(n_samples, len(fp)))
    predicted_probabilities = model.predict_proba(binary_samples)[:, 1]
    return np.mean(predicted_probabilities)

def make_predictions(data, model):
    df = data[np.concatenate([['id', 'formula'], model.feature_names_in_])]
    predicted_probabilities =  df[model.feature_names_in_].apply(lambda x: MonteCarlo_prediction(np.array(x), model=model), axis=1)
    return predicted_probabilities

In [None]:
# Load models

ahr_rf_model = pickle.load(open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Code/Model_training/AHR/RF/2025-05-14_RF_feat_var_0.9_ahr_CORRECT_FINAL_correct_features.pkl', 'rb'))
ahr_xgb_model = pickle.load(open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Code/Model_training/AHR/XGBoost/2025-05-14_XGBoost_feat_var_0.9_ahr_CORRECT_FINAL_correct_features.pkl', 'rb'))

mmp_rf_model = pickle.load(open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Code/Model_training/MMP/RF/2025-05-14_RF_feat_var_0.9_mmp_CORRECT_FINAL_correct_features.pkl','rb'))
mmp_xgb_model = pickle.load(open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Code/Model_training/MMP/XGBoost/2025-05-14_XGBoost_feat_var_0.9_mmp_CORRECT_FINAL_correct_features.pkl','rb'))

#Load feature names
ahr_features = pickle.load(open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Code/Model_training/AHR/ahr_features.pkl','rb'))
mmp_features = pickle.load(open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Code/Model_training/MMP/mmp_features.pkl','rb'))

# Prediction of top 1 formula fingerprints WITHOUT mass correction

## Get all fingerprints for top 1 molecular formula

Code by Ida Rahu

In [None]:
ellinor = '/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_no_mass_correction/ellinor_data'
iris = '/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_no_mass_correction/iris_data_dry'
isabell = '/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_no_mass_correction/isabell_data'
library = '/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_no_mass_correction/library_data'

output_folders = [ellinor, iris, isabell, library]

In [None]:
def get_rank1_fingerprints(output_folders, esi_mode='pos'):
    sirius_output_folder = ellinor # Folder with SIRIUS+CSI:FingerID results

    if esi_mode == 'pos':
        fp_info = pd.read_csv(f'{sirius_output_folder}/csi_fingerid.tsv', sep='\t')
    else:
        fp_info = pd.read_csv(f'{sirius_output_folder}/csi_fingerid_neg.tsv', sep='\t')

    # Generating the dataframe for SIRIUS+CSI:FingerID results
    columns = np.concatenate([['id', 'formula', 'adduct'], [idx for idx in fp_info.relativeIndex.values]], axis=0)
    fp_data = pd.DataFrame(columns=columns)

    without_fp = [] # Array for MS features without predicted fingerprints

    for output_folder in output_folders:
        for file_name in glob.glob(f'{sirius_output_folder}/*/formula_candidates.tsv'):
            data = pd.read_csv(file_name, sep='\t')
            rank1formula = data[data.formulaRank == 1].molecularFormula.values[0]
            adduct = data[data.formulaRank == 1].adduct.values[0].replace(' ', '') # Using fingerprints of rank 1 formulas
            id = os.path.basename(os.path.dirname(file_name)).split('_')[1:]
            id = '_'.join(id[:len(id)//2])
            try:
                fp = pd.read_csv(f'{os.path.dirname(file_name)}/fingerprints/{rank1formula}_{adduct}.fpt', header=None).T.values.flatten()
                data_ready = np.concatenate([[id, rank1formula, adduct], fp], axis=0)
                fp_data.loc[len(fp_data)] = data_ready
            except:
                without_fp.append(id)

    fp_data = fp_data.apply(pd.to_numeric, errors='ignore')

    fp_data.columns = [int(col) if col.isnumeric() else col for col in fp_data.columns]

    return fp_data, without_fp

In [None]:
fp_data_non_corrected, without_fp = get_rank1_fingerprints(output_folders, esi_mode='pos')

In [None]:
fp_data_non_corrected.sort_values(by=['id'], inplace=True)

In [None]:
fp_data_non_corrected

In [None]:
fp_data_non_corrected=fp_data_non_corrected.drop_duplicates()

In [None]:
# Functions for making predictions
def MonteCarlo_prediction(fp, model, n_samples=10000):
    binary_samples = np.random.binomial(1, fp, size=(n_samples, len(fp)))
    predicted_probabilities = model.predict_proba(binary_samples)[:, 1]
    return np.mean(predicted_probabilities)

def make_predictions(data, model, feature_names):
    df = data.iloc[:, list(feature_names)]
    predicted_probabilities =  df.apply(lambda x: MonteCarlo_prediction(np.array(x), model=model), axis=1)
    return predicted_probabilities

## Make model predictions

In [None]:
models = [ahr_rf_model, ahr_xgb_model, mmp_rf_model, mmp_xgb_model]
model_name = ['ahr_rf_pred', 'ahr_xgb_pred', 'mmp_rf_pred', 'mmp_xgb_pred']

predicted_results_uncorrected = fp_data_non_corrected[['id', 'formula', 'adduct']].copy()

for model, name in zip(models, model_name): 
    if model == ahr_rf_model or model == ahr_xgb_model:
        feature_names = ahr_features
    else:
        feature_names = mmp_features

    # Make predictions
    predicted_probabilities = make_predictions(fp_data_non_corrected.iloc[:, 3:], model, feature_names)
    
    # Add the predicted probabilities to the dataframe
    predicted_results_uncorrected[name] = predicted_probabilities

In [None]:
predicted_results_uncorrected

In [None]:
# with open('2025-05-21_Model_evaluation_top1MF_pred_uncorrected_mass.pkl', 'wb') as f:
#     pickle.dump(predicted_results_uncorrected, f)

In [None]:
# with open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_no_mass_correction/2025-05-21_Rank1_fingerprints.pkl', 'wb') as f:
#     pickle.dump(fp_data_non_corrected, f)

## Get the corresponding label to the names

### Clean names in the 'id' column

In [None]:
import pickle
import pandas as pd

In [None]:
with open('2025-05-21_Model_evaluation_top1MF_pred_uncorrected_mass.pkl', 'rb') as f:
    predicted_results_uncorrected = pickle.load(f)

In [None]:
predicted_results_uncorrected.id = predicted_results_uncorrected.id.apply(lambda x: x.translate(str.maketrans('', '', '()[]{}<>,.+ :\'\"'))).str.lower()

predicted_results_uncorrected.id = predicted_results_uncorrected.id.apply(lambda x: x.split('_'))

predicted_results_uncorrected.id = [x[0:-1] if x[-1]=='combined'
                            else x for x in predicted_results_uncorrected.id]

predicted_results_uncorrected.id = ['_'.join(x[0:-1]) if x[-1]=='e' or x[-1]=='h' or x[-1].isnumeric()
                            else '_'.join(x) for x in predicted_results_uncorrected.id]

predicted_results_uncorrected.id = predicted_results_uncorrected.id.apply(lambda x: x.translate(str.maketrans('', '', '_-')))

In [None]:
predicted_results_uncorrected.sort_values(by=['id'], inplace=True)

In [None]:
predicted_results_uncorrected

### clean the ID column in 'ground truth' and merge with tox21 dataset

In [None]:
with open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/2025-05-13_ground_truth_molecular_formula_no_sirius.pkl', 'rb') as f:
     ground_truth = pickle.load(f)

In [None]:
ground_truth.id = ground_truth.id.apply(lambda x: x.translate(str.maketrans('', '', '_-')))

In [None]:
ground_truth

In [None]:
with open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Code/Tox21 comparison/2025-03-06_tox21_ahr_mmp_available_compounds_all_sources_UPDATED.pkl','rb') as f:
    tox21 = pickle.load(f)

In [None]:
tox21 = tox21.drop_duplicates(subset='InChIKey14', keep='first')
tox21 = tox21[tox21.sirius_data.isna()]

In [None]:
tox21_gt = tox21.merge(ground_truth[['InChIKey14', 'id', 'molecular_formula']], how='left', on='InChIKey14')

In [None]:
tox21_gt

### Merge the fp data with ground truth

In [None]:
fp_data_uncorrected_with_gt = tox21_gt.merge(predicted_results_uncorrected, on='id', how='left')

In [None]:
fp_data_uncorrected_with_gt

In [None]:
fp_data_uncorrected_with_gt = fp_data_uncorrected_with_gt.drop_duplicates(subset=fp_data_uncorrected_with_gt.columns, keep='first')

In [None]:
fp_data_uncorrected_with_gt.shape

In [None]:
fp_data_uncorrected_with_gt_final = fp_data_uncorrected_with_gt[fp_data_uncorrected_with_gt.formula.notna()]

fp_data_uncorrected_with_gt_final = fp_data_uncorrected_with_gt_final.drop(columns=['iris_data', 'ms_library', 'isabel_data', 'sirius_data', 'old_klara_MMK', 'section_aces', 'section_kemikum','gc_probability'])

In [None]:
fp_data_uncorrected_with_gt_final

In [None]:
# # save datframe
# with open('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_no_mass_correction/2025-05-21_Rank1_fingerprints_with_ground_truth.pkl', 'wb') as f:
#     pickle.dump(fp_data_uncorrected_with_gt, f)

# Prediction of top1 formula fingerprints with mass correction

## Get all fingerprints for top 1 molecular formula

In [None]:
ellinor = '/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_mass_correction_2025-05-13/ellinor_data'
iris = '/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_mass_correction_2025-05-13/iris_data_dry'
isabell = '/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_mass_correction_2025-05-13/isabell_data'
library = '/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Data/MSMS/SIRIUS_output/output_mass_correction_2025-05-13/library_data'

output_folders = [ellinor, iris, isabell, library]

# get all MP
fp_data_corrected, without_fp = get_rank1_fingerprints(output_folders, esi_mode='pos')
fp_data_corrected.shape

In [None]:
fp_data_corrected.sort_values(by=['id'], inplace=True)
fp_data_corrected=fp_data_corrected.drop_duplicates().reset_index(drop=True)

In [None]:
fp_data_corrected

In [None]:
predicted_results_corrected = fp_data_corrected[['id', 'formula', 'adduct']].copy()

for model, name in zip(models, model_name): 
    if model == ahr_rf_model or model == ahr_xgb_model:
        feature_names = ahr_features
    else:
        feature_names = mmp_features

    # Make predictions
    predicted_probabilities = make_predictions(fp_data_corrected.iloc[:, 3:], model, feature_names)
    
    # Add the predicted probabilities to the dataframe
    predicted_results_corrected[name] = predicted_probabilities

In [None]:
predicted_results_corrected

In [None]:
# with open('2025-05-21_Model_evaluation_top1MF_pred_corrected_mass.pkl', 'wb') as f:
#     pickle.dump(predicted_results_corrected, f)

## Connect to ground truth and get labels from tox21 dataset

### Clean fp_data

In [None]:
with open('2025-05-21_Model_evaluation_top1MF_pred_corrected_mass.pkl', 'rb') as f:
    predicted_results_corrected = pickle.load(f)

In [None]:
predicted_results_corrected.id = predicted_results_corrected.id.apply(lambda x: x.translate(str.maketrans('', '', '()[]{}<>,.+ :\'\"'))).str.lower()

predicted_results_corrected.id = predicted_results_corrected.id.apply(lambda x: x.split('_'))

predicted_results_corrected.id = ['_'.join(x[0:-1]) if x[-1]=='combined'
                            else '_'.join(x) for x in predicted_results_corrected.id]

predicted_results_corrected.id = predicted_results_corrected.id.apply(lambda x: x.split('-'))

predicted_results_corrected.id = ['-'.join(x[0:-1]) if x[-1]=='e' or x[-1]=='h' or x[-1].isnumeric()
                            else '-'.join(x) for x in predicted_results_corrected.id]

predicted_results_corrected.id = predicted_results_corrected.id.apply(lambda x: x.translate(str.maketrans('', '', '_-')))

In [None]:
predicted_results_corrected.head()

### Merge with ground truth and tox21

In [None]:
fp_corrected_with_gt = tox21_gt.merge(predicted_results_corrected, on='id', how='left')

fp_corrected_with_gt.shape

In [None]:
fp_corrected_with_gt = fp_corrected_with_gt.drop_duplicates(subset=fp_corrected_with_gt.columns, keep='first')
fp_corrected_with_gt.shape

In [None]:
fp_corrected_with_gt_final = fp_corrected_with_gt[fp_corrected_with_gt.formula.notna()]
fp_corrected_with_gt_final = fp_corrected_with_gt_final.drop(columns=['iris_data', 'ms_library', 'isabel_data', 'sirius_data', 'old_klara_MMK', 'section_aces', 'section_kemikum','gc_probability'])
fp_corrected_with_gt_final

In [None]:
# with open('evaluation_set_corrected_mass.pkl', 'wb') as f:
#     pickle.dump(fp_corrected_with_gt_final, f)

# Evaluations and visualizations

## Uncorrected evaluation

In [None]:
import pandas as pd
import numpy as np
import pickle

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, \
                            precision_recall_curve, roc_curve, balanced_accuracy_score, \
                            precision_score, f1_score, recall_score, ConfusionMatrixDisplay

In [None]:
with open('2025-05-18_threshold_fpr_at_90_recall_ahr_mmp_models.pkl', 'rb') as f:
    model_metrics = pickle.load(f)

In [None]:
fp_data_uncorrected_with_gt_final

In [None]:
fp_data_uncorrected_with_gt_final_ahr = fp_data_uncorrected_with_gt_final[fp_data_uncorrected_with_gt_final['nr.ahr'].notna()]
fp_data_uncorrected_with_gt_final_mmp = fp_data_uncorrected_with_gt_final[fp_data_uncorrected_with_gt_final['sr.mmp'].notna()]

In [None]:
model_metrics

In [None]:
model_metrics.loc[0, 'threshold_90_recall']

In [None]:
def get_metrics(y_true, y_proba, threshold, model_name, endpoint):
    # Calculate ROC-AUC score
    roc_auc = roc_auc_score(y_true, y_proba)

    # Calculate precision-recall curve
    fpr, tpr, thresh= roc_curve(y_true, y_proba)
    roc_auc_curve = pd.DataFrame({'fpr': fpr, 
                                  'tpr': tpr, 
                                  'thresh': thresh})

    # Use threshold to get binary predictions
    y_pred = np.where(y_proba >= threshold, 1, 0)

    # Get confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()

    # Calculate fpr_at_90_recall
    fpr_90_recall = fp/(fp + tn)

    # Calculate balanced accuracy
    ba = balanced_accuracy_score(y_true, y_pred)

    metrics_df = pd.DataFrame({'endpoint': [endpoint],
                               'model': [model_name],
                               'fpr_90_recall_evaluation': [fpr_90_recall],
                               'roc_auc': [roc_auc],
                               'balanced_accuracy': [ba]})
    
    #Make confusion matrix
    cm_vis = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Inactive', 'Active'])

    
    return metrics_df, roc_auc_curve, cm_vis

In [None]:
ahr_rf_metrics, ahr_rf_roc_auc_curve, ahr_rf_cm_vis = get_metrics(fp_data_uncorrected_with_gt_final_ahr['nr.ahr'],
                                                                    fp_data_uncorrected_with_gt_final_ahr['ahr_rf_pred'],
                                                                    model_metrics.loc[1, 'threshold_90_recall'],
                                                                    'RF',
                                                                    'AHR')
ahr_xgb_metrics, ahr_xgb_roc_auc_curve, ahr_xgb_cm_vis = get_metrics(fp_data_uncorrected_with_gt_final_ahr['nr.ahr'],
                                                                    fp_data_uncorrected_with_gt_final_ahr['ahr_xgb_pred'],
                                                                    model_metrics.loc[0, 'threshold_90_recall'],
                                                                    'XGB',
                                                                    'AHR')
mmp_rf_metrics, mmp_rf_roc_auc_curve, mmp_rf_cm_vis = get_metrics(fp_data_uncorrected_with_gt_final_mmp['sr.mmp'],
                                                                    fp_data_uncorrected_with_gt_final_mmp['mmp_rf_pred'],
                                                                    model_metrics.loc[2, 'threshold_90_recall'],
                                                                    'RF',
                                                                    'MMP')
mmp_xgb_metrics, mmp_xgb_roc_auc_curve, mmp_xgb_cm_vis = get_metrics(fp_data_uncorrected_with_gt_final_mmp['sr.mmp'],
                                                                    fp_data_uncorrected_with_gt_final_mmp['mmp_xgb_pred'],
                                                                    model_metrics.loc[3, 'threshold_90_recall'],
                                                                    'XGB',
                                                                    'MMP')

In [None]:
uncorrected_evaluation_metrics = pd.concat([ahr_rf_metrics, ahr_xgb_metrics, mmp_rf_metrics, mmp_xgb_metrics], ignore_index=True)
uncorrected_evaluation_metrics

In [None]:
with open('2025-05-22_evaluation_metrics_uncorrected_mass.pkl', 'wb') as f:
    pickle.dump(uncorrected_evaluation_metrics, f)

In [None]:
import matplotlib.pyplot as plt

# Set the figure parameters
plt.rcParams.update({'figure.figsize':[9.8,9.8],
                'font.size': 16, 
                'font.weight': 'normal',
                'axes.titlesize': 12,
                'axes.labelsize': 12,
                'xtick.labelsize': 12,
                'ytick.labelsize': 12,
                'legend.fontsize': 12,
                'legend.title_fontsize': 12,
                'axes.titleweight': 'bold',
                'font.family': 'serif',
                'font.serif': ['Times New Roman'],
                'figure.dpi':300,
                
                })

fig, ax = plt.subplots(2, 2, figsize=(10, 9.5), dpi=300)
# AHR
ahr_rf_cm_vis.plot(ax=ax[0, 0], cmap='Blues', colorbar=False)
ahr_rf_cm_vis.ax_.set_title('AhR RF')

ahr_xgb_cm_vis.plot(ax=ax[0, 1], cmap='Blues', colorbar=False)
ahr_xgb_cm_vis.ax_.set_title('AhR XGB')

# MMP
mmp_rf_cm_vis.plot(ax=ax[1, 0], cmap='Oranges', colorbar=False)
mmp_rf_cm_vis.ax_.set_title('MMP RF')

mmp_xgb_cm_vis.plot(ax=ax[1, 1], cmap='Oranges', colorbar=False)
mmp_xgb_cm_vis.ax_.set_title('MMP XGB')

plt.tight_layout()

#plt.suptitle('Evaluation set\nConfusion matrices for AhR and MMP models at TPR=90%', fontsize=16, fontweight='bold', y=1.03)



plt.savefig('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Visualizations/2025-05-22_confusion_matrix_evaluation_uncorrected.pdf', dpi=300, bbox_inches='tight')

In [None]:
#Plot all ROC curves
fig, ax = plt.subplots(1, 2, figsize=(8, 5), dpi=300, tight_layout=True, sharex=True, sharey=True)

ax[0].plot(ahr_rf_roc_auc_curve['fpr'], ahr_rf_roc_auc_curve['tpr'], label='AhR RF', color='#219EBC')
ax[0].plot(ahr_xgb_roc_auc_curve['fpr'], ahr_xgb_roc_auc_curve['tpr'], label='AhR XGB', color='#023047')
ax[0].grid(visible=True, which='both', linewidth=0.5, alpha=0.5)
ax[0].set_title('AhR models')
ax[0].legend(loc='lower right')
ax[0].set_aspect('equal')


ax[1].plot(mmp_rf_roc_auc_curve['fpr'], mmp_rf_roc_auc_curve['tpr'], label='MMP RF', color='#FFB703')
ax[1].plot(mmp_xgb_roc_auc_curve['fpr'], mmp_xgb_roc_auc_curve['tpr'], label='MMP XGB', color='#FB8500')
ax[1].grid(visible=True, which='both', linewidth=0.5, alpha=0.5)
ax[1].set_title('MMP models')
ax[1].legend(loc='lower right')
ax[1].set_aspect('equal')


#fig.suptitle('Evaluation set\nROC-AUC for AhR and MMP models', fontweight='bold', fontsize=16, y = 0.93)
fig.supxlabel('False Positive Rate', y=0.12, fontsize=12)
fig.supylabel('True Positive Rate', x=0.05, fontsize=12)

fig.savefig('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Visualizations/2025-05-22_roc_curves_ahr_mmp_evaluation_uncorrected.pdf', dpi=300, bbox_inches='tight',
            transparent=True)

## Corrected evaluation

In [None]:
fp_corrected_with_gt_final_ahr = fp_corrected_with_gt_final[fp_corrected_with_gt_final['nr.ahr'].notna()]
fp_corrected_with_gt_final_mmp = fp_corrected_with_gt_final[fp_corrected_with_gt_final['sr.mmp'].notna()]

In [None]:
ahr_rf_metrics, ahr_rf_roc_auc_curve, ahr_rf_cm_vis = get_metrics(fp_corrected_with_gt_final_ahr['nr.ahr'],
                                                                    fp_corrected_with_gt_final_ahr['ahr_rf_pred'],
                                                                    model_metrics.loc[1, 'threshold_90_recall'],
                                                                    'RF',
                                                                    'AHR')
ahr_xgb_metrics, ahr_xgb_roc_auc_curve, ahr_xgb_cm_vis = get_metrics(fp_corrected_with_gt_final_ahr['nr.ahr'],
                                                                    fp_corrected_with_gt_final_ahr['ahr_xgb_pred'],
                                                                    model_metrics.loc[0, 'threshold_90_recall'],
                                                                    'XGB',
                                                                    'AHR')
mmp_rf_metrics, mmp_rf_roc_auc_curve, mmp_rf_cm_vis = get_metrics(fp_corrected_with_gt_final_mmp['sr.mmp'],
                                                                    fp_corrected_with_gt_final_mmp['mmp_rf_pred'],
                                                                    model_metrics.loc[2, 'threshold_90_recall'],
                                                                    'RF',
                                                                    'MMP')
mmp_xgb_metrics, mmp_xgb_roc_auc_curve, mmp_xgb_cm_vis = get_metrics(fp_corrected_with_gt_final_mmp['sr.mmp'],
                                                                    fp_corrected_with_gt_final_mmp['mmp_xgb_pred'],
                                                                    model_metrics.loc[3, 'threshold_90_recall'],
                                                                    'XGB',
                                                                    'MMP')

In [None]:
corrected_evaluation_metrics = pd.concat([ahr_rf_metrics, ahr_xgb_metrics, mmp_rf_metrics, mmp_xgb_metrics], ignore_index=True)
corrected_evaluation_metrics

In [None]:
with open('2025-05-22_corrected_evaluation_metrics.pkl', 'wb') as f:
    pickle.dump(corrected_evaluation_metrics, f)

In [None]:
import matplotlib.pyplot as plt

# Set the figure parameters
plt.rcParams.update({'figure.figsize':[9.8,9.8],
                'font.size': 16, 
                'font.weight': 'normal',
                'axes.titlesize': 12,
                'axes.labelsize': 12,
                'xtick.labelsize': 12,
                'ytick.labelsize': 12,
                'legend.fontsize': 12,
                'legend.title_fontsize': 12,
                'axes.titleweight': 'bold',
                'font.family': 'serif',
                'font.serif': ['Times New Roman'],
                'figure.dpi':300,
                
                })

fig, ax = plt.subplots(2, 2, figsize=(10, 9.5), dpi=300)
# AHR
ahr_rf_cm_vis.plot(ax=ax[0, 0], cmap='Blues', colorbar=False)
ahr_rf_cm_vis.ax_.set_title('AhR RF')

ahr_xgb_cm_vis.plot(ax=ax[0, 1], cmap='Blues', colorbar=False)
ahr_xgb_cm_vis.ax_.set_title('AhR XGB')

# MMP
mmp_rf_cm_vis.plot(ax=ax[1, 0], cmap='Oranges', colorbar=False)
mmp_rf_cm_vis.ax_.set_title('MMP RF')

mmp_xgb_cm_vis.plot(ax=ax[1, 1], cmap='Oranges', colorbar=False)
mmp_xgb_cm_vis.ax_.set_title('MMP XGB')

plt.tight_layout()

#plt.suptitle('Evaluation set\nConfusion matrices for AhR and MMP models at TPR=90%', fontsize=16, fontweight='bold', y=1.03)



plt.savefig('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Visualizations/2025-05-22_confusion_matrix_evaluation_corrected.pdf', dpi=300, bbox_inches='tight')

In [None]:
#Plot all ROC curves
fig, ax = plt.subplots(1, 2, figsize=(8, 5), dpi=300, tight_layout=True, sharex=True, sharey=True)

ax[0].plot(ahr_rf_roc_auc_curve['fpr'], ahr_rf_roc_auc_curve['tpr'], label='AhR RF', color='#219EBC')
ax[0].plot(ahr_xgb_roc_auc_curve['fpr'], ahr_xgb_roc_auc_curve['tpr'], label='AhR XGB', color='#023047')
ax[0].grid(visible=True, which='both', linewidth=0.5, alpha=0.5)
ax[0].set_title('AhR models')
ax[0].legend(loc='lower right')
ax[0].set_aspect('equal')

ax[1].plot(mmp_rf_roc_auc_curve['fpr'], mmp_rf_roc_auc_curve['tpr'], label='MMP RF', color='#FFB703')
ax[1].plot(mmp_xgb_roc_auc_curve['fpr'], mmp_xgb_roc_auc_curve['tpr'], label='MMP XGB', color='#FB8500')
ax[1].grid(visible=True, which='both', linewidth=0.5, alpha=0.5)
ax[1].set_title('MMP models')
ax[1].legend(loc='lower right')
ax[1].set_aspect('equal')


#fig.suptitle('Evaluation set\nROC-AUC for AhR and MMP models', fontweight='bold', fontsize=16, y = 0.93)
fig.supxlabel('False Positive Rate', y=0.12, fontsize=12)
fig.supylabel('True Positive Rate', x=0.05, fontsize=12)

fig.savefig('/Users/elli/Library/CloudStorage/OneDrive-Kruvelab/Master_thesis/Visualizations/2025-05-22_roc_curves_ahr_mmp_evaluation_corrected.pdf', dpi=300, bbox_inches='tight',
           transparent=True)