In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from natsort import natsorted
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
import scipy.stats as stats
from tqdm import tqdm
import copy
import seaborn as sns
import pickle


#### Notebook Structure:

The first 2 cells are just defining useful result formatting functions.

Then, in the next non-empty cell we load in the results for each of the required datasets.

Then, we extract the results for subsets of methods, starting with the U-Base (no update steps) and then allowing for various other combinations which can be specified by the user.

Code in this notebook was used to generate Table 5 in the appendix.

In [28]:
# functions for renaming methods and compiling the results

def make_folder(path):
    if not os.path.exists(path):
        os.makedirs(path)
            
def get_vars_method(method):

    Q = np.zeros(10)
    G = np.zeros(7)
    U = np.zeros(8)
    
    first_name = method.split('_')[0]
    second_name = method.split('_')[1]
    try:
        third_attr = method.split('_')[2]
    except:
        third_attr = 'missing'

    try:
        fourth_attr = method.split('_')[3]
    except:
       fourth_attr = 'missing'

    try:
        fifth_attr = method.split('_')[4]
    except:
        fifth_attr = 'missing'

    try:
        sixth_attr = method.split('_')[5]
    except:
        sixth_attr = 'missing'

    try:
        sev_attr = method.split('_')[6]
    except:
        sev_attr = 'missing'

    try:
        eight_attr = method.split('_')[7]
    except:
        eight_attr = 'missing'

    Q_cat = 0.0
    if first_name == 'cfr':
        Q[0] = 1.0
        Q_cat = 1
    elif first_name == 'd':
        Q[1] = 1.0
        Q_cat = 2.0
    elif first_name == 'dnotreg':
        Q[2] = 1.0
        Q_cat = 3.0
    elif first_name == 'mn':
        Q[3] = 1.0
        Q_cat = 4.0
    elif first_name == 'sl':
        Q[4] = 1.0
        Q_cat = 5.0
    elif first_name == 'tvae':
        Q[5] = 1.0
        Q_cat = 6.0
    elif first_name == 'lr':
        Q[6] = 1.0
        Q_cat = 7.0
    elif first_name == 't':
        Q[7] = 1.0
        Q_cat = 8.0
    elif first_name == 's':
        Q[8] = 1.0
        Q_cat = 9.0
    elif first_name == 'dml':
        Q[9] = 1
        Q_cat = 10

    if second_name == 'learner' or second_name == 'var':
        second_name = third_attr
    G_cat = 0.0
    if second_name == 'dp':
        G[1] = 1.0
        G_cat = 2.0
    elif second_name == 'mn':
        G[3] = 1.0
        G_cat = 4.0
    elif second_name == 'dpnotreg':
        G[2] = 1.0
        G_cat = 3.0
    elif second_name == 'cfr':
        G[0] = 1.0
        G_cat = 1.0
    elif second_name == 'lr':
        G[6] = 1.0
        G_cat = 7.0
    elif second_name == 'p':
        G[5] = 1.0
        G_cat = 6.0
    elif second_name == 'sl':
        G[4] = 1.0
        G_cat = 5.0

    U_cat = 0.0
    if third_attr == 'missing' or  third_attr == 'var':
        update_attr = 'Update: Base'

    if fourth_attr == 'missing':
        update_attr = 'Update: Base'

    if fourth_attr == 'learner':
        fourth_attr = fifth_attr
        fifth_attr =  sixth_attr
        sixth_attr = sev_attr
        sev_attr = eight_attr

    if fourth_attr == 'multi':
        update_attr = 'Update: Multistep'
    elif fourth_attr == 'submodel':
        U[0] = 1.0
        U_cat = 1.0
    elif fourth_attr == 'onestep':
        U[1] = 1.0
        U_cat = 2.0

    if fourth_attr == 'multi' and fifth_attr == 'nonlin' and sixth_attr == 'y' and sev_attr == 'var':
        U[2] = 1.0
        U_cat = 3.0
    if fourth_attr == 'multi' and fifth_attr == 'nonlin' and sixth_attr == 'y' and sev_attr == 'meanvar':
        U[3] = 1.0
        U_cat = 4.0
    if fourth_attr == 'multi' and fifth_attr == 'nonlin' and sixth_attr == 'noy' and sev_attr == 'var':
        U[4] = 1.0
        U_cat = 5.0
    if fourth_attr == 'multi' and fifth_attr == 'nonlin' and sixth_attr == 'noy' and sev_attr == 'meanvar':
        U[5] = 1.0
        U_cat = 6.0

    if fourth_attr == 'multi' and fifth_attr == 'linear' and sixth_attr == 'var':
        U[6] = 1.0
        U_cat = 7.0
    if fourth_attr == 'multi' and fifth_attr == 'linear' and sixth_attr == 'meanvar':
        U[7] = 1.0
        U_cat = 8.0


    cols_all = ['Q-CFR', 'Q-D', 'Q-Dnotreg', 'Q-MN', 'Q-SL', 'Q-TVAE', 'Q-LR', 'Q-T', 'Q-S', 'DML',
           'G-CFR', 'G-D', 'G-Dnotreg', 'G-MN', 'G-SL', 'G-P', 'G-LR', 'U-sub', 'U-ones',
           'U-multi-nonlin-fqh-var', 'U-multi-nonlin-fqh-meanvar', 'U-multi-nonlin-fh-var',
           'U-multi-nonlin-fh-meanvar', 'U-multi-lin-var','U-multi-lin-meanvar']
    cols_cat = ['QModel','GModel','UpdateStep']
    QGU = np.concatenate([Q, G, U])
    QGU_cat = np.concatenate([np.array([Q_cat]), np.array([G_cat]), np.array([U_cat])])

    return QGU, QGU_cat, cols_all, cols_cat

def rename_methods(method_names):
    new_methods = []
    for method in method_names:

        first_name = method.split('_')[0]
        second_name = method.split('_')[1]
        try:
            third_attr = method.split('_')[2]
        except:
            third_attr = 'missing'

        try:
            fourth_attr = method.split('_')[3]
        except:
           fourth_attr = 'missing'

        try:
            fifth_attr = method.split('_')[4]
        except:
            fifth_attr = 'missing'

        try:
            sixth_attr = method.split('_')[5]
        except:
            sixth_attr = 'missing'

        try:
            sev_attr = method.split('_')[6]
        except:
            sev_attr = 'missing'

        try:
            eight_attr = method.split('_')[7]
        except:
            eight_attr = 'missing'


    #     print(' -' ,first_name, second_name, third_attr)
        if first_name == 'cfr':
            new_first_name = 'Q: CFR'
        elif first_name == 'd':
            new_first_name = 'Q: Dnet'
        elif first_name == 'dnotreg':
            new_first_name = 'Q: Dnet (no treg)'
        elif first_name == 'mn':
            new_first_name = 'Q: MN'
        elif first_name == 'sl':
            new_first_name = 'Q: SL'
        elif first_name == 'tvae':
            new_first_name = 'Q: TVAE'
        elif first_name == 'lr':
            new_first_name = 'Q: LR'
        elif first_name == 't':
            new_first_name = 'Q: T-learn'
        elif first_name == 's':
            new_first_name = 'Q: S-learn'

        if second_name == 'learner' or second_name == 'var':
            second_name = third_attr
            
        if second_name == 'dp':
            new_second_name = 'G: Dnet'
        elif second_name == 'mn':
            new_second_name = 'G: MN'
        elif second_name == 'dpnotreg':
            new_second_name = 'G: Dnet (no treg)'
        elif second_name == 'cfr':
            new_second_name = 'G: CFR'
        elif second_name == 'lr':
            new_second_name = 'G: LR'
        elif second_name == 'p':
            new_second_name = 'G: P-Learn'
        elif second_name == 'sl':
            new_second_name = 'G: SL'

        if third_attr == 'missing' or  third_attr == 'var':
            new_second_name = ' '
            update_attr = 'Update: Base'

        if fourth_attr == 'missing':
            update_attr = 'Update: Base'
            new_second_name = ' '

        if fourth_attr == 'learner':
            fourth_attr = fifth_attr
            fifth_attr =  sixth_attr
            sixth_attr = sev_attr
            sev_attr = eight_attr

        if fourth_attr == 'multi':
            update_attr = 'Update: Multistep'
        elif fourth_attr == 'submodel':
            update_attr = 'Update: Submodel'
        elif fourth_attr == 'onestep':
            update_attr = 'Update: Onestep'


#         print('4th', fourth_attr, fifth_attr, sixth_attr, sev_attr, eight_attr)

        if fourth_attr == 'multi' and fifth_attr == 'nonlin' and sixth_attr == 'y' and sev_attr == 'var':
            update_attr = 'Update: Nonlinear Multistep f(Q,H) w/ var penalty'
        if fourth_attr == 'multi' and fifth_attr == 'nonlin' and sixth_attr == 'y' and sev_attr == 'meanvar':
            update_attr = 'Update: Nonlinear Multistep f(Q,H) w/ mean+var penalty'
        if fourth_attr == 'multi' and fifth_attr == 'nonlin' and sixth_attr == 'noy' and sev_attr == 'var':
            update_attr = 'Update: Nonlinear Multistep f(H) w/ var penalty'
        if fourth_attr == 'multi' and fifth_attr == 'nonlin' and sixth_attr == 'noy' and sev_attr == 'meanvar':
            update_attr = 'Update: Nonlinear Multistep f(H) w/ mean+var penalty'

        if fourth_attr == 'multi' and fifth_attr == 'linear' and sixth_attr == 'var':
            update_attr = 'Update: Linear Multistep w/ var penalty'
        if fourth_attr == 'multi' and fifth_attr == 'linear' and sixth_attr == 'meanvar':
            update_attr = 'Update: Linear Multistep w/ mean+var penalty'

        new_name = new_first_name + ' ' + new_second_name + ' ' + update_attr
        new_methods.append(new_name)
    
    return new_methods

def get_files(folder, prefix):
    all_files = os.listdir(folder)
    files = []
    for file in all_files:
        if prefix in file:
            files.append(file)
    return natsorted(files)

def get_results(ds_prefix):
    results_fn = 'results/'
    data_fn = 'data/'
    data_fs = get_files(data_fn, ds_prefix)
    results_fs = get_files(results_fn, ds_prefix)

    methods = []
    for f in results_fs:
        data = pd.read_csv('results/'  + f)
        cols = set(data.columns)
        cols.remove('measurement')
        cols = list(cols)
        methods = methods + cols

    methods = set(methods)   
    # collect into single dictionary
    all_results_aeATE = {}
    all_results_ATE = {}

    for i in tqdm(range(len(results_fs))):
        f = results_fs[i]
        data = pd.read_csv(results_fn + f)
        cols = data.columns
        cols = set(data.columns)
        cols.remove('measurement')
        cols = list(cols)

        for method in cols:

            method_data = data[method]

            try:            
                all_results_aeATE[method].append(method_data.values[3])
                all_results_ATE[method].append(method_data.values[2])
            except:
                all_results_aeATE[method] = []
                all_results_ATE[method] = []

                all_results_aeATE[method].append(method_data.values[3])
                all_results_ATE[method].append(method_data.values[2])


    all_results = {'aeATE': all_results_aeATE, 'ATE': all_results_ATE}

    all_method_names = methods
    all_method_names.remove('true_ate')

    performance_results = {}
    all_nan_methods = []
    for method in tqdm(all_method_names):
        method_results_dict = {}
        # pull out relevant data
        gt_ATE = all_results_aeATE['true_ate']
        # the gt ATE gets stored twice (before and after update) so we only need every other value
        gt_ATE = gt_ATE[::2]
        method_results = all_results_ATE[method]
        method_aeATEs = all_results_aeATE[method]

        method_results_dict['aeate'] = np.asarray(method_aeATEs).mean()
        method_results_dict['aeate_std'] = np.asarray(method_aeATEs).std()
        method_results_dict['ate_std'] = np.asarray(method_results).std()
        shapiro_test = stats.shapiro(method_results)
        method_results_dict['all_ests'] = nanmean_(np.asarray(method_results))
        method_results_dict['gt_ate'] = gt_ATE
        method_results_dict['p'] = shapiro_test.pvalue
        performance_results[method+'_var'] = method_results_dict
    return performance_results, all_results

def nanmean_(array):
    nan_mean = np.nanmean(array)
    inds = np.where(np.isnan(array))
    array[inds] = nan_mean
    return array

def bootstrapper(gt, target, subsample_size=50, metric='mse'):
    gt = np.asarray(gt)
    indexes = np.arange(len(gt))
    results = []
    for i in range(5000):
        index = np.random.choice(indexes, subsample_size)
        gt_sample = gt[index]
        target_sample = target[index]
        if metric == 'mse':
            results.append(mean_squared_error(gt_sample, target_sample))
        elif metric == 'rmse':
            results.append(mean_squared_error(gt_sample, target_sample, squared=False))
        elif metric == 'mae':
            results.append(mean_absolute_error(gt_sample, target_sample))
            
    results = np.asarray(results)
    results_mean = results.mean()
    results_std = results.std()
    
    return results_mean, results_std

In [29]:
# compile the results into a dictionary using the prefix to the results (as set in main.sh)

all_results = {}

ds_prefix = 'RUN_all_synth1_TEST500_'
performance, all_ds_results = get_results(ds_prefix)
all_results['LF (v1) n=500'] = performance

ds_prefix = 'RUN_all_synth1_TEST5000_'
performance, all_ds_results = get_results(ds_prefix)
all_results['LF (v1) n=5000'] = performance

ds_prefix = 'RUN_all_synth1_TEST10000_'
performance, all_ds_results = get_results(ds_prefix)
all_results['LF (v1) n=10000'] = performance


ds_prefix = 'RUN_all_synth2_TEST500_'
performance, all_ds_results = get_results(ds_prefix)
all_results['LF (v2) n=500'] = performance

ds_prefix = 'RUN_all_synth2_TEST5000_'
performance, all_ds_results = get_results(ds_prefix)
all_results['LF (v2) n=5000'] = performance

ds_prefix = 'RUN_all_synth2_TEST10000_'
performance, all_ds_results = get_results(ds_prefix)
all_results['LF (v2) n=10000'] = performance

ds_prefix = 'RUN_all_IHDP_TEST500_'
performance, all_ds_results = get_results(ds_prefix)
all_results['IHDP n=747'] = performance

ds_prefix = 'RUN_all_general_TEST500_'
performance, all_ds_results = get_results(ds_prefix)
all_results['Gen n=500'] = performance

ds_prefix = 'RUN_all_general_TEST5000_'
performance, all_ds_results = get_results(ds_prefix)
all_results['Gen n=5000'] = performance

ds_prefix = 'RUN_all_general_TEST10000_'
performance, all_ds_results = get_results(ds_prefix)
all_results['Gen n=10000'] = performance

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 106.96it/s]
  nan_mean = np.nanmean(array)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 514/514 [00:00<00:00, 7362.15it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 108.76it/s]
  nan_mean = np.nanmean(array)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 514/514 [00:00<00:00, 7070.52it/s]
100%|█

In [43]:
# save compiled results using pickle and load them back in to check it worked
a_file = open("all_results_w_bootstraps_forDML.pkl", "wb")
pickle.dump(all_results, a_file)
a_file.close()

a_file = open("all_results_w_bootstraps_forDML.pkl", "rb")
all_results = pickle.load(a_file)

In [56]:
# get extra results for DML comparison, starting with  U-base
all_results_copy= all_results.copy()
outcome_cols = ['p', 'aeate', 'ate_std']
ds_list = ['IHDP n=747'] # 'LF (v2) n=10000' or  'Gen n=10000' 'IHDP n=747'
dataset = [] 
dataset_cat = []
for i in range(len(ds_list)):
    ds = ds_list[i]
    ds_var = np.array([i])

    for method in all_results_copy[ds].keys():
        result = all_results_copy[ds][method]
        output_p, output_aeate, output_atestd,  = np.asarray([result['p']]), np.asarray([result['aeate']]), np.asarray([result['ate_std']])

        var, var_cat, cols, cols_cat = get_vars_method(method)
        cols += ['dataset'] + outcome_cols
        cols_cat += ['dataset'] + outcome_cols
        var = np.concatenate([var, ds_var, output_p, output_aeate, output_atestd])
        var_cat = np.concatenate([var_cat, ds_var, output_p, output_aeate, output_atestd])
        dataset.append(var)
        dataset_cat.append(var_cat)

        
dataset = pd.DataFrame(np.asarray(dataset), columns=cols)
dataset_cat = pd.DataFrame(np.asarray(dataset_cat), columns=cols_cat)
dataset = dataset.drop(columns=['U-multi-nonlin-fqh-var', 'U-multi-lin-var', 'U-multi-nonlin-fqh-meanvar', 'U-multi-nonlin-fh-var'])
     
Q_table = ['Q-MN','Q-SL','Q-TVAE','Q-T','DML']
G_table = [ 'G-SL', 'G-MN']
U_table = ['U-multi-lin-meanvar', 'U-sub']


all_Q_cols = [i for i in dataset.columns if 'Q-' in i]
all_Q_cols = all_Q_cols + ['DML']
all_G_cols = [i for i in dataset.columns if 'G-' in i]
all_U_cols = [i for i in dataset.columns if 'U-' in i]

desired_columns = Q_table + G_table + U_table + ['p', 'aeate', 'ate_std']

result_columns = ['p', 'aeate','ate_std']

U_base_results = []
method_names = []
for q_method in Q_table:    
    temp_set = set(all_Q_cols).copy()
    temp_set.remove(q_method)
    result = dataset[(dataset[q_method] == 1)]
    res_star = result[result[list(temp_set)+all_G_cols+all_U_cols].sum(1) == 0]
    U_base_results.append(np.round(res_star[result_columns].values[0],4))
    
a = pd.DataFrame(U_base_results)
print(ds_list, 'U-Base')
a.columns = ['p', 'aeate', 'ate_std']
a.index = Q_table
a

['IHDP n=747'] U-Base


Unnamed: 0,p,aeate,ate_std
Q-MN,0.0,0.2142,0.4551
Q-SL,0.0,0.1474,0.3162
Q-TVAE,0.0,0.1123,0.416
Q-T,0.0,0.0895,0.4299
DML,0.0,0.0919,0.4408


In [78]:
from IPython.display import display, HTML

# get extra results  for DML comparison with other U-methods
all_results_copy= all_results.copy()
outcome_cols = ['p', 'aeate', 'ate_std']
ds_list = ['LF (v1) n=500', 'LF (v1) n=5000', 'LF (v1) n=10000',
          'LF (v2) n=500', 'LF (v2) n=5000', 'LF (v2) n=10000',
          'Gen n=500', 'Gen n=5000', 'Gen n=10000', 'IHDP n=747'] # 'LF (v2) n=10000' or  'Gen n=10000'
ind_g = 'G-MN'
ind_u = 'U-sub'

for i in range(len(ds_list)):
    dataset = [] 
    dataset_cat = []
    ds = ds_list[i]
    ds_var = np.array([i])

    for method in all_results_copy[ds].keys():
        result = all_results_copy[ds][method]
        output_p, output_aeate, output_atestd,  = np.asarray([result['p']]), np.asarray([result['aeate']]), np.asarray([result['ate_std']])

        var, var_cat, cols, cols_cat = get_vars_method(method)
        cols += ['dataset'] + outcome_cols
        cols_cat += ['dataset'] + outcome_cols
        var = np.concatenate([var, ds_var, output_p, output_aeate, output_atestd])
        var_cat = np.concatenate([var_cat, ds_var, output_p, output_aeate, output_atestd])
        dataset.append(var)
        dataset_cat.append(var_cat)



    dataset = pd.DataFrame(np.asarray(dataset), columns=cols)
    dataset_cat = pd.DataFrame(np.asarray(dataset_cat), columns=cols_cat)
    dataset = dataset.drop(columns=['U-multi-nonlin-fqh-var', 'U-multi-lin-var', 'U-multi-nonlin-fqh-meanvar', 'U-multi-nonlin-fh-var'])


    Q_table = ['Q-MN','Q-SL','Q-TVAE','Q-T','DML']
    G_table = [ 'G-SL', 'G-MN']
    U_table = ['U-multi-lin-meanvar', 'U-sub']


    all_Q_cols = [i for i in dataset.columns if 'Q-' in i]
    all_Q_cols = all_Q_cols + ['DML']
    all_G_cols = [i for i in dataset.columns if 'G-' in i]
    all_U_cols = [i for i in dataset.columns if 'U-' in i]

    desired_columns = Q_table + G_table + U_table + ['p', 'aeate', 'ate_std']

    result_columns = ['p', 'aeate', 'ate_std']

    method_names = []
    reses = []
    for q_method in Q_table[:-1]:

        result = dataset[(dataset[q_method] == 1) &
                        (dataset[ind_g] == 1 )&
                        (dataset[ind_u] == 1)]

        reses.append(np.round(result[result_columns].values[0],4))
    a = pd.DataFrame(reses)
#     print(ds, ind_g, ind_u)
    a.columns = ['p', 'aeate', 'ate_std']
    a.index = Q_table[:-1]
    display(HTML(a.to_html()))

Unnamed: 0,p,aeate,ate_std
Q-MN,0.0045,0.0768,0.1049
Q-SL,0.0202,0.0758,0.1022
Q-TVAE,0.0029,0.0806,0.1067
Q-T,0.1244,0.0734,0.0973


Unnamed: 0,p,aeate,ate_std
Q-MN,0.5692,0.0237,0.0282
Q-SL,0.2467,0.0217,0.0283
Q-TVAE,0.465,0.0227,0.0297
Q-T,0.6863,0.0254,0.0331


Unnamed: 0,p,aeate,ate_std
Q-MN,0.2524,0.0189,0.0203
Q-SL,0.3251,0.0149,0.0182
Q-TVAE,0.1576,0.0154,0.019
Q-T,0.4915,0.0171,0.0203


Unnamed: 0,p,aeate,ate_std
Q-MN,0.0009,0.0724,0.0984
Q-SL,0.0001,0.0708,0.095
Q-TVAE,0.0,0.076,0.1028
Q-T,0.0079,0.0703,0.0981


Unnamed: 0,p,aeate,ate_std
Q-MN,0.6407,0.0262,0.0316
Q-SL,0.5783,0.0244,0.0316
Q-TVAE,0.4977,0.0233,0.0303
Q-T,0.0001,0.0271,0.038


Unnamed: 0,p,aeate,ate_std
Q-MN,0.1258,0.0246,0.0261
Q-SL,0.321,0.0173,0.0207
Q-TVAE,0.2403,0.0165,0.0202
Q-T,0.4665,0.0186,0.0229


Unnamed: 0,p,aeate,ate_std
Q-MN,0.0001,0.0553,0.074
Q-SL,0.0001,0.0548,0.0725
Q-TVAE,0.0,0.0578,0.076
Q-T,0.0044,0.0541,0.0696


Unnamed: 0,p,aeate,ate_std
Q-MN,0.7184,0.0164,0.0227
Q-SL,0.9423,0.0171,0.0235
Q-TVAE,0.9806,0.0183,0.0244
Q-T,0.9,0.0279,0.0253


Unnamed: 0,p,aeate,ate_std
Q-MN,0.0522,0.0131,0.0168
Q-SL,0.0091,0.013,0.0164
Q-TVAE,0.0491,0.0132,0.0172
Q-T,0.04,0.0193,0.0175


Unnamed: 0,p,aeate,ate_std
Q-MN,0.0,0.3005,0.8027
Q-SL,0.0,0.1832,0.5087
Q-TVAE,0.0,0.2085,0.5505
Q-T,0.0,0.1328,0.4796
