In [1]:
import os
import pandas as pd

from utils import risk_by_lowest, risk_by_highest

In [2]:
meta_models = ['sl', 'tl']
base_models = ['l1', 'l2', 'dt', 'rf', 'et', 'kr', 'cb', 'lgbm']
base_metrics_dir = '../results/metrics/run1/'

plugin_meta_models = ['sl', 'tl']
plugin_base_models = ['dt', 'lgbm', 'cb']
plugin_dir = '../results/scores/run1/'

rscore_base_models = ['dt', 'lgbm', 'cb']
rscore_dir = '../results/scores/run1/'

In [3]:
def process_mse():
    mse_list = []
    for mm in meta_models:
        for bm in base_models:
            est_name = f'{mm}_{bm}'
            df_base_test = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_test_metrics.csv'))

            # Val MSE
            df_base_val = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_val_metrics.csv'))
            df_base_val_gr = df_base_val.groupby(['iter_id', 'param_id'], as_index=False).mean().drop(columns=['fold_id'])

            if mm == 'tl':
                df_base_val_gr['mse'] = df_base_val_gr[['mse_m0', 'mse_m1']].mean(axis=1)
                df_base_test['mse'] = df_base_test[['mse_m0', 'mse_m1']].mean(axis=1)

            df_base = df_base_val_gr.merge(df_base_test, on=['iter_id', 'param_id'], suffixes=['_val', '_test'])
            mse_i = risk_by_lowest(df_base, 'mse_val', ['ate_test', 'pehe_test'])
            mse_list.append([est_name] + mse_i)

    return pd.DataFrame(mse_list, columns=['name', 'ate_mse', 'pehe_mse'])

In [4]:
def process_plugins(df_main):
    df_copy = df_main.copy()
    for plugin_mm in plugin_meta_models:
        for plugin_bm in plugin_base_models:
            plugin_name = f'{plugin_mm}_{plugin_bm}'
            plugin_ate_list = []
            plugin_pehe_list = []
            for mm in meta_models:
                for bm in base_models:
                    est_name = f'{mm}_{bm}'
                    df_base_test = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_test_metrics.csv'))

                    # Plugin ATE and PEHE
                    df_plugin_val = pd.read_csv(os.path.join(plugin_dir, plugin_name, f'{est_name}_plugin_{plugin_name}.csv'))
                    df_plugin_val_gr = df_plugin_val.groupby(['iter_id', 'param_id'], as_index=False).mean().drop(columns=['fold_id'])
                    df_plugin = df_plugin_val_gr.merge(df_base_test, on=['iter_id', 'param_id'], suffixes=['_val', '_test'])
                    plugin_ate_i = risk_by_lowest(df_plugin, 'ate_val', ['ate_test', 'pehe_test'])
                    plugin_pehe_i = risk_by_lowest(df_plugin, 'pehe_val', ['ate_test', 'pehe_test'])
                    plugin_ate_list.append([est_name] + plugin_ate_i)
                    plugin_pehe_list.append([est_name] + plugin_pehe_i)

            df_plugin_ate = pd.DataFrame(plugin_ate_list, columns=['name', f'ate_{plugin_name}_ate', f'pehe_{plugin_name}_ate'])
            df_plugin_pehe = pd.DataFrame(plugin_pehe_list, columns=['name', f'ate_{plugin_name}_pehe', f'pehe_{plugin_name}_pehe'])
            df_plugin = df_plugin_ate.merge(df_plugin_pehe, on=['name'])
            df_copy = df_copy.merge(df_plugin, on=['name'])
    
    return df_copy

In [5]:
def process_rscores(df_main):
    df_copy = df_main.copy()
    for rs_bm in rscore_base_models:
        rs_name = f'rs_{rs_bm}'
        scores_list = []
        for mm in meta_models:
            for bm in base_models:
                est_name = f'{mm}_{bm}'
                df_base_test = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_test_metrics.csv'))

                # R-Score
                df_rscore_val = pd.read_csv(os.path.join(rscore_dir, rs_name, f'{est_name}_r_score_{rs_name}.csv'))
                df_rscore_val_gr = df_rscore_val.groupby(['iter_id', 'param_id'], as_index=False).mean().drop(columns=['fold_id'])
                df_rscore_test = df_rscore_val_gr.merge(df_base_test, on=['iter_id', 'param_id'])
                rscore_i = risk_by_highest(df_rscore_test, 'rscore', ['ate', 'pehe'])
                scores_list.append([est_name] + rscore_i)

        df_rscore = pd.DataFrame(scores_list, columns=['name', f'ate_{rs_name}', f'pehe_{rs_name}'])
        df_copy = df_copy.merge(df_rscore, on=['name'])
    
    return df_copy

In [8]:
df_mse = process_mse()
df_plugins = process_plugins(df_mse)
df_rscores = process_rscores(df_plugins)

In [11]:
metric = 'ate'
p_metric = 'pehe'
plugin_cols = [f'{metric}_{pmm}_{pbm}_{p_metric}' for pmm in plugin_meta_models for pbm in plugin_base_models]
rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_base_models]
cols = ['name', f'{metric}_mse'] + plugin_cols + rscore_cols
df_rscores[cols]

Unnamed: 0,name,ate_mse,ate_sl_dt_pehe,ate_sl_lgbm_pehe,ate_sl_cb_pehe,ate_tl_dt_pehe,ate_tl_lgbm_pehe,ate_tl_cb_pehe,ate_rs_dt,ate_rs_lgbm,ate_rs_cb
0,sl_l1,3.654 +/- 0.446,0.593 +/- 1.219,3.654 +/- 0.446,0.865 +/- 0.770,0.054 +/- 0.127,0.042 +/- 0.125,0.054 +/- 0.127,0.598 +/- 0.708,0.148 +/- 0.112,0.148 +/- 0.112
1,sl_l2,0.576 +/- 0.212,0.065 +/- 0.153,0.576 +/- 0.212,0.525 +/- 0.274,0.058 +/- 0.175,0.058 +/- 0.175,0.058 +/- 0.175,0.634 +/- 0.092,0.204 +/- 0.191,0.395 +/- 0.165
2,sl_dt,0.380 +/- 0.794,0.414 +/- 0.789,0.399 +/- 0.783,0.410 +/- 0.780,0.423 +/- 0.573,0.224 +/- 0.375,0.404 +/- 0.550,0.458 +/- 0.616,0.546 +/- 0.920,0.514 +/- 0.822
3,sl_rf,0.943 +/- 2.168,1.204 +/- 2.951,1.219 +/- 2.946,0.855 +/- 2.100,0.212 +/- 0.426,0.219 +/- 0.318,0.264 +/- 0.433,0.283 +/- 0.412,0.244 +/- 0.418,0.235 +/- 0.426
4,sl_et,0.808 +/- 2.143,0.786 +/- 2.150,0.809 +/- 2.143,0.656 +/- 1.718,0.255 +/- 0.391,0.343 +/- 0.533,0.275 +/- 0.442,0.551 +/- 1.218,0.154 +/- 0.213,0.160 +/- 0.211
5,sl_kr,4.747 +/- 2.710,1.621 +/- 3.807,4.747 +/- 2.710,1.649 +/- 1.779,0.534 +/- 1.248,0.307 +/- 0.592,0.343 +/- 0.723,0.987 +/- 0.570,0.248 +/- 0.305,0.401 +/- 0.275
6,sl_cb,1.108 +/- 1.252,0.547 +/- 1.441,1.108 +/- 1.252,1.108 +/- 1.252,0.229 +/- 0.633,0.276 +/- 0.643,0.229 +/- 0.633,0.737 +/- 0.394,0.136 +/- 0.121,0.353 +/- 0.196
7,sl_lgbm,0.087 +/- 0.074,0.087 +/- 0.080,0.089 +/- 0.078,0.088 +/- 0.080,0.083 +/- 0.077,0.097 +/- 0.093,0.090 +/- 0.087,0.100 +/- 0.100,0.086 +/- 0.086,0.087 +/- 0.087
8,tl_l1,0.153 +/- 0.101,1.605 +/- 3.312,4.483 +/- 2.320,1.435 +/- 1.027,0.382 +/- 0.421,0.404 +/- 0.372,0.351 +/- 0.390,1.016 +/- 0.534,0.179 +/- 0.113,0.283 +/- 0.110
9,tl_l2,0.122 +/- 0.100,0.181 +/- 0.162,0.251 +/- 0.144,0.251 +/- 0.144,0.098 +/- 0.064,0.113 +/- 0.065,0.110 +/- 0.065,0.247 +/- 0.147,0.134 +/- 0.104,0.180 +/- 0.091


In [12]:
metric = 'pehe'
p_metric = 'pehe'
plugin_cols = [f'{metric}_{pmm}_{pbm}_{p_metric}' for pmm in plugin_meta_models for pbm in plugin_base_models]
rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_base_models]
cols = ['name', f'{metric}_mse'] + plugin_cols + rscore_cols
df_rscores[cols]

Unnamed: 0,name,pehe_mse,pehe_sl_dt_pehe,pehe_sl_lgbm_pehe,pehe_sl_cb_pehe,pehe_tl_dt_pehe,pehe_tl_lgbm_pehe,pehe_tl_cb_pehe,pehe_rs_dt,pehe_rs_lgbm,pehe_rs_cb
0,sl_l1,2.584 +/- 0.909,0.195 +/- 0.458,2.584 +/- 0.909,0.541 +/- 0.526,0.004 +/- 0.010,0.003 +/- 0.009,0.004 +/- 0.010,0.333 +/- 0.470,0.030 +/- 0.019,0.030 +/- 0.019
1,sl_l2,0.242 +/- 0.135,0.020 +/- 0.056,0.242 +/- 0.135,0.223 +/- 0.153,0.004 +/- 0.013,0.004 +/- 0.013,0.004 +/- 0.013,0.246 +/- 0.128,0.057 +/- 0.079,0.123 +/- 0.084
2,sl_dt,0.038 +/- 0.065,0.079 +/- 0.130,0.033 +/- 0.063,0.033 +/- 0.063,0.427 +/- 0.563,0.238 +/- 0.444,0.446 +/- 0.587,0.568 +/- 0.560,0.711 +/- 0.949,0.711 +/- 0.905
3,sl_rf,0.025 +/- 0.027,0.117 +/- 0.160,0.099 +/- 0.168,0.052 +/- 0.047,0.476 +/- 0.779,0.473 +/- 0.949,0.544 +/- 1.004,0.659 +/- 0.995,0.674 +/- 0.995,0.682 +/- 0.995
4,sl_et,0.074 +/- 0.093,0.126 +/- 0.086,0.104 +/- 0.075,0.097 +/- 0.071,0.497 +/- 0.739,0.426 +/- 0.927,0.403 +/- 0.669,0.538 +/- 0.598,0.945 +/- 1.505,0.947 +/- 1.506
5,sl_kr,2.574 +/- 0.865,0.311 +/- 0.478,2.574 +/- 0.866,0.487 +/- 0.277,0.533 +/- 0.829,0.721 +/- 1.553,0.686 +/- 1.400,1.053 +/- 1.302,1.279 +/- 2.107,1.321 +/- 2.087
6,sl_cb,0.277 +/- 0.136,0.030 +/- 0.080,0.277 +/- 0.136,0.277 +/- 0.136,0.024 +/- 0.061,0.003 +/- 0.008,0.024 +/- 0.061,0.490 +/- 0.503,0.303 +/- 0.571,0.450 +/- 0.542
7,sl_lgbm,0.053 +/- 0.035,0.034 +/- 0.062,0.051 +/- 0.059,0.039 +/- 0.062,0.062 +/- 0.141,0.042 +/- 0.066,0.036 +/- 0.066,0.062 +/- 0.095,0.129 +/- 0.180,0.114 +/- 0.186
8,tl_l1,1.424 +/- 2.460,0.429 +/- 0.850,2.512 +/- 0.764,0.495 +/- 0.269,0.610 +/- 1.187,0.571 +/- 1.253,0.612 +/- 1.237,1.015 +/- 1.091,1.356 +/- 2.241,1.393 +/- 2.436
9,tl_l2,0.210 +/- 0.268,0.009 +/- 0.013,0.033 +/- 0.023,0.033 +/- 0.023,0.008 +/- 0.014,0.009 +/- 0.014,0.007 +/- 0.014,0.048 +/- 0.039,0.061 +/- 0.048,0.064 +/- 0.063
