In [1]:
import os
import pandas as pd

from utils import get_best_metric, get_by_lowest, get_by_highest

In [3]:
meta_models = ['sl', 'tl']
base_models = ['l1', 'l2', 'dt', 'rf', 'et', 'kr', 'cb', 'lgbm']
base_metrics_dir = '../results/metrics/run1/'

plugin_meta_models = ['sl', 'tl']
plugin_base_models = ['dt', 'lgbm', 'cb']
plugin_dir = '../results/scores/run1/'

rscore_base_models = ['dt', 'lgbm', 'cb']
rscore_dir = '../results/scores/run1/'

In [4]:
def process_test():
    test_list = []
    for mm in meta_models:
        for bm in base_models:
            est_name = f'{mm}_{bm}'

            # Test ATE and PEHE
            df_base_test = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_test_metrics.csv'))
            ate_test_i = get_best_metric(df_base_test, 'ate')
            pehe_test_i = get_best_metric(df_base_test, 'pehe')
            test_list.append([est_name, ate_test_i, pehe_test_i])

    return pd.DataFrame(test_list, columns=['name', 'ate_test', 'pehe_test'])

In [5]:
def process_mse(df_main):
    mse_list = []
    for mm in meta_models:
        for bm in base_models:
            est_name = f'{mm}_{bm}'
            df_base_test = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_test_metrics.csv'))

            # Val MSE
            df_base_val = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_val_metrics.csv'))
            df_base_val_gr = df_base_val.groupby(['iter_id', 'param_id'], as_index=False).mean().drop(columns=['fold_id'])

            if mm == 'tl':
                df_base_val_gr['mse'] = df_base_val_gr[['mse_m0', 'mse_m1']].mean(axis=1)
                df_base_test['mse'] = df_base_test[['mse_m0', 'mse_m1']].mean(axis=1)

            df_base = df_base_val_gr.merge(df_base_test, on=['iter_id', 'param_id'], suffixes=['_val', '_test'])
            mse_i = get_by_lowest(df_base, 'mse_val', ['ate_test', 'pehe_test'])
            mse_list.append([est_name] + mse_i)

    df_mse = pd.DataFrame(mse_list, columns=['name', 'ate_mse', 'pehe_mse'])
    return df_main.merge(df_mse, on=['name'])

In [6]:
def process_plugins(df_main):
    df_copy = df_main.copy()
    for plugin_mm in plugin_meta_models:
        for plugin_bm in plugin_base_models:
            plugin_name = f'{plugin_mm}_{plugin_bm}'
            plugin_path = f'{plugin_dir}{plugin_name}'
            plugin_ate_list = []
            plugin_pehe_list = []
            for mm in meta_models:
                for bm in base_models:
                    est_name = f'{mm}_{bm}'
                    df_base_test = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_test_metrics.csv'))

                    # Plugin ATE and PEHE
                    df_plugin_val = pd.read_csv(os.path.join(plugin_dir, plugin_name, f'{est_name}_plugin_{plugin_name}.csv'))
                    df_plugin_val_gr = df_plugin_val.groupby(['iter_id', 'param_id'], as_index=False).mean().drop(columns=['fold_id'])
                    df_plugin = df_plugin_val_gr.merge(df_base_test, on=['iter_id', 'param_id'], suffixes=['_val', '_test'])
                    plugin_ate_i = get_by_lowest(df_plugin, 'ate_val', ['ate_test', 'pehe_test'])
                    plugin_pehe_i = get_by_lowest(df_plugin, 'pehe_val', ['ate_test', 'pehe_test'])
                    plugin_ate_list.append([est_name] + plugin_ate_i)
                    plugin_pehe_list.append([est_name] + plugin_pehe_i)

            df_plugin_ate = pd.DataFrame(plugin_ate_list, columns=['name', f'ate_{plugin_name}_ate', f'pehe_{plugin_name}_ate'])
            df_plugin_pehe = pd.DataFrame(plugin_pehe_list, columns=['name', f'ate_{plugin_name}_pehe', f'pehe_{plugin_name}_pehe'])
            df_plugin = df_plugin_ate.merge(df_plugin_pehe, on=['name'])
            df_copy = df_copy.merge(df_plugin, on=['name'])
    
    return df_copy

In [7]:
def process_rscores(df_main):
    df_copy = df_main.copy()
    for rs_bm in rscore_base_models:
        rs_name = f'rs_{rs_bm}'
        scores_list = []
        for mm in meta_models:
            for bm in base_models:
                est_name = f'{mm}_{bm}'
                df_base_test = pd.read_csv(os.path.join(base_metrics_dir, est_name, f'{est_name}_test_metrics.csv'))

                # R-Score
                df_rscore_val = pd.read_csv(os.path.join(rscore_dir, rs_name, f'{est_name}_r_score_{rs_name}.csv'))
                df_rscore_val_gr = df_rscore_val.groupby(['iter_id', 'param_id'], as_index=False).mean().drop(columns=['fold_id'])
                df_rscore_test = df_rscore_val_gr.merge(df_base_test, on=['iter_id', 'param_id'])
                rscore_i = get_by_highest(df_rscore_test, 'rscore', ['ate', 'pehe'])
                scores_list.append([est_name] + rscore_i)

        df_rscore = pd.DataFrame(scores_list, columns=['name', f'ate_{rs_name}', f'pehe_{rs_name}'])
        df_copy = df_copy.merge(df_rscore, on=['name'])
    
    return df_copy

In [8]:
df_test = process_test()
df_mse = process_mse(df_test)
df_plugins = process_plugins(df_mse)
df_rscores = process_rscores(df_plugins)

In [11]:
metric = 'ate'
p_metric = 'ate'
plugin_cols = [f'{metric}_{pmm}_{pbm}_{p_metric}' for pmm in plugin_meta_models for pbm in plugin_base_models]
rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_base_models]
cols = ['name', f'{metric}_mse'] + plugin_cols + rscore_cols + [f'{metric}_test']
df_rscores[cols]

Unnamed: 0,name,ate_mse,ate_sl_dt_ate,ate_sl_lgbm_ate,ate_sl_cb_ate,ate_tl_dt_ate,ate_tl_lgbm_ate,ate_tl_cb_ate,ate_rs_dt,ate_rs_lgbm,ate_rs_cb,ate_test
0,sl_l1,4.791 +/- 2.711,1.495 +/- 3.782,4.791 +/- 2.711,1.850 +/- 2.532,1.178 +/- 2.700,1.178 +/- 2.700,1.178 +/- 2.700,1.734 +/- 2.583,1.284 +/- 2.664,1.284 +/- 2.664,1.136 +/- 2.708
1,sl_l2,1.676 +/- 2.643,1.171 +/- 2.805,1.676 +/- 2.643,1.676 +/- 2.643,1.158 +/- 2.646,1.158 +/- 2.646,1.158 +/- 2.646,1.734 +/- 2.619,1.304 +/- 2.604,1.495 +/- 2.526,1.100 +/- 2.658
2,sl_dt,0.388 +/- 0.807,0.722 +/- 1.240,0.459 +/- 0.781,0.469 +/- 0.797,0.497 +/- 0.743,0.532 +/- 0.788,0.449 +/- 0.702,0.465 +/- 0.628,0.554 +/- 0.932,0.522 +/- 0.832,0.008 +/- 0.014
3,sl_rf,0.986 +/- 2.165,0.874 +/- 2.168,1.253 +/- 2.943,0.914 +/- 2.155,0.286 +/- 0.450,0.204 +/- 0.354,0.266 +/- 0.459,0.326 +/- 0.447,0.286 +/- 0.451,0.278 +/- 0.460,0.043 +/- 0.050
4,sl_et,0.914 +/- 2.423,0.841 +/- 2.236,0.910 +/- 2.424,0.755 +/- 1.961,0.241 +/- 0.368,0.245 +/- 0.385,0.236 +/- 0.371,0.656 +/- 1.495,0.260 +/- 0.361,0.266 +/- 0.358,0.106 +/- 0.281
5,sl_kr,4.791 +/- 2.711,1.402 +/- 3.510,4.791 +/- 2.711,1.613 +/- 1.549,0.130 +/- 0.092,0.119 +/- 0.106,0.287 +/- 0.489,1.031 +/- 0.562,0.292 +/- 0.321,0.445 +/- 0.269,0.044 +/- 0.081
6,sl_cb,1.317 +/- 1.364,0.799 +/- 1.531,1.317 +/- 1.364,1.317 +/- 1.364,0.230 +/- 0.167,0.239 +/- 0.160,0.239 +/- 0.163,0.946 +/- 0.513,0.345 +/- 0.190,0.562 +/- 0.203,0.209 +/- 0.147
7,sl_lgbm,0.193 +/- 0.186,0.205 +/- 0.211,0.187 +/- 0.181,0.187 +/- 0.181,0.274 +/- 0.320,0.274 +/- 0.326,0.276 +/- 0.323,0.205 +/- 0.216,0.192 +/- 0.207,0.192 +/- 0.208,0.106 +/- 0.144
8,tl_l1,0.461 +/- 0.450,1.894 +/- 3.865,4.791 +/- 2.711,1.461 +/- 0.812,0.349 +/- 0.495,0.358 +/- 0.493,0.358 +/- 0.493,1.324 +/- 0.693,0.487 +/- 0.395,0.591 +/- 0.415,0.308 +/- 0.418
9,tl_l2,0.470 +/- 0.524,0.474 +/- 0.684,0.589 +/- 0.632,0.589 +/- 0.632,0.401 +/- 0.510,0.400 +/- 0.511,0.404 +/- 0.512,0.596 +/- 0.640,0.483 +/- 0.489,0.528 +/- 0.460,0.349 +/- 0.519


In [13]:
metric = 'ate'
p_metric = 'pehe'
plugin_cols = [f'{metric}_{pmm}_{pbm}_{p_metric}' for pmm in plugin_meta_models for pbm in plugin_base_models]
rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_base_models]
cols = ['name', f'{metric}_mse'] + plugin_cols + rscore_cols + [f'{metric}_test']
df_rscores[cols]

Unnamed: 0,name,ate_mse,ate_sl_dt_pehe,ate_sl_lgbm_pehe,ate_sl_cb_pehe,ate_tl_dt_pehe,ate_tl_lgbm_pehe,ate_tl_cb_pehe,ate_rs_dt,ate_rs_lgbm,ate_rs_cb,ate_test
0,sl_l1,4.791 +/- 2.711,1.730 +/- 3.781,4.791 +/- 2.711,2.001 +/- 2.482,1.190 +/- 2.695,1.178 +/- 2.700,1.190 +/- 2.695,1.734 +/- 2.583,1.284 +/- 2.664,1.284 +/- 2.664,1.136 +/- 2.708
1,sl_l2,1.676 +/- 2.643,1.165 +/- 2.807,1.676 +/- 2.643,1.625 +/- 2.490,1.158 +/- 2.646,1.158 +/- 2.646,1.158 +/- 2.646,1.734 +/- 2.619,1.304 +/- 2.604,1.495 +/- 2.526,1.100 +/- 2.658
2,sl_dt,0.388 +/- 0.807,0.422 +/- 0.802,0.407 +/- 0.796,0.418 +/- 0.793,0.431 +/- 0.585,0.231 +/- 0.387,0.412 +/- 0.560,0.465 +/- 0.628,0.554 +/- 0.932,0.522 +/- 0.832,0.008 +/- 0.014
3,sl_rf,0.986 +/- 2.165,1.246 +/- 2.946,1.261 +/- 2.940,0.898 +/- 2.092,0.254 +/- 0.466,0.261 +/- 0.354,0.307 +/- 0.468,0.326 +/- 0.447,0.286 +/- 0.451,0.278 +/- 0.460,0.043 +/- 0.050
4,sl_et,0.914 +/- 2.423,0.892 +/- 2.431,0.915 +/- 2.424,0.762 +/- 1.998,0.361 +/- 0.576,0.449 +/- 0.780,0.381 +/- 0.618,0.656 +/- 1.495,0.260 +/- 0.361,0.266 +/- 0.358,0.106 +/- 0.281
5,sl_kr,4.791 +/- 2.711,1.665 +/- 3.805,4.791 +/- 2.711,1.693 +/- 1.784,0.578 +/- 1.248,0.351 +/- 0.595,0.386 +/- 0.726,1.031 +/- 0.562,0.292 +/- 0.321,0.445 +/- 0.269,0.044 +/- 0.081
6,sl_cb,1.317 +/- 1.364,0.757 +/- 1.543,1.317 +/- 1.364,1.317 +/- 1.364,0.439 +/- 0.742,0.485 +/- 0.738,0.439 +/- 0.742,0.946 +/- 0.513,0.345 +/- 0.190,0.562 +/- 0.203,0.209 +/- 0.147
7,sl_lgbm,0.193 +/- 0.186,0.192 +/- 0.171,0.194 +/- 0.169,0.193 +/- 0.170,0.188 +/- 0.166,0.203 +/- 0.190,0.195 +/- 0.173,0.205 +/- 0.216,0.192 +/- 0.207,0.192 +/- 0.208,0.106 +/- 0.144
8,tl_l1,0.461 +/- 0.450,1.913 +/- 3.715,4.791 +/- 2.711,1.742 +/- 1.403,0.690 +/- 0.780,0.712 +/- 0.754,0.659 +/- 0.762,1.324 +/- 0.693,0.487 +/- 0.395,0.591 +/- 0.415,0.308 +/- 0.418
9,tl_l2,0.470 +/- 0.524,0.530 +/- 0.666,0.600 +/- 0.639,0.600 +/- 0.639,0.447 +/- 0.533,0.461 +/- 0.527,0.459 +/- 0.526,0.596 +/- 0.640,0.483 +/- 0.489,0.528 +/- 0.460,0.349 +/- 0.519


In [14]:
metric = 'pehe'
p_metric = 'ate'
plugin_cols = [f'{metric}_{pmm}_{pbm}_{p_metric}' for pmm in plugin_meta_models for pbm in plugin_base_models]
rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_base_models]
cols = ['name', f'{metric}_mse'] + plugin_cols + rscore_cols + [f'{metric}_test']
df_rscores[cols]

Unnamed: 0,name,pehe_mse,pehe_sl_dt_ate,pehe_sl_lgbm_ate,pehe_sl_cb_ate,pehe_tl_dt_ate,pehe_tl_lgbm_ate,pehe_tl_cb_ate,pehe_rs_dt,pehe_rs_lgbm,pehe_rs_cb,pehe_test
0,sl_l1,6.998 +/- 6.783,4.565 +/- 7.770,6.998 +/- 6.783,4.848 +/- 7.142,4.417 +/- 7.338,4.417 +/- 7.338,4.417 +/- 7.338,4.747 +/- 7.199,4.443 +/- 7.326,4.443 +/- 7.326,4.414 +/- 7.336
1,sl_l2,4.647 +/- 7.271,4.426 +/- 7.370,4.647 +/- 7.271,4.647 +/- 7.271,4.409 +/- 7.320,4.409 +/- 7.320,4.409 +/- 7.320,4.651 +/- 7.274,4.462 +/- 7.297,4.528 +/- 7.266,4.405 +/- 7.317
2,sl_dt,4.875 +/- 7.756,5.485 +/- 8.487,5.355 +/- 7.776,5.771 +/- 8.816,5.536 +/- 8.236,5.497 +/- 8.143,5.463 +/- 8.215,5.406 +/- 8.252,5.549 +/- 8.673,5.549 +/- 8.644,4.837 +/- 7.771
3,sl_rf,4.596 +/- 7.213,4.733 +/- 7.174,4.738 +/- 7.343,4.682 +/- 7.183,5.182 +/- 8.232,5.068 +/- 8.051,5.204 +/- 8.226,5.229 +/- 8.199,5.245 +/- 8.203,5.252 +/- 8.203,4.570 +/- 7.220
4,sl_et,4.607 +/- 7.275,4.812 +/- 7.174,4.861 +/- 7.178,4.843 +/- 7.126,5.441 +/- 8.753,5.448 +/- 8.654,5.448 +/- 8.751,5.071 +/- 7.744,5.477 +/- 8.737,5.480 +/- 8.740,4.533 +/- 7.236
5,sl_kr,6.998 +/- 6.783,4.964 +/- 7.546,6.998 +/- 6.783,5.015 +/- 7.502,6.017 +/- 10.008,6.029 +/- 10.039,6.308 +/- 10.895,5.477 +/- 8.556,5.702 +/- 9.368,5.744 +/- 9.346,4.423 +/- 7.268
6,sl_cb,5.484 +/- 8.139,5.276 +/- 8.233,5.484 +/- 8.139,5.484 +/- 8.139,5.488 +/- 8.744,5.485 +/- 8.740,5.489 +/- 8.739,5.698 +/- 8.654,5.511 +/- 8.738,5.658 +/- 8.696,5.207 +/- 8.171
7,sl_lgbm,5.625 +/- 8.867,5.621 +/- 8.870,5.613 +/- 8.874,5.613 +/- 8.874,5.692 +/- 9.087,5.692 +/- 9.084,5.696 +/- 9.082,5.635 +/- 8.968,5.702 +/- 9.055,5.687 +/- 9.060,5.573 +/- 8.877
8,tl_l1,5.910 +/- 9.286,5.465 +/- 7.838,6.998 +/- 6.783,5.908 +/- 9.031,5.865 +/- 9.283,5.877 +/- 9.298,5.886 +/- 9.294,5.501 +/- 7.899,5.842 +/- 9.069,5.880 +/- 9.264,4.486 +/- 6.831
9,tl_l2,5.819 +/- 9.259,5.917 +/- 9.246,5.757 +/- 9.258,5.757 +/- 9.258,5.922 +/- 9.268,5.915 +/- 9.271,5.916 +/- 9.268,5.658 +/- 8.983,5.670 +/- 8.996,5.673 +/- 8.998,5.609 +/- 9.000


In [15]:
metric = 'pehe'
p_metric = 'pehe'
plugin_cols = [f'{metric}_{pmm}_{pbm}_{p_metric}' for pmm in plugin_meta_models for pbm in plugin_base_models]
rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_base_models]
cols = ['name', f'{metric}_mse'] + plugin_cols + rscore_cols + [f'{metric}_test']
df_rscores[cols]

Unnamed: 0,name,pehe_mse,pehe_sl_dt_pehe,pehe_sl_lgbm_pehe,pehe_sl_cb_pehe,pehe_tl_dt_pehe,pehe_tl_lgbm_pehe,pehe_tl_cb_pehe,pehe_rs_dt,pehe_rs_lgbm,pehe_rs_cb,pehe_test
0,sl_l1,6.998 +/- 6.783,4.609 +/- 7.793,6.998 +/- 6.783,4.955 +/- 7.092,4.418 +/- 7.337,4.417 +/- 7.338,4.418 +/- 7.337,4.747 +/- 7.199,4.443 +/- 7.326,4.443 +/- 7.326,4.414 +/- 7.336
1,sl_l2,4.647 +/- 7.271,4.425 +/- 7.371,4.647 +/- 7.271,4.628 +/- 7.218,4.409 +/- 7.320,4.409 +/- 7.320,4.409 +/- 7.320,4.651 +/- 7.274,4.462 +/- 7.297,4.528 +/- 7.266,4.405 +/- 7.317
2,sl_dt,4.875 +/- 7.756,4.917 +/- 7.738,4.870 +/- 7.756,4.871 +/- 7.755,5.265 +/- 8.279,5.075 +/- 8.012,5.284 +/- 8.287,5.406 +/- 8.252,5.549 +/- 8.673,5.549 +/- 8.644,4.837 +/- 7.771
3,sl_rf,4.596 +/- 7.213,4.687 +/- 7.362,4.669 +/- 7.363,4.622 +/- 7.224,5.046 +/- 7.987,5.044 +/- 8.162,5.115 +/- 8.220,5.229 +/- 8.199,5.245 +/- 8.203,5.252 +/- 8.203,4.570 +/- 7.220
4,sl_et,4.607 +/- 7.275,4.659 +/- 7.254,4.636 +/- 7.260,4.629 +/- 7.273,5.030 +/- 7.966,4.958 +/- 8.139,4.936 +/- 7.900,5.071 +/- 7.744,5.477 +/- 8.737,5.480 +/- 8.740,4.533 +/- 7.236
5,sl_kr,6.998 +/- 6.783,4.734 +/- 7.741,6.998 +/- 6.783,4.911 +/- 7.083,4.956 +/- 8.092,5.144 +/- 8.808,5.109 +/- 8.660,5.477 +/- 8.556,5.702 +/- 9.368,5.744 +/- 9.346,4.423 +/- 7.268
6,sl_cb,5.484 +/- 8.139,5.238 +/- 8.247,5.484 +/- 8.139,5.484 +/- 8.139,5.231 +/- 8.183,5.211 +/- 8.170,5.231 +/- 8.183,5.698 +/- 8.654,5.511 +/- 8.738,5.658 +/- 8.696,5.207 +/- 8.171
7,sl_lgbm,5.625 +/- 8.867,5.607 +/- 8.937,5.623 +/- 8.931,5.611 +/- 8.935,5.635 +/- 9.015,5.615 +/- 8.941,5.609 +/- 8.941,5.635 +/- 8.968,5.702 +/- 9.055,5.687 +/- 9.060,5.573 +/- 8.877
8,tl_l1,5.910 +/- 9.286,4.916 +/- 7.661,6.998 +/- 6.783,4.981 +/- 6.903,5.097 +/- 8.012,5.058 +/- 8.075,5.098 +/- 8.058,5.501 +/- 7.899,5.842 +/- 9.069,5.880 +/- 9.264,4.486 +/- 6.831
9,tl_l2,5.819 +/- 9.259,5.619 +/- 8.999,5.642 +/- 8.989,5.642 +/- 8.989,5.617 +/- 9.010,5.618 +/- 9.010,5.617 +/- 9.010,5.658 +/- 8.983,5.670 +/- 8.996,5.673 +/- 8.998,5.609 +/- 9.000
