In [1]:
import proplot as pplt
import matplotlib.pyplot as plt
import os
import yaml
from tqdm import tqdm
import pandas as pd
import numpy as np

In [2]:
records = []
l2vals = [0.01]
for l2 in l2vals:
    dname = f"results_lr-0.001-l2-{l2}"
    files = os.listdir(dname)
    for f in tqdm(files):
        with open(f"{dname}/{f}", 'r') as io:
            r = yaml.load(io, yaml.SafeLoader)
        scenario, seed = f.replace('.yaml', '').split('-')
        r['l2_reg'] = l2
        s = scenario.split('_')[1]
        r['scenario'] = 'ihdp' if s == '0' else s
        r['seed'] = seed.split('_')[1]
        records.append(r)

100%|██████████| 1400/1400 [00:16<00:00, 86.15it/s]


In [3]:
records = pd.DataFrame(records)
records = records.rename({'ate_error': 'bias', 'mse_loss': 'mse_outcome', 'ate_estimate': 'ate_pred_mean'}, axis=1)
records["mean_ate"] = records["ate_pred_mean"] - records["bias"]
records['mse_ate'] = np.square(records['bias']) 

In [4]:
results = records.drop("seed", axis=1).groupby(["l2_reg", "scenario"]).mean().reset_index()
n = records.groupby(["l2_reg", "scenario"]).agg({'seed': len}).reset_index().rename({'seed': 'nreps'}, axis=1)
results = results.merge(n)
if len(l2vals) == 1:
    results = results.drop('l2_reg', axis=1)
results

Unnamed: 0,scenario,bias,ate_pred_mean,mse_outcome,tmle_loss,train_loss,treatment_loss,mean_ate,mse_ate,nreps
0,1,0.027015,-2.535983,0.002226,0.000721,0.045825,0.044378,-2.562998,0.366198,200
1,2,0.059297,-1.336515,0.002336,0.000762,0.046504,0.044976,-1.395812,0.360805,200
2,3,0.650255,-0.745558,0.002111,0.00068,0.030605,0.029245,-1.395812,0.648532,200
3,4,-0.494283,-3.063505,0.0034,0.000911,0.035449,0.033625,-2.569221,0.773561,200
4,5,0.236625,-2.326373,0.003435,0.000916,0.046525,0.04469,-2.562998,0.433478,200
5,6,0.239676,-2.323322,0.003538,0.00137,0.046925,0.044182,-2.562998,0.133729,200
6,ihdp,0.052922,4.646346,0.291041,0.097868,0.630568,0.434822,4.593424,0.014854,200


In [5]:
# important columns
results[['scenario', 'bias', 'mse_ate']]

Unnamed: 0,scenario,bias,mse_ate
0,1,0.027015,0.366198
1,2,0.059297,0.360805
2,3,0.650255,0.648532
3,4,-0.494283,0.773561
4,5,0.236625,0.433478
5,6,0.239676,0.133729
6,ihdp,0.052922,0.014854


In [6]:
results.to_csv("benchmark-results-dragonnet.csv", index=False)

In [7]:
# compute standard deviations as a safety check
results_sd = records.drop("seed", axis=1).groupby(["l2_reg", "scenario"]).std().reset_index()
if len(l2vals) == 1:
    results_sd = results_sd.drop('l2_reg', axis=1)
results_sd.columns = [x + "_sd" for x in results_sd.columns]
results_sd

Unnamed: 0,scenario_sd,bias_sd,ate_pred_mean_sd,mse_outcome_sd,tmle_loss_sd,train_loss_sd,treatment_loss_sd,mean_ate_sd,mse_ate_sd
0,1,0.606056,0.615921,0.000488,0.000139,0.009662,0.009553,0.05448,0.51206
1,2,0.599237,0.599603,0.000556,0.00016,0.009983,0.00987,0.029324,0.513185
2,3,0.476272,0.476827,0.000491,0.000141,0.004705,0.004683,0.029324,0.669091
3,4,0.729318,0.730514,0.00047,0.000112,0.003989,0.003931,0.040726,0.995792
4,5,0.615941,0.623806,0.000693,0.000163,0.009791,0.009721,0.05448,0.590077
5,6,0.27689,0.287614,0.000626,0.000214,0.009691,0.00967,0.05448,0.161296
6,ihdp,0.110065,1.829801,0.104784,0.070925,0.142464,0.002097,1.793613,0.026966


## Results of Hyads

In [22]:
files = os.listdir("./results_app")
res = []
for f in files:
    with open(f"results_app/{f}", "r") as io:
        r = yaml.load(io, yaml.SafeLoader)
        _, year, seed = f.replace(".yaml", "").split("-")
        r['year'] = int(year.split("_")[1])
        # r['seed'] = int(seed.split("_")[1])
        res.append(r)
res = pd.DataFrame(res)
re_mean = res.groupby("year").mean()
re_mean.columns = [x + "_mean" for x in re_mean.columns]
re_std = res.groupby("year").std()
re_std.columns = [x + "_sd" for x in re_std.columns]

In [23]:
re_mean

Unnamed: 0_level_0,ate_estimate_mean,mse_loss_mean,tmle_loss_mean,train_loss_mean,treatment_loss_mean
year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2013,-0.302554,0.112162,0.077384,0.206517,0.05174
2014,-0.28432,0.122477,0.071942,0.188544,0.044658


In [24]:
re_std

Unnamed: 0_level_0,ate_estimate_sd,mse_loss_sd,tmle_loss_sd,train_loss_sd,treatment_loss_sd
year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2013,0.074031,0.001187,0.00082,0.001728,0.000452
2014,0.05607,0.003375,0.001983,0.004221,0.000471
