In [1]:
import pandas as pd

In [2]:
def show_all(df, metric, plugin_metrics, plugin_models, matching_ks, rscore_models):
    plugin_cols = [f'{metric}_{pm}_{plugin_metric}' for plugin_metric in plugin_metrics for pm in plugin_models]
    matching_cols = [f'{metric}_match_{k}k_{plugin_metric}' for plugin_metric in plugin_metrics for k in matching_ks]
    rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_models]
    cols = ['name', f'{metric}_mse', f'{metric}_r2'] + plugin_cols + matching_cols + rscore_cols
    return df[cols]

def show_all_jobs(df, metric, plugin_metrics, plugin_models, matching_ks, rscore_models):
    plugin_cols = [f'{metric}_{pm}_{plugin_metric}' for plugin_metric in plugin_metrics for pm in plugin_models]
    matching_cols = [f'{metric}_match_{k}k_{plugin_metric}' for plugin_metric in plugin_metrics for k in matching_ks]
    rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_models]
    cols = ['name', f'{metric}_mse', f'{metric}_r2', f'{metric}_pol'] + plugin_cols + matching_cols + rscore_cols
    return df[cols]

# IHDP

In [3]:
plugin_meta_models = ['sl', 'tl']
plugin_base_models = ['dt', 'lgbm', 'kr']
plugin_models = [f'{pmm}_{pbm}' for pmm in plugin_meta_models for pbm in plugin_base_models]
matching_ks = [1, 3, 5]
rscore_base_models = ['dt', 'lgbm', 'kr']

ds = 'ihdp'
avg_metric = 'ate'
ite_metric = 'pehe'

In [4]:
df_all = pd.read_csv(f'./tables/{ds}_compare_correlations_all_sem_latex.csv')

df_ate = show_all(df_all, avg_metric, ['ate', 'pehe'], plugin_models, matching_ks, rscore_base_models)
df_pehe = show_all(df_all, ite_metric, ['ate', 'pehe'], plugin_models, matching_ks, rscore_base_models)

selection_models = [f'{pmm}_{pbm}_{pm}' for pmm in plugin_meta_models for pbm in plugin_base_models for pm in ['ate', 'pehe']] + [f'match_{k}k_{pm}' for k in matching_ks for pm in ['ate', 'pehe']] + [f'rs_{rbm}' for rbm in rscore_base_models]
d_ate = {f'{avg_metric}_mse': 'mse', f'{avg_metric}_r2': 'r2'}
d_pehe = {f'{ite_metric}_mse': 'mse', f'{ite_metric}_r2': 'r2'}
for sm in selection_models:
    d_ate[f'{avg_metric}_{sm}'] = sm
    d_pehe[f'{ite_metric}_{sm}'] = sm

df_ate = df_ate.rename(columns=d_ate)
df_pehe = df_pehe.rename(columns=d_pehe)

df_ate = df_ate.set_index('name').T
df_pehe = df_pehe.set_index('name').T

df_merged = df_ate.merge(df_pehe, left_index=True, right_index=True, suffixes=['_ate', '_pehe']).reset_index()

df_merged['selection'] = df_merged['index'].apply(lambda x: x.upper().replace('_', '-'))
print(df_merged[['selection', 'all_ate', 'all_pehe']].to_latex(index=False, escape=False))


\begin{tabular}{lll}
\toprule
    selection &          all_ate &         all_pehe \\
\midrule
          MSE &  $0.774\pm0.071$ &  $0.907\pm0.019$ \\
           R2 & $-0.908\pm0.019$ & $-0.737\pm0.096$ \\
    SL-DT-ATE &  $0.749\pm0.148$ &  $0.582\pm0.062$ \\
  SL-LGBM-ATE & $-0.866\pm0.056$ & $-0.458\pm0.094$ \\
    SL-KR-ATE &  $0.718\pm0.107$ &  $0.505\pm0.042$ \\
    TL-DT-ATE &  $0.954\pm0.017$ &  $0.613\pm0.049$ \\
  TL-LGBM-ATE &  $0.954\pm0.016$ &  $0.611\pm0.048$ \\
    TL-KR-ATE &  $0.945\pm0.013$ &  $0.594\pm0.045$ \\
   SL-DT-PEHE &  $0.544\pm0.101$ &  $0.712\pm0.119$ \\
 SL-LGBM-PEHE & $-0.532\pm0.085$ & $-0.080\pm0.083$ \\
   SL-KR-PEHE &  $0.408\pm0.076$ &  $0.635\pm0.115$ \\
   TL-DT-PEHE &  $0.640\pm0.060$ &  $0.810\pm0.061$ \\
 TL-LGBM-PEHE &  $0.620\pm0.064$ &  $0.786\pm0.071$ \\
   TL-KR-PEHE &  $0.616\pm0.048$ &  $0.883\pm0.023$ \\
 MATCH-1K-ATE &  $0.954\pm0.017$ &  $0.599\pm0.054$ \\
 MATCH-3K-ATE &  $0.956\pm0.015$ &  $0.608\pm0.050$ \\
 MATCH-5K-ATE &  $0.957\pm

# Jobs

In [5]:
plugin_meta_models = ['sl', 'tl']
plugin_base_models = ['dt', 'lgbm', 'kr']
plugin_models = [f'{pmm}_{pbm}' for pmm in plugin_meta_models for pbm in plugin_base_models]
matching_ks = [1, 3, 5]
rscore_base_models = ['dt', 'lgbm', 'kr']

ds = 'jobs'
avg_metric = 'att'
ite_metric = 'policy'

In [6]:
df_all = pd.read_csv(f'./tables/{ds}_compare_correlations_all_sem_latex.csv')

df_ate = show_all_jobs(df_all, avg_metric, ['ate', 'pehe'], plugin_models, matching_ks, rscore_base_models)
df_pehe = show_all_jobs(df_all, ite_metric, ['ate', 'pehe'], plugin_models, matching_ks, rscore_base_models)

selection_models = [f'{pmm}_{pbm}_{pm}' for pmm in plugin_meta_models for pbm in plugin_base_models for pm in ['ate', 'pehe']] + [f'match_{k}k_{pm}' for k in matching_ks for pm in ['ate', 'pehe']] + [f'rs_{rbm}' for rbm in rscore_base_models]
d_ate = {f'{avg_metric}_mse': 'mse', f'{avg_metric}_r2': 'r2', f'{avg_metric}_pol': 'pol'}
d_pehe = {f'{ite_metric}_mse': 'mse', f'{ite_metric}_r2': 'r2', f'{ite_metric}_pol': 'pol'}
for sm in selection_models:
    d_ate[f'{avg_metric}_{sm}'] = sm
    d_pehe[f'{ite_metric}_{sm}'] = sm

df_ate = df_ate.rename(columns=d_ate)
df_pehe = df_pehe.rename(columns=d_pehe)

df_ate = df_ate.set_index('name').T
df_pehe = df_pehe.set_index('name').T

df_merged = df_ate.merge(df_pehe, left_index=True, right_index=True, suffixes=['_att', '_pol']).reset_index()

df_merged['selection'] = df_merged['index'].apply(lambda x: x.upper().replace('_', '-'))
print(df_merged[['selection', 'all_att', 'all_pol']].to_latex(index=False, escape=False))


\begin{tabular}{lll}
\toprule
    selection &          all_att &          all_pol \\
\midrule
          MSE &  $0.051\pm0.007$ &  $0.080\pm0.020$ \\
           R2 & $-0.048\pm0.006$ & $-0.079\pm0.020$ \\
          POL &  $0.023\pm0.005$ &  $0.289\pm0.074$ \\
    SL-DT-ATE &  $0.530\pm0.073$ &  $0.014\pm0.008$ \\
  SL-LGBM-ATE &  $0.530\pm0.073$ &  $0.014\pm0.008$ \\
    SL-KR-ATE &  $0.530\pm0.073$ &  $0.014\pm0.008$ \\
    TL-DT-ATE &  $0.530\pm0.073$ &  $0.014\pm0.008$ \\
  TL-LGBM-ATE &  $0.530\pm0.073$ &  $0.014\pm0.008$ \\
    TL-KR-ATE &  $0.530\pm0.073$ &  $0.014\pm0.008$ \\
   SL-DT-PEHE &  $0.505\pm0.064$ &  $0.011\pm0.006$ \\
 SL-LGBM-PEHE &  $0.505\pm0.064$ &  $0.011\pm0.006$ \\
   SL-KR-PEHE &  $0.505\pm0.064$ &  $0.011\pm0.006$ \\
   TL-DT-PEHE &  $0.505\pm0.064$ &  $0.011\pm0.006$ \\
 TL-LGBM-PEHE &  $0.505\pm0.064$ &  $0.011\pm0.006$ \\
   TL-KR-PEHE &  $0.505\pm0.064$ &  $0.011\pm0.006$ \\
 MATCH-1K-ATE &  $0.530\pm0.073$ &  $0.014\pm0.008$ \\
 MATCH-3K-ATE &  $0.530\pm

# Twins

In [7]:
plugin_meta_models = ['sl', 'tl']
plugin_base_models = ['dt', 'lgbm', 'kr']
plugin_models = [f'{pmm}_{pbm}' for pmm in plugin_meta_models for pbm in plugin_base_models]
matching_ks = [1, 3, 5]
rscore_base_models = ['dt', 'lgbm', 'kr']

ds = 'twins'
avg_metric = 'ate'
ite_metric = 'pehe'

In [8]:
df_all = pd.read_csv(f'./tables/{ds}_compare_correlations_all_sem_latex.csv')

df_ate = show_all(df_all, avg_metric, ['ate', 'pehe'], plugin_models, matching_ks, rscore_base_models)
df_pehe = show_all(df_all, ite_metric, ['ate', 'pehe'], plugin_models, matching_ks, rscore_base_models)

selection_models = [f'{pmm}_{pbm}_{pm}' for pmm in plugin_meta_models for pbm in plugin_base_models for pm in ['ate', 'pehe']] + [f'match_{k}k_{pm}' for k in matching_ks for pm in ['ate', 'pehe']] + [f'rs_{rbm}' for rbm in rscore_base_models]
d_ate = {f'{avg_metric}_mse': 'mse', f'{avg_metric}_r2': 'r2'}
d_pehe = {f'{ite_metric}_mse': 'mse', f'{ite_metric}_r2': 'r2'}
for sm in selection_models:
    d_ate[f'{avg_metric}_{sm}'] = sm
    d_pehe[f'{ite_metric}_{sm}'] = sm

df_ate = df_ate.rename(columns=d_ate)
df_pehe = df_pehe.rename(columns=d_pehe)

df_ate = df_ate.set_index('name').T
df_pehe = df_pehe.set_index('name').T

df_merged = df_ate.merge(df_pehe, left_index=True, right_index=True, suffixes=['_ate', '_pehe']).reset_index()

df_merged['selection'] = df_merged['index'].apply(lambda x: x.upper().replace('_', '-'))
print(df_merged[['selection', 'all_ate', 'all_pehe']].to_latex(index=False, escape=False))


\begin{tabular}{lll}
\toprule
    selection &          all_ate &         all_pehe \\
\midrule
          MSE &  $0.080\pm0.008$ &  $0.968\pm0.011$ \\
           R2 & $-0.076\pm0.007$ & $-0.969\pm0.010$ \\
    SL-DT-ATE &  $0.286\pm0.068$ &  $0.520\pm0.039$ \\
  SL-LGBM-ATE &  $0.309\pm0.062$ &  $0.538\pm0.030$ \\
    SL-KR-ATE &  $0.316\pm0.070$ &  $0.535\pm0.032$ \\
    TL-DT-ATE &  $0.283\pm0.048$ &  $0.541\pm0.028$ \\
  TL-LGBM-ATE &  $0.294\pm0.056$ &  $0.541\pm0.030$ \\
    TL-KR-ATE &  $0.288\pm0.054$ &  $0.546\pm0.024$ \\
   SL-DT-PEHE &  $0.071\pm0.010$ &  $0.656\pm0.033$ \\
 SL-LGBM-PEHE &  $0.070\pm0.009$ &  $0.657\pm0.033$ \\
   SL-KR-PEHE &  $0.071\pm0.009$ &  $0.657\pm0.033$ \\
   TL-DT-PEHE &  $0.068\pm0.008$ &  $0.658\pm0.036$ \\
 TL-LGBM-PEHE &  $0.067\pm0.008$ &  $0.657\pm0.036$ \\
   TL-KR-PEHE &  $0.069\pm0.008$ &  $0.658\pm0.034$ \\
 MATCH-1K-ATE &  $0.271\pm0.048$ &  $0.541\pm0.029$ \\
 MATCH-3K-ATE &  $0.274\pm0.049$ &  $0.541\pm0.029$ \\
 MATCH-5K-ATE &  $0.274\pm