In [174]:
import numpy as np
import os
from sklearn.metrics import mean_squared_error as mse_f
from scipy import sparse
from scipy.stats import gamma
from scipy.stats import ttest_ind
import warnings
import pandas as pd

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


In [175]:
paper_model_names = models = {
    "dsbmm_dpf.z-only": "Ours-no-meta",
    "dsbmm_dpf.z-theta-joint": "Ours",
    # 'spf.main':'MSPF',
    "unadjusted.main": "Unadjusted",
    "network_pref_only.main": "Net.-only",
    "topic_only.main": "Topic-only",
    "no_unobs.main": "Oracle",
    "topic_only_oracle.main": "Topic-oracle",
}


In [195]:
tmp_kv = tuple(paper_model_names.items())
for k, v in tmp_kv:
    if "dsbmm_dpf" in k:
        paper_model_names[k + "-ndc"] = v + "-NDC"

tmp_dict = {
    k + "old_subs" + "_ewcnone" + "_rcolmain_adm1_1hot": v + '-old-A'
    for k, v in paper_model_names.items()
    if "dsbmm_dpf" not in k
}
tmp_dict.update(
   {
    k + "pres_subs" + "_ewcnone" + "_rcolmain_adm1_1hot": v + '-pres-A'
    for k, v in paper_model_names.items()
    if "dsbmm_dpf" not in k
} 
)
sub_choice_pretty = {"old_subs": "-old", "pres_subs": "-pres"}
reg_choice_pretty = {"adm1": "-A", "ctry": "-C"}
for k, v in paper_model_names.items():
    if "dsbmm_dpf" in k:
        for sub_choice in ["old_subs", "pres_subs"]:
            for region in ["adm1", "ctry"]:
                tmp_dict[k + f"{sub_choice}_ewcnone_rcolmain_{region}_1hot"] = (
                    v + sub_choice_pretty[sub_choice] + reg_choice_pretty[region]
                )


In [196]:
full_paper_model_names = tmp_dict

In [197]:
def print_table(exp_results, regimes, models, exps=10, print_notfound=False):
    ncols = len(regimes.keys())
    nrows = len(models)
    results = np.zeros((nrows, ncols))
    std = np.zeros((nrows, ncols))
    alt_results = np.zeros((nrows, ncols))
    alt_std = np.zeros((nrows, ncols))

    col_idx = 0
    for regime, c in regimes.items():
        row_idx = 0
        for model in models:
            mse = np.zeros((exps,4))
            for i in range(exps):
                try:
                    beta_predicted = exp_results[c][model][i][0]
                    truth = exp_results[c][model][i][1]
                    sq_err = (beta_predicted - truth) ** 2
                    mse[i] = sq_err.mean(axis=0)
                except:
                    if print_notfound:
                        print(model, "exp", i, "not found")
            results[row_idx][col_idx] = round(mse.mean() * 1000, 2)
            std[row_idx][col_idx] = round(mse.std() * 1000, 2)
            
            alt_results[row_idx][col_idx] = round(mse[:,:-1].mean() * 1000, 2)
            alt_std[row_idx][col_idx] = round(mse[:,:-1].std() * 1000, 2)
            
            row_idx += 1
        col_idx += 1

    proper_names = [full_paper_model_names[m] for m in models]
    col_names = list(regimes.keys())
    df = pd.DataFrame(results, index=proper_names, columns=col_names, dtype=str)
    std_df = pd.DataFrame(std, index=proper_names, columns=col_names, dtype=str)
    df = df + "$\pm$" + std_df
    
    alt_df = pd.DataFrame(alt_results, index=proper_names, columns=col_names, dtype=str)
    alt_std_df = pd.DataFrame(alt_std, index=proper_names, columns=col_names, dtype=str)
    alt_df = alt_df + "$\pm$" + alt_std_df
    return df, alt_df


### Load results

In [198]:
from pathlib import Path

res_dir = Path("/scratch/fitzgeraldj/data/caus_inf_data/results")
exps = 5
# embed = "user"
sub_choices = ["old_subs","pres_subs"] 
regions = ["adm1","ctry"]  # or "ctry" for dsbmm_dpf models
base_models = [
    "unadjusted.main",
    "network_pref_only.main",
    "topic_only.main",
    "no_unobs.main",
    "topic_only_oracle.main",
    "dsbmm_dpf.z-only",
    "dsbmm_dpf.z-theta-joint",
    "dsbmm_dpf.z-theta-joint-ndc",
]
models = [m + f"{sub_choice}_ewcnone_rcolmain_{region}_1hot" for m in base_models for sub_choice in sub_choices for region in regions]

conf_types = ["homophily", "exog", "both"]
confounding_strengths = [(50, 10), (50, 50), (50, 100)]
exp_results = {}
found = set()

for i in range(1, exps + 1):
    for model in models:
        for (cov1conf, cov2conf) in confounding_strengths:
            for ct in conf_types:
                try:
                    base_file_name = (
                        "conf="
                        + str((cov1conf, cov2conf))
                        + ";conf_type="
                        + ct
                        + ".npz"
                    )
                    result_file = (
                        (res_dir / str(i)) / (model + "_model_fitted_params")
                    ) / base_file_name
                    res = np.load(result_file)
                    params = res["fitted"]
                    truth = res["true"]

                    if (ct, (cov1conf, cov2conf)) in exp_results:
                        if model in exp_results[(ct, (cov1conf, cov2conf))]:
                            exp_results[(ct, (cov1conf, cov2conf))][model].append(
                                (params, truth)
                            )
                        else:
                            exp_results[(ct, (cov1conf, cov2conf))][model] = [
                                (params, truth)
                            ]
                    else:
                        exp_results[(ct, (cov1conf, cov2conf))] = {
                            model: [(params, truth)]
                        }
                    if model not in found:
                        print(model, "found")
                        found |= set([model])
                except:
                    # print(result_file, " not found")
                    continue


dsbmm_dpf.z-theta-jointold_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-jointold_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-theta-jointpres_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-jointpres_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-onlypres_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-onlyold_subs_ewcnone_rcolmain_adm1_1hot found
unadjusted.mainold_subs_ewcnone_rcolmain_adm1_1hot found
unadjusted.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
topic_only.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
no_unobs.mainold_subs_ewcnone_rcolmain_adm1_1hot found
no_unobs.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
topic_only_oracle.mainold_subs_ewcnone_rcolmain_adm1_1hot found
topic_only_oracle.mainpres_subs_ewcnone_rcolmain_adm1_1hot found


In [199]:
# dsbmm_dpf.z-onlyupd_subs_ewcnone_rcolmain_adm1_1hot_model_fitted_params
found

{'dsbmm_dpf.z-onlyold_subs_ewcnone_rcolmain_adm1_1hot',
 'dsbmm_dpf.z-onlypres_subs_ewcnone_rcolmain_adm1_1hot',
 'dsbmm_dpf.z-theta-jointold_subs_ewcnone_rcolmain_adm1_1hot',
 'dsbmm_dpf.z-theta-jointold_subs_ewcnone_rcolmain_ctry_1hot',
 'dsbmm_dpf.z-theta-jointpres_subs_ewcnone_rcolmain_adm1_1hot',
 'dsbmm_dpf.z-theta-jointpres_subs_ewcnone_rcolmain_ctry_1hot',
 'no_unobs.mainold_subs_ewcnone_rcolmain_adm1_1hot',
 'no_unobs.mainpres_subs_ewcnone_rcolmain_adm1_1hot',
 'topic_only.mainpres_subs_ewcnone_rcolmain_adm1_1hot',
 'topic_only_oracle.mainold_subs_ewcnone_rcolmain_adm1_1hot',
 'topic_only_oracle.mainpres_subs_ewcnone_rcolmain_adm1_1hot',
 'unadjusted.mainold_subs_ewcnone_rcolmain_adm1_1hot',
 'unadjusted.mainpres_subs_ewcnone_rcolmain_adm1_1hot'}

### Visualize results

In [200]:
confounding_type = "exog"
models = list(exp_results[(confounding_type, confounding_strengths[1])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df1,alt_df1 = print_table(exp_results, regime1, models)


In [201]:
confounding_type = "homophily"
models = list(exp_results[(confounding_type, confounding_strengths[0])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df2,alt_df2 = print_table(exp_results, regime1, models)


In [202]:
confounding_type = "both"
models = list(exp_results[(confounding_type, confounding_strengths[0])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df3,alt_df3 = print_table(exp_results, regime1, models)


In [203]:
all_results = pd.concat([df1, df2, df3], axis=1, keys=["Exog.", "Homophily", "Both"])


In [204]:
all_results[all_results == all_results.min(axis=0)] = '\textbf{' + all_results[all_results == all_results.min(axis=0)] + '}'

In [205]:
all_results

Unnamed: 0_level_0,Exog.,Exog.,Exog.,Homophily,Homophily,Homophily,Both,Both,Both
Unnamed: 0_level_1,Low,Med.,High,Low,Med.,High,Low,Med.,High
Ours-old-A,1.32$\pm$2.83,2.25$\pm$5.51,3.67$\pm$12.45,384.56$\pm$1390.78,200.35$\pm$664.02,104.18$\pm$332.85,480.87$\pm$1790.08,222.02$\pm$698.54,119.93$\pm$382.8
Ours-old-C,1.32$\pm$2.84,2.26$\pm$5.53,3.72$\pm$12.76,335.59$\pm$1137.56,200.96$\pm$666.84,104.4$\pm$333.06,484.47$\pm$1803.43,223.66$\pm$704.5,120.3$\pm$383.22
Ours-pres-A,1.41$\pm$2.85,2.26$\pm$5.53,3.68$\pm$12.45,171.47$\pm$660.47,177.89$\pm$566.97,104.28$\pm$332.83,480.98$\pm$1790.05,222.24$\pm$698.48,120.15$\pm$382.74
Ours-pres-C,1.39$\pm$2.84,2.25$\pm$5.49,3.73$\pm$12.76,632.92$\pm$2367.74,248.41$\pm$795.89,104.71$\pm$333.0,484.51$\pm$1803.42,223.86$\pm$704.45,120.59$\pm$383.16
Unadjusted-old-A,0.17$\pm$0.76,0.14$\pm$0.46,0.13$\pm$0.42,\textbf{110.31$\pm$682.04},109.78$\pm$678.65,79.5$\pm$490.85,193.31$\pm$1194.58,130.43$\pm$806.33,94.02$\pm$580.34
Unadjusted-pres-A,0.13$\pm$0.51,0.14$\pm$0.48,0.13$\pm$0.42,126.8$\pm$784.27,111.52$\pm$689.52,79.15$\pm$488.64,193.31$\pm$1194.58,130.43$\pm$806.33,94.02$\pm$580.34
Topic-only-pres-A,0.09$\pm$0.41,0.21$\pm$1.04,0.08$\pm$0.26,92.43$\pm$572.09,64.35$\pm$398.04,51.43$\pm$319.08,114.97$\pm$712.35,73.93$\pm$456.55,67.96$\pm$421.3
Oracle-old-A,0.18$\pm$0.97,0.09$\pm$0.42,0.28$\pm$1.62,170.47$\pm$1054.85,116.88$\pm$723.92,88.36$\pm$548.8,241.21$\pm$1493.81,139.52$\pm$865.2,103.21$\pm$640.38
Oracle-pres-A,0.18$\pm$0.97,0.09$\pm$0.42,0.28$\pm$1.54,217.62$\pm$1343.7,116.57$\pm$721.69,87.58$\pm$543.42,246.38$\pm$1525.83,138.69$\pm$859.86,101.6$\pm$630.01
Topic-oracle-old-A,0.04$\pm$0.19,0.04$\pm$0.21,0.12$\pm$0.67,162.0$\pm$1002.77,122.91$\pm$761.32,89.99$\pm$557.69,218.24$\pm$1350.98,173.62$\pm$1075.91,99.77$\pm$617.87


In [206]:
all_results.to_latex('./results/semi-synth.tex',escape=False)

In [207]:
alt_all_results = pd.concat([alt_df1, alt_df2, alt_df3], axis=1, keys=["Exog.", "Homophily", "Both"])
alt_all_results


Unnamed: 0_level_0,Exog.,Exog.,Exog.,Homophily,Homophily,Homophily,Both,Both,Both
Unnamed: 0_level_1,Low,Med.,High,Low,Med.,High,Low,Med.,High
Ours-old-A,0.8$\pm$2.22,1.23$\pm$3.68,1.15$\pm$2.39,5.86$\pm$11.87,4.45$\pm$8.62,4.3$\pm$8.28,6.22$\pm$12.31,4.81$\pm$9.17,4.94$\pm$9.35
Ours-old-C,0.81$\pm$2.24,1.24$\pm$3.72,1.15$\pm$2.41,6.34$\pm$12.21,4.59$\pm$8.64,4.43$\pm$8.32,6.37$\pm$12.28,4.94$\pm$9.16,5.09$\pm$9.39
Ours-pres-A,0.92$\pm$2.28,1.25$\pm$3.72,1.16$\pm$2.42,8.51$\pm$21.57,4.92$\pm$8.89,4.43$\pm$8.67,6.37$\pm$12.32,5.11$\pm$9.46,5.24$\pm$9.75
Ours-pres-C,0.89$\pm$2.27,1.23$\pm$3.65,1.16$\pm$2.4,7.71$\pm$17.74,5.29$\pm$11.0,4.83$\pm$9.72,6.42$\pm$12.43,5.19$\pm$9.73,5.48$\pm$10.66
Unadjusted-old-A,0.07$\pm$0.22,0.11$\pm$0.34,0.1$\pm$0.33,1.43$\pm$5.4,1.44$\pm$5.46,1.18$\pm$4.13,2.64$\pm$11.58,1.71$\pm$6.77,1.42$\pm$5.3
Unadjusted-pres-A,0.07$\pm$0.22,0.11$\pm$0.34,0.1$\pm$0.33,1.58$\pm$6.23,1.44$\pm$5.46,1.18$\pm$4.13,2.64$\pm$11.58,1.71$\pm$6.77,1.42$\pm$5.3
Topic-only-pres-A,0.04$\pm$0.12,0.07$\pm$0.21,0.06$\pm$0.2,1.07$\pm$5.53,0.8$\pm$3.69,0.44$\pm$1.58,1.17$\pm$6.08,1.08$\pm$5.18,0.65$\pm$2.7
Oracle-old-A,0.03$\pm$0.1,0.03$\pm$0.11,0.03$\pm$0.12,2.03$\pm$10.8,1.26$\pm$6.68,0.63$\pm$3.35,2.62$\pm$13.92,1.27$\pm$6.75,0.87$\pm$4.6
Oracle-pres-A,0.03$\pm$0.12,0.04$\pm$0.13,0.04$\pm$0.14,3.21$\pm$16.96,1.31$\pm$6.59,0.74$\pm$3.31,2.68$\pm$14.09,1.31$\pm$6.6,0.94$\pm$4.41
Topic-oracle-old-A,0.01$\pm$0.03,0.01$\pm$0.04,0.01$\pm$0.04,1.86$\pm$9.26,1.31$\pm$6.08,0.9$\pm$3.69,2.49$\pm$12.63,1.75$\pm$8.41,1.08$\pm$4.68


In [208]:
alt_all_results[alt_all_results == alt_all_results.min(axis=0)] = '\textbf{' + alt_all_results[alt_all_results == alt_all_results.min(axis=0)] + '}'

In [209]:
alt_all_results

Unnamed: 0_level_0,Exog.,Exog.,Exog.,Homophily,Homophily,Homophily,Both,Both,Both
Unnamed: 0_level_1,Low,Med.,High,Low,Med.,High,Low,Med.,High
Ours-old-A,0.8$\pm$2.22,1.23$\pm$3.68,1.15$\pm$2.39,5.86$\pm$11.87,4.45$\pm$8.62,4.3$\pm$8.28,6.22$\pm$12.31,4.81$\pm$9.17,4.94$\pm$9.35
Ours-old-C,0.81$\pm$2.24,1.24$\pm$3.72,1.15$\pm$2.41,6.34$\pm$12.21,4.59$\pm$8.64,4.43$\pm$8.32,6.37$\pm$12.28,4.94$\pm$9.16,5.09$\pm$9.39
Ours-pres-A,0.92$\pm$2.28,1.25$\pm$3.72,1.16$\pm$2.42,8.51$\pm$21.57,4.92$\pm$8.89,4.43$\pm$8.67,6.37$\pm$12.32,5.11$\pm$9.46,5.24$\pm$9.75
Ours-pres-C,0.89$\pm$2.27,1.23$\pm$3.65,1.16$\pm$2.4,7.71$\pm$17.74,5.29$\pm$11.0,4.83$\pm$9.72,6.42$\pm$12.43,5.19$\pm$9.73,5.48$\pm$10.66
Unadjusted-old-A,0.07$\pm$0.22,0.11$\pm$0.34,0.1$\pm$0.33,1.43$\pm$5.4,1.44$\pm$5.46,1.18$\pm$4.13,2.64$\pm$11.58,1.71$\pm$6.77,1.42$\pm$5.3
Unadjusted-pres-A,0.07$\pm$0.22,0.11$\pm$0.34,0.1$\pm$0.33,1.58$\pm$6.23,1.44$\pm$5.46,1.18$\pm$4.13,2.64$\pm$11.58,1.71$\pm$6.77,1.42$\pm$5.3
Topic-only-pres-A,0.04$\pm$0.12,0.07$\pm$0.21,0.06$\pm$0.2,\textbf{1.07$\pm$5.53},0.8$\pm$3.69,0.44$\pm$1.58,1.17$\pm$6.08,1.08$\pm$5.18,0.65$\pm$2.7
Oracle-old-A,0.03$\pm$0.1,0.03$\pm$0.11,0.03$\pm$0.12,2.03$\pm$10.8,1.26$\pm$6.68,0.63$\pm$3.35,2.62$\pm$13.92,1.27$\pm$6.75,0.87$\pm$4.6
Oracle-pres-A,0.03$\pm$0.12,0.04$\pm$0.13,0.04$\pm$0.14,3.21$\pm$16.96,1.31$\pm$6.59,0.74$\pm$3.31,2.68$\pm$14.09,1.31$\pm$6.6,0.94$\pm$4.41
Topic-oracle-old-A,0.01$\pm$0.03,0.01$\pm$0.04,0.01$\pm$0.04,1.86$\pm$9.26,1.31$\pm$6.08,0.9$\pm$3.69,2.49$\pm$12.63,1.75$\pm$8.41,1.08$\pm$4.68


In [210]:
alt_all_results.to_latex('./results/alt-semi-synth.tex',escape=False)

In [192]:
import pickle
# load up PPC results
with open(res_dir / "dsbmm_ppc_results.pkl", "rb") as f:
    dsbmm_ppc_results = pickle.load(f)
with open(res_dir / "dpf_ppc_results.pkl", "rb") as f:
    dpf_ppc_results = pickle.load(f)
with open(res_dir / "dpf_auc_results.pkl", "rb") as f:
    dpf_auc_results = pickle.load(f)

In [194]:
dpf_ppc_results

array([[1., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])

In [109]:
tmp = np.random.randint(6,8,size=(3,2))/10
ppc_df = pd.DataFrame(tmp,columns=['$A$','$Y$'])
ppc_df['$Q$'] = [4,9,16]
ppc_df['$K$'] = [3,5,8]
ppc_df[['$K$','$Y$']].to_latex('./results/ex-topic-synth-ppcs.tex',escape=False)
ppc_df[['$Q$','$A$']].to_latex('./results/ex-auth-synth-ppcs.tex',escape=False)