# Comparing models from recent pipeline runs with previous models using KGML-xDTD embeddings

#### 26 Aug 2024

The purpose of this notebook is to compare the performance of the models from the recent pipeline runs with the previous models using embeddings from the KGML-xDTD paper, as well as the first e2e pipeline run (from approx 14 August 2024) in which there was a bug in the evaluation pipeline. 

The notebook produces a table with the evaluation metrics for the models. See the bottom of the notebook. 

**Conclusions**: 

- We beat the old (internal) record for ranking metrics. This is a convincing indication we have competitive embeddings now.
- Comparing the scores for xg_ensemble, we see that we have clearly improved on the embeddings of the first run.
- The xg_synth scores for different embeddings suggest that our embeddings are still not quite as good as KGML-xDTD. This is not surprising since we are temporarily training with limited compute than Chunyu.
- The ensemble models have good ranking scores. Previously, we hadn’t tried combing the ensemble  technique with our other performance boosting method (BHO, large amount of synthesised negatives), so this a good sign.

In [41]:
import pandas as pd
import numpy as np
import joblib
import yaml

In [3]:
# %load_ext kedro.ipython

In [4]:
import pickle
import pandas as pd
import numpy as np
import sys

import matrix.datasets.pair_generator as pair_generator
from matrix.datasets.graph import KnowledgeGraph
from matrix.pipelines.evaluation.nodes import make_test_predictions, evaluate_test_predictions

## Input files

### Ground truth data with test-train splits

In [7]:
known_pairs = pd.read_parquet('/Users/alexei/Documents/data/e2e_14_aug/releases_20240807_05_model_input_splits')

### KGs with embeddings (KGML-xDTD, old run and new run)

In [16]:
with open('./input/rtx_kg2_nodes_chunyu', 'rb') as f:
    rtx_kg2 = pd.read_parquet(f)
rtx_kg2 = rtx_kg2.rename(columns={'embedding': 'topological_embedding'})
graph_chunyu = KnowledgeGraph(rtx_kg2)

In [10]:
with open('/Users/alexei/Documents/data/e2e_14_aug/releases_20240807_04_feature_rtx_kg2_nodes', 'rb') as f:
    rtx_kg2 = pd.read_parquet(f)
graph_first_run = KnowledgeGraph(rtx_kg2)

### Models

In [72]:
# Models using Chunyu's embeddings
xg_synth_chunyu = joblib.load('/Users/alexei/Documents/repos/matrix/pipelines/matrix/notebooks/scratch/local/generate_matrix/input/xg_balanced_retrain.joblib')
rf_chunyu = joblib.load('/Users/alexei/Documents/repos/matrix/pipelines/matrix/notebooks/scratch/local/misc/train-kgml-xdtd/output/kgml_xdtd_split.joblib')

In [24]:
# Model using embeddings from first run
with open("/Users/alexei/Documents/data/e2e_14_aug/releases_20240807_06_models_xgc_model.pickle", "rb") as f:
    xg_ensemble_first_run = pickle.load(f)

### Other objects needed

In [43]:
# Load model params
with open('/Users/alexei/Documents/repos/matrix/pipelines/matrix/conf/base/modelling/parameters/defaults.yml', 'r') as f:
    model_params_defaults = yaml.safe_load(f)

In [49]:
# Load objects needed for predictions
score_col_name = 'treat score'
features = model_params_defaults['_model_options']['model_tuning_args']['features']
transformers = model_params_defaults['_model_options']['transformers']

In [61]:
# Load drug flags
with open('/Users/alexei/Documents/repos/matrix/pipelines/matrix/conf/base/modelling/parameters/xg_baseline.yml', 'r') as f:
    model_params_baseline = yaml.safe_load(f)
drug_flags = model_params_baseline['modelling.xg_baseline']['_overrides']['generator']['drug_flags']


## Predictions

### Disease-centric matrix

In [105]:
# Generating disease-centric matrix
matrix = pair_generator.MatrixTestDiseases(drug_flags)
matrix = matrix.generate(graph_first_run, known_pairs)

100%|██████████| 1133/1133 [00:00<00:00, 8038.19it/s]


In [106]:
# Making predictions
data = matrix.copy()
mat_preds_xg_synth_chunyu = make_test_predictions(graph_chunyu, data, transformers, xg_synth_chunyu, features, score_col_name)
data = matrix.copy()
mat_preds_rf_chunyu = make_test_predictions(graph_chunyu, data, transformers, rf_chunyu, features, score_col_name)
data = matrix.copy()
mat_preds_xg_ensemble_first_run = make_test_predictions(graph_first_run, data, transformers, xg_ensemble_first_run, features, score_col_name)

100%|██████████| 1133/1133 [07:30<00:00,  2.52it/s]
100%|██████████| 1133/1133 [07:46<00:00,  2.43it/s]
100%|██████████| 1133/1133 [08:44<00:00,  2.16it/s]


### Ground truth test set

In [107]:
gt_data  = pair_generator.GroundTruthTestPairs()
gt_data = gt_data.generate(graph_first_run, known_pairs)

In [108]:
# Making predictions
data = gt_data.copy()
gt_preds_xg_synth_chunyu = make_test_predictions(graph_chunyu, data, transformers, xg_synth_chunyu, features, score_col_name)
data = gt_data.copy()
gt_preds_rf_chunyu = make_test_predictions(graph_chunyu, data, transformers, rf_chunyu, features, score_col_name)
data = gt_data.copy()
gt_preds_xg_ensemble_first_run = make_test_predictions(graph_first_run, data, transformers, xg_ensemble_first_run, features, score_col_name)

100%|██████████| 1510/1510 [00:16<00:00, 90.22it/s]
100%|██████████| 1510/1510 [01:09<00:00, 21.86it/s]
100%|██████████| 1510/1510 [00:18<00:00, 81.43it/s]


## Computing evaluation metrics

In [109]:
# Loading evaluation options from the catalog
with open('/Users/alexei/Documents/repos/matrix/pipelines/matrix/conf/base/evaluation/parameters.yml', 'r') as f:
    eval_params = yaml.safe_load(f)
spec_ranking_options = eval_params['evaluation.disease_specific_ranking']['evaluation_options']['evaluation']
mat_ranking_options = eval_params['evaluation.disease_centric_matrix']['evaluation_options']['evaluation']
classification_options = eval_params['evaluation.simple_ground_truth_classification']['evaluation_options']['evaluation']

In [110]:
# Computing metrics with the pipeline
spec_ranking_xg_synth_chunyu = evaluate_test_predictions(mat_preds_xg_synth_chunyu, spec_ranking_options)
mat_ranking_xg_synth_chunyu = evaluate_test_predictions(mat_preds_xg_synth_chunyu, mat_ranking_options)
classification_xg_synth_chunyu = evaluate_test_predictions(gt_preds_xg_synth_chunyu, classification_options)
spec_ranking_rf_chunyu = evaluate_test_predictions(mat_preds_rf_chunyu, spec_ranking_options)
mat_ranking_rf_chunyu = evaluate_test_predictions(mat_preds_rf_chunyu, mat_ranking_options)
classification_rf_chunyu = evaluate_test_predictions(gt_preds_rf_chunyu, classification_options)
spec_ranking_xg_ensemble_first_run = evaluate_test_predictions(mat_preds_xg_ensemble_first_run, spec_ranking_options)
mat_ranking_xg_ensemble_first_run = evaluate_test_predictions(mat_preds_xg_ensemble_first_run, mat_ranking_options)
classification_xg_ensemble_first_run = evaluate_test_predictions(gt_preds_xg_ensemble_first_run, classification_options)

100%|██████████| 1133/1133 [01:24<00:00, 13.36it/s]
100%|██████████| 1133/1133 [01:24<00:00, 13.40it/s]
100%|██████████| 1133/1133 [01:23<00:00, 13.64it/s]


## Overview of results

In [131]:
model_name_lst = ['xg_synth_chunyu', 'rf_chunyu', 'xg_ensemble_first_run']
metric_name_lst = ['auroc', 'ap', 'mrr', 'hit2', 'hit10', 'hit100', 'acc', 'f1']
auroc_lst = [eval('mat_ranking_'+model_name)['roc_auc_score'] for model_name in model_name_lst]
ap_lst = [eval('mat_ranking_'+model_name)['average_precision_score'] for model_name in model_name_lst]
mrr_lst = [eval('spec_ranking_'+model_name)['mrr'] for model_name in model_name_lst]
hit2_lst = [eval('spec_ranking_'+model_name)['hit-2'] for model_name in model_name_lst]
hit10_lst = [eval('spec_ranking_'+model_name)['hit-10'] for model_name in model_name_lst]
hit100_lst = [eval('spec_ranking_'+model_name)['hit-100'] for model_name in model_name_lst]
acc_lst = [eval('classification_'+model_name)['accuracy_score'] for model_name in model_name_lst]
f1_lst = [eval('classification_'+model_name)['f1_score'] for model_name in model_name_lst]

In [132]:
# Tedious manual input of latest results (new run 26 Aug)
results_xg_ensemble_new_run_3 = {'auroc': 0.9219692860724151,
'ap': 0.043389910213820035,
'mrr': 0.19728144176219592,
'hit2': 0.17863805970149255,
'hit10': 0.314365671641791,
'hit100': 0.6203358208955224,
'acc': 0.8574043565806333,
'f1': 0.7843897038472184
}
results_xg_synth_new_run_3 = {
'auroc': 0.8383403010392563,
'ap': 0.007503326148252967,
'mrr': 0.10459264807623478,
'hit2': 0.08861940298507463,
'hit10': 0.17583955223880596,
'hit100': 0.4398320895522388,
'acc': 0.9101226432363171,
'f1': 0.8768497617256082
}
results_xg_ensemble_new_run_6 = {
'mrr': 0.14736130097929595,
'hit2': 0.13013059701492538,
'hit10': 0.2462686567164179,
'hit100': 0.554570895522388,
'auroc': 0.9055752875277525,
'ap': 0.028384013211504897,
'acc': 0.8411129416071755,
'f1': 0.7535491198182851
}
results_xg_synth_new_run_6 = {'ap': 0.005491298670828154,
'hit100': 0.376865671641791,
'auroc': 0.8248684530632286,
'mrr': 0.06576272993950255,
'hit2': 0.05177238805970149,
'hit10': 0.11800373134328358,
'acc': 0.9038989566172433,
'f1': 0.8678580417820287
}

In [133]:
model_name_lst += ['xg_ensemble_new_run_3', 'xg_synth_new_run_3', 'xg_ensemble_new_run_6', 'xg_synth_new_run_6']
for model_name in model_name_lst[-4:]:
    for metric_name in metric_name_lst:
        eval(metric_name + '_lst').append(eval('results_' + model_name)[metric_name])

In [134]:
results_df = pd.DataFrame({metric_name: eval(metric_name + '_lst') for metric_name in metric_name_lst}, index=model_name_lst)

In [135]:
# Function to highlight the best model for each metric
def highlight_max(s):
    is_max = s == s.max()
    return ['background-color: yellow' if v else '' for v in is_max]

styled_df = results_df.style.apply(highlight_max, axis=0)
styled_df

Unnamed: 0,auroc,ap,mrr,hit2,hit10,hit100,acc,f1
xg_synth_chunyu,0.867708,0.019273,0.153431,0.132929,0.271455,0.578358,0.926231,0.900074
rf_chunyu,0.741243,0.006884,0.129286,0.120336,0.197295,0.423041,0.882482,0.83342
xg_ensemble_first_run,0.844903,0.008687,0.068847,0.047575,0.132929,0.398787,0.78272,0.62941
xg_ensemble_new_run_3,0.921969,0.04339,0.197281,0.178638,0.314366,0.620336,0.857404,0.78439
xg_synth_new_run_3,0.83834,0.007503,0.104593,0.088619,0.17584,0.439832,0.910123,0.87685
xg_ensemble_new_run_6,0.905575,0.028384,0.147361,0.130131,0.246269,0.554571,0.841113,0.753549
xg_synth_new_run_6,0.824868,0.005491,0.065763,0.051772,0.118004,0.376866,0.903899,0.867858
