In [282]:
import numpy as np
import os
from sklearn.metrics import mean_squared_error as mse_f
from scipy import sparse
from scipy.stats import gamma
from scipy.stats import ttest_ind
import warnings
import pandas as pd

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


In [283]:
paper_model_names = models = {
    "dsbmm_dpf.z-only": "Ours-no-meta",
    "dsbmm_dpf.z-theta-joint": "Ours",
    # 'spf.main':'MSPF',
    "unadjusted.main": "Unadjusted",
    "network_pref_only.main": "Net.-only",
    "topic_only.main": "Topic-only",
    "no_unobs.main": "Oracle",
    "topic_only_oracle.main": "Topic-oracle",
}


In [284]:
tmp_kv = tuple(paper_model_names.items())
for k, v in tmp_kv:
    if "dsbmm_dpf" in k:
        paper_model_names[k + "-ndc"] = v + "-NDC"

tmp_dict = {
    k + "old_subs" + "_ewcnone" + "_rcolmain_adm1_1hot": v + "-old-A"
    for k, v in paper_model_names.items()
    if "dsbmm_dpf" not in k
}
tmp_dict.update(
    {
        k + "pres_subs" + "_ewcnone" + "_rcolmain_adm1_1hot": v + "-pres-A"
        for k, v in paper_model_names.items()
        if "dsbmm_dpf" not in k
    }
)
sub_choice_pretty = {"old_subs": "-old", "pres_subs": "-pres"}
reg_choice_pretty = {"adm1": "-A", "ctry": "-C"}
for k, v in paper_model_names.items():
    if "dsbmm_dpf" in k:
        for sub_choice in ["old_subs", "pres_subs"]:
            for region in ["adm1", "ctry"]:
                tmp_dict[k + f"{sub_choice}_ewcnone_rcolmain_{region}_1hot"] = (
                    v + sub_choice_pretty[sub_choice] + reg_choice_pretty[region]
                )


In [285]:
full_paper_model_names = tmp_dict


In [287]:
def print_table(exp_results, regimes, models, exps=10, print_notfound=False, bold_min = True):
    ncols = len(regimes.keys())
    nrows = len(models)
    results = np.zeros((nrows, ncols))
    std = np.zeros((nrows, ncols))
    alt_results = np.zeros((nrows, ncols))
    alt_std = np.zeros((nrows, ncols))

    col_idx = 0
    for regime, c in regimes.items():
        row_idx = 0
        for model in models:
            mse = np.zeros((exps, 4))
            pres_exps = np.zeros(exps,dtype=bool)
            for i in range(exps):
                try:
                    beta_predicted = exp_results[c][model][i][0]
                    truth = exp_results[c][model][i][1]
                    sq_err = (beta_predicted - truth) ** 2
                    mse[i] = sq_err.mean(axis=0)
                    pres_exps[i] = True
                except:
                    if print_notfound:
                        print(model, "exp", i, "not found")
            results[row_idx][col_idx] = round(mse[pres_exps].mean() * 1000, 2)
            std[row_idx][col_idx] = round(mse[pres_exps].std() * 1000, 2)

            alt_results[row_idx][col_idx] = round(mse[pres_exps, :-1].mean() * 1000, 2)
            alt_std[row_idx][col_idx] = round(mse[pres_exps, :-1].std() * 1000, 2)

            row_idx += 1
        col_idx += 1

    proper_names = [full_paper_model_names[m] for m in models]
    col_names = list(regimes.keys())
    df = pd.DataFrame(results, index=proper_names, columns=col_names, dtype=str)
    min_inds = np.argmin(results,axis=0)
    mode_inds = np.argmin(results + std,axis=0)
    alt_min_inds = np.argmin(alt_results,axis=0)
    alt_mode_inds = np.argmin(alt_results + alt_std,axis=0)
    std_df = pd.DataFrame(std, index=proper_names, columns=col_names, dtype=str)
    df = df + "$\pm$" + std_df

    alt_df = pd.DataFrame(alt_results, index=proper_names, columns=col_names, dtype=str)
    alt_std_df = pd.DataFrame(alt_std, index=proper_names, columns=col_names, dtype=str)
    alt_df = alt_df + "$\pm$" + alt_std_df
    
    if bold_min:
        for tmp_df,idxs in zip([df,alt_df],[min_inds,alt_min_inds]):
            for col,idx in zip(tmp_df.columns,idxs):
                tmp_df[col].iloc[idx] = (
                    "\textbf{" + tmp_df[col].iloc[idx] + "}"
                )

    return df, alt_df, min_inds, mode_inds, alt_min_inds, alt_mode_inds


### Load results

In [288]:
from pathlib import Path

res_dir = Path("/scratch/fitzgeraldj/data/caus_inf_data/results")
exps = 5
# embed = "user"
sub_choices = ["old_subs", "pres_subs"]
regions = ["adm1", "ctry"]  # or "ctry" for dsbmm_dpf models
base_models = [
    "unadjusted.main",
    "network_pref_only.main",
    "topic_only.main",
    "no_unobs.main",
    "topic_only_oracle.main",
    "dsbmm_dpf.z-only",
    "dsbmm_dpf.z-theta-joint",
    "dsbmm_dpf.z-theta-joint-ndc",
]
models = [
    m + f"{sub_choice}_ewcnone_rcolmain_{region}_1hot"
    for m in base_models
    for sub_choice in sub_choices
    for region in regions
]

conf_types = ["homophily", "exog", "both"]
confounding_strengths = [(50, 10), (50, 50), (50, 100)]
exp_results = {}
found = set()

for i in range(1, exps + 1):
    for model in models:
        for (cov1conf, cov2conf) in confounding_strengths:
            for ct in conf_types:
                try:
                    base_file_name = (
                        "conf="
                        + str((cov1conf, cov2conf))
                        + ";conf_type="
                        + ct
                        + ".npz"
                    )
                    result_file = (
                        (res_dir / str(i)) / (model + "_model_fitted_params")
                    ) / base_file_name
                    res = np.load(result_file)
                    params = res["fitted"]
                    truth = res["true"]

                    if (ct, (cov1conf, cov2conf)) in exp_results:
                        if model in exp_results[(ct, (cov1conf, cov2conf))]:
                            exp_results[(ct, (cov1conf, cov2conf))][model].append(
                                (params, truth)
                            )
                        else:
                            exp_results[(ct, (cov1conf, cov2conf))][model] = [
                                (params, truth)
                            ]
                    else:
                        exp_results[(ct, (cov1conf, cov2conf))] = {
                            model: [(params, truth)]
                        }
                    if model not in found:
                        print(model, "found")
                        found |= set([model])
                except:
                    # print(result_file, " not found")
                    continue


dsbmm_dpf.z-onlyold_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-onlypres_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-jointold_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-jointold_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-theta-jointpres_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-jointpres_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-theta-joint-ndcold_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-joint-ndcold_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-theta-joint-ndcpres_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-joint-ndcpres_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-onlyold_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-onlypres_subs_ewcnone_rcolmain_ctry_1hot found
unadjusted.mainold_subs_ewcnone_rcolmain_adm1_1hot found
unadjusted.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
network_pref_only.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
topic_only.mainpres_subs_ewcnone_rcolmain_adm1

### Visualize results

In [289]:
confounding_type = "exog"
models = list(exp_results[(confounding_type, confounding_strengths[1])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df1, alt_df1, min_inds1, mode_inds1, alt_min_inds, alt_mode_inds1 = print_table(exp_results, regime1, models)


In [290]:
confounding_type = "homophily"
models = list(exp_results[(confounding_type, confounding_strengths[0])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df2, alt_df2, min_inds2, mode_inds2, alt_min_inds2, alt_mode_inds2 = print_table(exp_results, regime1, models)


In [291]:
confounding_type = "both"
models = list(exp_results[(confounding_type, confounding_strengths[0])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df3, alt_df3, min_inds3, mode_inds3, alt_min_inds3, alt_mode_inds3 = print_table(exp_results, regime1, models)


In [292]:
all_results = pd.concat([df1, df2, df3], axis=1, keys=["Exog.", "Homophily", "Both"])


In [293]:
# all_results.index = all_results.index.str.wrap(15)
all_results.sort_index(inplace=True)
out_index = [idx for idx in all_results.index if 'oracle' in idx.lower()] + [idx for idx in all_results.index if 'oracle' not in idx.lower()]
all_results = all_results.loc[out_index]


In [294]:
all_results

Unnamed: 0_level_0,Exog.,Exog.,Exog.,Homophily,Homophily,Homophily,Both,Both,Both
Unnamed: 0_level_1,Low,Med.,High,Low,Med.,High,Low,Med.,High
Oracle-old-A,1.76$\pm$2.57,0.89$\pm$1.02,2.84$\pm$4.36,1704.66$\pm$2917.48,1168.83$\pm$2002.77,883.64$\pm$1519.55,2412.09$\pm$4132.59,1395.16$\pm$2394.54,1032.08$\pm$1772.62
Oracle-pres-A,1.8$\pm$2.56,0.94$\pm$0.99,2.77$\pm$4.1,2176.2$\pm$3713.89,1165.67$\pm$1996.35,875.8$\pm$1504.22,2463.84$\pm$4221.16,1386.93$\pm$2379.56,1016.0$\pm$1743.58
Topic-oracle-old-A,\textbf{0.37$\pm$0.48},\textbf{0.41$\pm$0.55},1.17$\pm$1.81,1620.0$\pm$2773.71,1229.15$\pm$2106.26,899.89$\pm$1543.15,2182.37$\pm$3736.97,1736.23$\pm$2977.05,997.67$\pm$1709.34
Topic-oracle-pres-A,0.5$\pm$0.44,0.59$\pm$0.49,1.28$\pm$1.68,1827.29$\pm$3139.73,1215.9$\pm$2083.28,901.04$\pm$1545.12,2197.91$\pm$3763.43,1764.0$\pm$3024.53,1009.04$\pm$1729.02
Net.-only-pres-A,2.79$\pm$2.08,1.72$\pm$0.81,1.74$\pm$0.66,2397.95$\pm$4115.04,\textbf{273.79$\pm$460.67},298.71$\pm$483.66,2062.72$\pm$3495.05,1139.92$\pm$1946.24,221.97$\pm$370.51
Ours-NDC-old-A,3.1$\pm$4.33,3.19$\pm$4.36,3.0$\pm$4.41,\textbf{271.72$\pm$756.79},510.98$\pm$1153.34,265.81$\pm$585.96,1507.93$\pm$3932.16,673.58$\pm$1522.78,464.68$\pm$1060.75
Ours-NDC-old-C,1.22$\pm$1.17,1.61$\pm$1.38,5.31$\pm$16.46,624.67$\pm$1431.33,277.1$\pm$736.24,\textbf{201.87$\pm$653.52},\textbf{885.37$\pm$3360.32},549.92$\pm$1486.13,\textbf{175.64$\pm$413.81}
Ours-NDC-pres-A,3.1$\pm$4.32,3.23$\pm$4.33,2.94$\pm$4.29,373.74$\pm$891.21,511.34$\pm$1153.19,265.89$\pm$585.92,1508.04$\pm$3932.12,673.69$\pm$1522.73,464.84$\pm$1060.68
Ours-NDC-pres-C,1.2$\pm$1.18,2.37$\pm$3.42,5.8$\pm$16.48,742.37$\pm$1723.36,277.2$\pm$736.21,201.98$\pm$653.49,885.53$\pm$3360.28,550.09$\pm$1486.07,175.77$\pm$413.76
Ours-no-meta-old-A,2.7$\pm$3.62,4.52$\pm$7.12,7.32$\pm$17.18,1506.4$\pm$4375.16,291.44$\pm$616.93,204.26$\pm$435.05,962.55$\pm$2437.77,447.6$\pm$948.98,241.53$\pm$521.82


In [295]:
all_results.to_latex("./results/semi-synth.tex", escape=False)


In [296]:
alt_all_results = pd.concat(
    [alt_df1, alt_df2, alt_df3], axis=1, keys=["Exog.", "Homophily", "Both"]
)
# alt_all_results.index=alt_all_results.index.str.wrap(15)


In [297]:
# alt_all_results[alt_all_results == alt_all_results.min(axis=0)] = (
#     "\textbf{" + alt_all_results[alt_all_results == alt_all_results.min(axis=0)] + "}"
# )


In [298]:
# all_results.index = all_results.index.str.wrap(15)
alt_all_results.sort_index(inplace=True)
out_index = [idx for idx in alt_all_results.index if 'oracle' in idx.lower()] + [idx for idx in alt_all_results.index if 'oracle' not in idx.lower()]
alt_all_results = alt_all_results.loc[out_index]


In [299]:
alt_all_results

Unnamed: 0_level_0,Exog.,Exog.,Exog.,Homophily,Homophily,Homophily,Both,Both,Both
Unnamed: 0_level_1,Low,Med.,High,Low,Med.,High,Low,Med.,High
Oracle-old-A,0.28$\pm$0.2,0.3$\pm$0.2,0.33$\pm$0.22,20.31$\pm$28.2,12.56$\pm$17.45,6.33$\pm$8.72,26.21$\pm$36.32,12.7$\pm$17.61,8.67$\pm$12.01
Oracle-pres-A,0.32$\pm$0.2,0.37$\pm$0.18,0.41$\pm$0.22,32.1$\pm$44.13,13.11$\pm$16.71,7.35$\pm$7.83,26.83$\pm$36.57,13.12$\pm$16.76,9.36$\pm$10.77
Topic-oracle-old-A,\textbf{0.09$\pm$0.04},\textbf{0.1$\pm$0.1},\textbf{0.13$\pm$0.06},18.64$\pm$23.34,13.12$\pm$14.65,8.96$\pm$7.98,24.88$\pm$32.21,17.46$\pm$20.82,10.79$\pm$10.71
Topic-oracle-pres-A,0.25$\pm$0.11,0.31$\pm$0.15,0.32$\pm$0.15,14.59$\pm$17.63,13.14$\pm$14.66,8.97$\pm$8.03,25.16$\pm$32.52,17.82$\pm$21.16,10.8$\pm$10.64
Net.-only-pres-A,2.34$\pm$2.22,1.35$\pm$0.56,1.43$\pm$0.45,22.18$\pm$28.39,7.82$\pm$2.16,19.66$\pm$20.61,44.9$\pm$28.96,16.28$\pm$16.66,8.06$\pm$2.7
Ours-NDC-old-A,2.03$\pm$3.4,2.11$\pm$3.73,1.88$\pm$3.29,23.63$\pm$51.03,9.21$\pm$12.08,8.18$\pm$10.22,11.54$\pm$16.69,9.29$\pm$12.88,8.96$\pm$10.94
Ours-NDC-old-C,0.85$\pm$0.99,1.15$\pm$1.25,1.39$\pm$1.61,\textbf{6.3$\pm$9.49},\textbf{5.49$\pm$6.42},5.32$\pm$4.86,\textbf{6.56$\pm$7.34},\textbf{5.15$\pm$5.17},\textbf{5.39$\pm$5.04}
Ours-NDC-pres-A,2.03$\pm$3.4,2.17$\pm$3.69,1.81$\pm$3.04,21.43$\pm$42.9,9.69$\pm$12.33,8.28$\pm$10.18,11.68$\pm$16.73,9.47$\pm$12.91,9.18$\pm$11.02
Ours-NDC-pres-C,0.83$\pm$1.0,2.16$\pm$3.91,2.04$\pm$2.94,8.64$\pm$16.11,5.7$\pm$6.51,5.47$\pm$4.94,6.77$\pm$7.51,5.37$\pm$5.26,5.56$\pm$5.11
Ours-no-meta-old-A,1.6$\pm$2.94,2.43$\pm$4.93,2.12$\pm$2.99,8.14$\pm$10.87,11.17$\pm$16.29,8.58$\pm$9.92,12.37$\pm$14.99,9.52$\pm$10.9,9.87$\pm$11.12


In [300]:
alt_all_results.to_latex("./results/alt-semi-synth.tex", escape=False)


# Visualise PPC results

In [414]:
import pickle

# load up PPC results
dsbmm_ppc_results = []
dpf_ppc_results = []
dpf_auc_results = []
for exp_idx in range(20):
    try:
        with open(res_dir / f"dsbmm_ppc_results_sim{exp_idx}.pkl", "rb") as f:
            tmp_dsbmm_ppc_results = pickle.load(f)
        with open(res_dir / f"dpf_ppc_results_sim{exp_idx}.pkl", "rb") as f:
            tmp_dpf_ppc_results = pickle.load(f)
        with open(res_dir / f"dpf_auc_results_sim{exp_idx}.pkl", "rb") as f:
            tmp_dpf_auc_results = pickle.load(f)
        dsbmm_ppc_results.append(tmp_dsbmm_ppc_results)
        dpf_ppc_results.append(tmp_dpf_ppc_results)
        dpf_auc_results.append(tmp_dpf_auc_results)
    except FileNotFoundError:
        print(f"Sim {exp_idx} results not found")
dsbmm_ppc_results = np.stack(dsbmm_ppc_results,axis=0)
dpf_ppc_results = np.stack(dpf_ppc_results,axis=0)
dpf_auc_results = np.stack(dpf_auc_results,axis=0)


Sim 5 results not found
Sim 16 results not found
Sim 17 results not found
Sim 18 results not found
Sim 19 results not found


In [None]:
# in shape (n_exps,n_Q_tested)
# where each experiment is a subsample of the data using a different seed
# each value is then n_pos / n_repls, where n_pos counts the number of
# replicates in which the likelihood of observing the replicated data
# was greater than observing the held-out data, after fitting on the
# remaining data
dpf_ppc_results


In [None]:
dpf_ppc_results.mean(axis=-1)

In [None]:
dpf_auc_results.mean(axis=-1)


In [447]:
tmp = np.stack(
    [
        np.array(
            [
                dsbmm_ppc_results.mean(axis=-1)[
                    ~np.isin(dsbmm_ppc_results.mean(axis=-1)[:, col], [0, 1]), col
                ].mean()
                for col in range(3)
            ]
        ),
        dpf_ppc_results.mean(axis=(0, -1)),
    ]
)


In [448]:
ppc_df = pd.DataFrame(tmp.T, columns=["$A$", "$Y$"])
ppc_df["$Q$"] = [4, 9, 16]
ppc_df["$K$"] = [5, 8, 10]
ppc_df[["$Y$","$A$"]] = np.round(ppc_df[["$Y$","$A$"]],3)
ppc_df[["$K$", "$Y$"]].to_latex("./results/topic-synth-ppcs.tex", escape=False)
ppc_df[["$Q$", "$A$"]].to_latex("./results/auth-synth-ppcs.tex", escape=False)


In [None]:
ppc_df[["$K$", "$Y$", "$Q$", "$A$"]]


In [342]:
from pif_dsbmm_dpf.citation.predictive_check import calculate_ppc_dsbmm, calculate_ppc_dpf, mask_topics

exp_idx = 0
Q = 4
K = 5
dsbmm_datadir = res_dir.parent / "dsbmm_data"
dpf_res_dir = res_dir.parent / "dpf_results"
with open(dsbmm_datadir / f"dsbmmppc_runsim_model_{exp_idx}_Q{Q}_subs.pkl", "rb") as f:
    node_probs, Z_trans, block_probs = pickle.load(f)
with open(dpf_res_dir / f"dpfppc_runsim_model_{exp_idx}_K{K}.pkl",'rb') as f:
    W_hat, Theta_hat = pickle.load(f)
with open(res_dir.parent / f"sim_model_{exp_idx}.pkl", "rb") as f:
    sim_model = pickle.load(f)


In [343]:
np.random.seed(exp_idx)
Y = sim_model.make_multi_covariate_simulation(
    noise=10.0, confounding_strength=50.0, confounding_to_use="both"
)
A = sim_model.A
N = Y[0].shape[0]
M = Y[0].shape[1]
T = len(Y)


Saving semi-synth data to /scratch/fitzgeraldj/data/caus_inf_data/sim_model_0.pkl


In [402]:
beta = sim_model.make_simulated_influence()

In [405]:
[A_t.mean() for A_t in A]

[0.005825739717210655,
 0.015015957394327081,
 0.03330797737524349,
 0.049230982624319505]

In [413]:
for t in range(4):
    print(((sim_model.beta[:,np.newaxis,t]*sim_model.A[t])@Y[t]).mean())

1.3141275265773484
12.054698304734837
109.85377671181134
1576.6025810774593


In [344]:
masked_friends = [mask_topics(N, N) for _ in range(T - 1)]
masked_tpcs = [mask_topics(N,M) for _ in range(T-1)]
aus = np.arange(N, dtype=int)
masked_friends = [(aus.copy(), mf) for mf in masked_friends]
masked_tpcs = [(aus.copy(),mt) for mt in masked_tpcs]



In [345]:
import time 
np.random.seed(int(time.time()) % 2**32)
A_ll_heldout, A_ll_repl, e_rates = calculate_ppc_dsbmm(masked_friends,A,node_probs,block_probs,ret_rates=True)
X_ll_heldout, X_ll_repl = calculate_ppc_dpf(masked_tpcs,Y[:-1],Theta_hat,W_hat)

In [348]:
W_hat.shape

(1000, 4, 5)

In [None]:
x_ppc = np.zeros(4)
av_diff = 0.0
for _ in range(100):
    np.random.seed(int(time.time()) % 2**32)
    X_ll_heldout, X_ll_repl = calculate_ppc_dpf(masked_tpcs,Y[:-1],Theta_hat/2,W_hat/2)
    x_ppc[X_ll_repl > X_ll_heldout] += 1.0 
    av_diff += (X_ll_repl - X_ll_heldout).mean()
print(x_ppc / 100)
print(av_diff)

In [354]:
dpf_rates = (Theta_hat[:,np.newaxis,...] * W_hat[np.newaxis,...]).sum(axis=-1)

In [358]:
dpf_rates -= dpf_rates.mean(axis=1,keepdims=True)

In [361]:
dpf_rates[dpf_rates<0] = 0

In [None]:
from scipy.stats import poisson
sub_rates = [dpf_rates[mt[0],mt[1],t*np.ones(len(mt[0]),dtype=int)] for t,mt in enumerate(masked_tpcs)]
x_ppc = np.zeros(4)
alt_ppc = np.zeros(4)
for _ in range(100): 
    np.random.seed(int(time.time()) % 2**32)
    for t in range(4):
        repl = poisson.rvs(sub_rates[t])
        repl_ll = poisson.logpmf(repl>0,sub_rates[t]).sum()
        ho_ll = poisson.logpmf(Y[t][masked_tpcs[t][0],masked_tpcs[t][1]]>0,sub_rates[t]).sum()
        x_ppc[t] += repl_ll > ho_ll
    alt_ho_ll, alt_repl_ll = calculate_ppc_dpf(masked_tpcs,[Yt > 0 for Yt in Y[:-1]],Theta_hat,W_hat)
    alt_ppc[alt_repl_ll>alt_ho_ll] += 1.0
print(x_ppc / 100)
print(alt_ppc / 100)

In [None]:
tmp_rates = (np.exp(Theta_hat[:,np.newaxis,...]/2 )* np.exp(W_hat[np.newaxis,...]/2)).sum(axis=-1)
tmp_rates = [tmp_rates[mt[0],mt[1],t*np.ones(len(mt[0]),dtype=int)] for t,mt in enumerate(masked_tpcs)]
x_ppc = np.zeros(4)
for _ in range(100):
    for t in range(4):
        repl = poisson.rvs(tmp_rates[t])
        repl_ll = poisson.logpmf(repl,tmp_rates[t]).sum()
        ho = Y[t][masked_tpcs[t][0],masked_tpcs[t][1]]
        print(f"repl: {repl.mean():.2f}; ho: {ho.mean():.2f}")
        ho_ll = poisson.logpmf(ho,tmp_rates[t]).sum()
        x_ppc[t] += repl_ll > ho_ll
    
print(x_ppc / 100)  

In [386]:
np.stack([Yt.toarray() for Yt in Y[:-1]]).mean()

25.533953806870937

In [None]:
np.unique(Y[t][masked_tpcs[t][0],masked_tpcs[t][1]])

In [74]:
in_degs = A[-1].sum(axis=0)
out_degs = A[-1].sum(axis=1)
# in_degs[in_degs == 0] = 1.0
# out_degs[out_degs == 0] = 1.0
pos_nodes = (in_degs > 0) & (out_degs > 0)
e_rates = np.einsum(
    "iq,qr,jr->ij",
    out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    block_probs[..., -1],
    in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
)


In [78]:
good_samp_idxs = np.isin(masked_friends[-1][1], pos_nodes)
tmp_j = masked_friends[-1][1][good_samp_idxs]
tmp_i = masked_friends[-1][0][good_samp_idxs]
print(f"Sampled w no out citations: {(A[-1][tmp_j,:].sum(axis=1) == 0).sum()}")
print(f"Sampled w no in citations: {(A[-1][:,tmp_j].sum(axis=0) == 0).sum()}")


Sampled w no out citations: 0
Sampled w no in citations: 0


In [117]:
# ah this was using incorrectly calculated block probs -- reupdate directly instead
tp_marg = np.einsum(
    "iq,ijqr,jr->ijqr",
    node_probs[:, -1, :],
    poisson.pmf(
        A[-1].toarray()[..., np.newaxis, np.newaxis],
        out_degs[:, np.newaxis, np.newaxis, np.newaxis]
        * in_degs[np.newaxis, :, np.newaxis, np.newaxis]
        * block_probs[..., -1][np.newaxis, np.newaxis],
    ),
    node_probs[:, -1, :],
)

In [118]:
tp_marg /= tp_marg.sum(axis=(-2,-1),keepdims=True)

  tp_marg /= tp_marg.sum(axis=(-2,-1),keepdims=True)


In [146]:
eff_block_probs_num = np.nansum(
    tp_marg[np.ix_(pos_nodes,pos_nodes)]*(A[-1].toarray()[np.ix_(pos_nodes,pos_nodes)][...,np.newaxis,np.newaxis]),axis=(0,1)
)
eff_block_probs_denom = np.einsum(
    "iq,jr->qr",
    out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
)
eff_block_probs = np.divide(
    eff_block_probs_num,
    eff_block_probs_denom,
    where=eff_block_probs_denom > 0,
    out=np.zeros_like(eff_block_probs_denom),
)


In [107]:
A[-1].nnz / (A[-1].shape[0] ** 2)


0.04923098262431963

In [150]:
eff_e_rates = np.einsum(
    "iq,qr,jr->ij",
    out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    eff_block_probs,
    in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
)

In [151]:
from scipy.stats import poisson

heldout = A[-1][masked_friends[-1][0],masked_friends[-1][1]]
ppc = 0.0
repls = 100
rev_idx = np.zeros(N,dtype=int)
rev_idx[pos_nodes] = np.arange(pos_nodes.sum(),dtype=int)
sub_mask_i = masked_friends[-1][0]
sub_mask_j = masked_friends[-1][1]
joint_mask = np.isin(sub_mask_i,np.flatnonzero(pos_nodes)) & np.isin(sub_mask_j,np.flatnonzero(pos_nodes))
sub_mask_i = sub_mask_i[joint_mask]
sub_mask_j = sub_mask_j[joint_mask]

subset_rates = eff_e_rates[rev_idx[sub_mask_i],rev_idx[sub_mask_j]]
for _ in range(repls):
    samps = poisson.rvs(subset_rates)
    ho_ll = poisson.logpmf(heldout[joint_mask],subset_rates).sum()
    rep_ll = poisson.logpmf(samps,subset_rates).sum()
    if rep_ll > ho_ll:
        ppc += 1.0
print(ppc/repls)


0.32


## Recalc PPCs for incorrect block prob samples 
- NB will only work at all for smaller numbers of groups, where dividing by n_descs is less of a problem

In [171]:
from scipy.stats import poisson
import time
from pif_dsbmm_dpf.citation.predictive_check import calculate_ppc_dsbmm, mask_topics

exp_idx = 0
dsbmm_datadir = res_dir.parent / "dsbmm_data"
with open(res_dir.parent / f"sim_model_{exp_idx}.pkl", "rb") as f:
    sim_model = pickle.load(f)
np.random.seed(exp_idx)
Y = sim_model.make_multi_covariate_simulation(
    noise=10.0, confounding_strength=50.0, confounding_to_use="both"
)
A = sim_model.A
N = Y[0].shape[0]
M = Y[0].shape[1]
T = len(Y)
masked_friends = [mask_topics(N, N) for _ in range(T - 1)]
aus = np.arange(N, dtype=int)
masked_friends = [(aus.copy(), mf) for mf in masked_friends]
heldout = A[-1][masked_friends[-1][0],masked_friends[-1][1]]
in_degs = A[-1].sum(axis=0)
out_degs = A[-1].sum(axis=1)

pos_nodes = (in_degs > 0) & (out_degs > 0)
Qs = [4,9,16]
ppc_scores = np.zeros(len(Qs))
for q_idx,Q in enumerate(Qs):
    with open(dsbmm_datadir / f"dsbmmppc_runsim_model_{exp_idx}_Q{Q}_subs.pkl", "rb") as f:
        node_probs, Z_trans, block_probs = pickle.load(f)

    np.random.seed(int(time.time()) % 2**32)
    
    
    
    tp_marg = np.einsum(
        "iq,ijqr,jr->ijqr",
        node_probs[:, -1, :],
        poisson.pmf(
            A[-1].toarray()[..., np.newaxis, np.newaxis],
            out_degs[:, np.newaxis, np.newaxis, np.newaxis]
            * in_degs[np.newaxis, :, np.newaxis, np.newaxis]
            * block_probs[..., -1][np.newaxis, np.newaxis],
        ),
        node_probs[:, -1, :],
    )
    tp_marg_sums = np.nansum(tp_marg,axis=(-2,-1),keepdims=True)
    tp_marg = np.divide(tp_marg,tp_marg_sums,where=tp_marg_sums>0,out=np.zeros_like(tp_marg))
    eff_block_probs_num = np.nansum(
        tp_marg[np.ix_(pos_nodes,pos_nodes)]*(A[-1].toarray()[np.ix_(pos_nodes,pos_nodes)][...,np.newaxis,np.newaxis]),axis=(0,1)
    )
    eff_block_probs_denom = np.einsum(
        "iq,jr->qr",
        out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
        in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    )
    eff_block_probs = np.divide(
        eff_block_probs_num,
        eff_block_probs_denom,
        where=eff_block_probs_denom > 0,
        out=np.zeros_like(eff_block_probs_denom),
    )

    eff_e_rates = np.einsum(
        "iq,qr,jr->ij",
        out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
        eff_block_probs,
        in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    )
    ppc = 0.0
    repls = 100
    rev_idx = np.zeros(N,dtype=int)
    rev_idx[pos_nodes] = np.arange(pos_nodes.sum(),dtype=int)
    sub_mask_i = masked_friends[-1][0]
    sub_mask_j = masked_friends[-1][1]
    joint_mask = np.isin(sub_mask_i,np.flatnonzero(pos_nodes)) & np.isin(sub_mask_j,np.flatnonzero(pos_nodes))
    sub_mask_i = sub_mask_i[joint_mask]
    sub_mask_j = sub_mask_j[joint_mask]

    subset_rates = eff_e_rates[rev_idx[sub_mask_i],rev_idx[sub_mask_j]]
    for _ in range(repls):
        samps = poisson.rvs(subset_rates)
        ho_ll = poisson.logpmf(heldout[joint_mask],subset_rates).sum()
        rep_ll = poisson.logpmf(samps,subset_rates).sum()
        if rep_ll > ho_ll:
            ppc += 1.0
    ppc_scores[q_idx] = ppc/repls


Saving semi-synth data to /scratch/fitzgeraldj/data/caus_inf_data/sim_model_0.pkl


In [279]:
exp_idx = 10 
K = 5

with open((res_dir.parent / "dpf_results") / f"dpfppc_runsim_model_{exp_idx}_K{K}.pkl",'rb') as f:
    W_hat, Theta_hat = pickle.load(f)

In [280]:
Theta_hat.shape

(3258, 3, 5)

In [281]:
W_hat.shape

(1000, 3, 5)