In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def show_all(df, metric, plugin_metrics, plugin_models, matching_ks, rscore_models):
    plugin_cols = [f'{metric}_{pm}_{plugin_metric}' for plugin_metric in plugin_metrics for pm in plugin_models]
    matching_cols = [f'{metric}_match_{k}k_{plugin_metric}' for plugin_metric in plugin_metrics for k in matching_ks]
    rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_models]
    cols = ['name', f'{metric}_mse', f'{metric}_r2'] + plugin_cols + matching_cols + rscore_cols + [f'{metric}_val']
    return df[cols]

def show_all_jobs(df, metric, plugin_metrics, plugin_models, matching_ks, rscore_models):
    plugin_cols = [f'{metric}_{pm}_{plugin_metric}' for plugin_metric in plugin_metrics for pm in plugin_models]
    matching_cols = [f'{metric}_match_{k}k_{plugin_metric}' for plugin_metric in plugin_metrics for k in matching_ks]
    rscore_cols = [f'{metric}_rs_{rs_bm}' for rs_bm in rscore_models]
    cols = ['name', f'{metric}_mse', f'{metric}_r2', f'{metric}_pol'] + plugin_cols + matching_cols + rscore_cols + [f'{metric}_val']
    return df[cols]

In [None]:
def metric_name(x):
    if x['selection'] == 'MSE':
        return "$\mu{-risk}$"
    elif x['selection'] == 'R2':
        return "$\mu{-risk}_R$"
    elif x['selection'] == 'BEST':
        return "Oracle"
    elif x['selection'] == 'POL':
        return "$\mathcal{R}_{pol}$"
    else:
        ret = "$\\tau{-risk}"
        if x['selection'].startswith('RS'):
            ret += "_{R}$"
        elif x['selection'].startswith('MATCH'):
            ret += "_{match}"
        else:
            ret += "_{plug}"
        
        if x['selection'].endswith('ATE'):
            ret += "^{ATE}$"
        elif x['selection'].endswith('PEHE'):
            ret += "^{PEHE}$"
        
        return ret

In [None]:
plugin_meta_models = ['sl', 'tl']
plugin_base_models = ['dt', 'lgbm', 'kr']
plugin_models = [f'{pmm}_{pbm}' for pmm in plugin_meta_models for pbm in plugin_base_models]
matching_ks = [1, 3, 5]
rscore_base_models = ['dt', 'lgbm', 'kr']

In [None]:
def plot_metrics(ds, avg_metric, ite_metric, ax=None, legend=False, s=40):
    df_all = pd.read_csv(f'./tables/{ds}_compare_metrics_all_val_raw.csv')
    df_ate = show_all(df_all, avg_metric, ['ate', 'pehe'], plugin_models, matching_ks, rscore_base_models)
    df_pehe = show_all(df_all, ite_metric, ['ate', 'pehe'], plugin_models, matching_ks, rscore_base_models)

    selection_models = [f'{pmm}_{pbm}_{pm}' for pmm in plugin_meta_models for pbm in plugin_base_models for pm in ['ate', 'pehe']] + [f'match_{k}k_{pm}' for k in matching_ks for pm in ['ate', 'pehe']] + [f'rs_{rbm}' for rbm in rscore_base_models]
    d_ate = {f'{avg_metric}_val': 'best', f'{avg_metric}_mse': 'mse', f'{avg_metric}_r2': 'r2'}
    d_pehe = {f'{ite_metric}_val': 'best', f'{ite_metric}_mse': 'mse', f'{ite_metric}_r2': 'r2'}
    for sm in selection_models:
        d_ate[f'{avg_metric}_{sm}'] = sm
        d_pehe[f'{ite_metric}_{sm}'] = sm

    df_ate = df_ate.rename(columns=d_ate)
    df_pehe = df_pehe.rename(columns=d_pehe)

    df_ate = df_ate.set_index('name').T
    df_pehe = df_pehe.set_index('name').T

    df_final = df_ate.merge(df_pehe, left_index=True, right_index=True, suffixes=['_ate', '_pehe']).reset_index()

    # Add an empty policy row for compatibility with Jobs dataset
    df_final.loc[1.5] = 'pol', None, None
    df_final = df_final.sort_index().reset_index(drop=True)

    df_final['selection'] = df_final['index'].apply(lambda x: x.upper().replace('_', '-'))

    df_final['metric'] = df_final.apply(metric_name, axis=1)
    df_final['learner'] = df_final.apply(base_name, axis=1)

    if ax:
        sns.scatterplot(data=df_final, x='all_ate', y='all_pehe', hue='metric', style='learner', ax=ax, legend=legend, s=s)
    else:
        sns.scatterplot(data=df_final, x='all_ate', y='all_pehe', hue='metric', style='learner', legend=legend, s=s)

In [2]:
ds = 'ihdp'
avg_metric = 'ate'
ite_metric = 'pehe'

#plot_metrics(ds, avg_metric, ite_metric, legend=True)

In [3]:
df_all = pd.read_csv(f'./tables/{ds}_compare_metrics_all_val_raw.csv')

In [5]:
df_all.T

Unnamed: 0,0
name,all
ate_val,0.173588
pehe_val,0.64219
ate_mse,3.452451
pehe_mse,4.941296
ate_r2,3.808946
pehe_r2,6.083907
ate_sl_dt_ate,1.455495
pehe_sl_dt_ate,3.326862
ate_sl_dt_pehe,1.655077
