In [1]:
import proplot as pplt
import matplotlib.pyplot as plt
import os
import yaml
from tqdm import tqdm
import pandas as pd
import numpy as np

In [2]:
records = []
l2vals = [0.01]
for l2 in l2vals:
    dname = f"results_lr-0.001-l2-{l2}"
    files = os.listdir(dname)
    for f in tqdm(files):
        with open(f"{dname}/{f}", 'r') as io:
            r = yaml.load(io, yaml.SafeLoader)
        scenario, seed = f.replace('.yaml', '').split('-')
        r['l2_reg'] = l2
        s = scenario.split('_')[1]
        r['scenario'] = 'ihdp' if s == '0' else s
        r['seed'] = seed.split('_')[1]
        records.append(r)

100%|██████████| 1400/1400 [00:00<00:00, 1695.04it/s]


In [3]:
records = pd.DataFrame(records)
records = records.rename({'ate_error': 'bias', 'mse_loss': 'mse_outcome', 'ate_estimate': 'ate_pred_mean'}, axis=1)
records["mean_ate"] = records["ate_pred_mean"] - records["bias"]
records['mse_ate'] = np.square(records['bias']) 

In [4]:
results = records.drop("seed", axis=1).groupby(["l2_reg", "scenario"]).mean().reset_index()
n = records.groupby(["l2_reg", "scenario"]).agg({'seed': len}).reset_index().rename({'seed': 'nreps'}, axis=1)
results = results.merge(n)
if len(l2vals) == 1:
    results = results.drop('l2_reg', axis=1)
results

Unnamed: 0,scenario,bias,ate_pred_mean,mse_outcome,train_loss,treatment_loss,treg_loss,mean_ate,mse_ate,nreps
0,1,-0.363083,-2.926081,0.002331,0.044638,0.043118,0.00076,-2.562998,0.493034,200
1,2,-0.378053,-1.773866,0.002466,0.045001,0.043384,0.000809,-1.395812,0.492223,200
2,3,0.259024,-1.136789,0.002166,0.030201,0.028805,0.000698,-1.395812,0.289288,200
3,4,-0.995293,-3.564514,0.00351,0.034849,0.032964,0.000943,-2.569221,1.538452,200
4,5,-0.195894,-2.758891,0.003561,0.045334,0.043428,0.000953,-2.562998,0.413015,200
5,6,0.060187,-2.502811,0.003573,0.046387,0.043615,0.001386,-2.562998,0.084361,200
6,ihdp,0.073367,4.666792,0.29097,0.63082,0.434956,0.097922,4.593424,0.021491,200


In [5]:
# important columns
results[['scenario', 'bias', 'mse_ate']]

Unnamed: 0,scenario,bias,mse_ate
0,1,-0.363083,0.493034
1,2,-0.378053,0.492223
2,3,0.259024,0.289288
3,4,-0.995293,1.538452
4,5,-0.195894,0.413015
5,6,0.060187,0.084361
6,ihdp,0.073367,0.021491


In [6]:
results.to_csv("benchmark-results-dragonnet.csv", index=False)

In [7]:
# compute standard deviations as a safety check
results_sd = records.drop("seed", axis=1).groupby(["l2_reg", "scenario"]).std().reset_index()
if len(l2vals) == 1:
    results_sd = results_sd.drop('l2_reg', axis=1)
results_sd.columns = [x + "_sd" for x in results_sd.columns]
results_sd

Unnamed: 0,scenario_sd,bias_sd,ate_pred_mean_sd,mse_outcome_sd,train_loss_sd,treatment_loss_sd,treg_loss_sd,mean_ate_sd,mse_ate_sd
0,1,0.602511,0.612043,0.00049,0.008951,0.008841,0.000141,0.05448,0.701787
1,2,0.592498,0.591839,0.00055,0.009115,0.009003,0.000159,0.029324,0.703375
2,3,0.472558,0.472377,0.000502,0.004535,0.004513,0.000144,0.029324,0.379781
3,4,0.742023,0.743646,0.000492,0.003643,0.003593,0.000118,0.040726,1.598995
4,5,0.613615,0.621306,0.00069,0.009084,0.009015,0.000163,0.05448,0.586691
5,6,0.284858,0.29468,0.000633,0.009338,0.009313,0.000217,0.05448,0.11108
6,ihdp,0.127239,1.82395,0.105208,0.14276,0.002086,0.071027,1.793613,0.040414


## Results of Hyads

In [8]:
files = os.listdir("./results_app")
res = []
for f in files:
    with open(f"results_app/{f}", "r") as io:
        r = yaml.load(io, yaml.SafeLoader)
        _, year, seed = f.replace(".yaml", "").split("-")
        r['year'] = int(year.split("_")[1])
        # r['seed'] = int(seed.split("_")[1])
        res.append(r)
res = pd.DataFrame(res)
re_mean = res.groupby("year").mean()
re_mean.columns = [x + "_mean" for x in re_mean.columns]
re_std = res.groupby("year").std()
re_std.columns = [x + "_sd" for x in re_std.columns]

In [9]:
re_mean

Unnamed: 0_level_0,ate_estimate_mean,mse_loss_mean,train_loss_mean,treatment_loss_mean,treg_loss_mean
year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2013,-0.180061,0.111502,0.20486,0.050992,0.076932
2014,-0.36448,0.120733,0.185924,0.044087,0.070918


In [10]:
re_std

Unnamed: 0_level_0,ate_estimate_sd,mse_loss_sd,train_loss_sd,treatment_loss_sd,treg_loss_sd
year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2013,,,,,
2014,,,,,
