In [1]:
import numpy as np
import os
from sklearn.metrics import mean_squared_error as mse_f
from scipy import sparse
from scipy.stats import gamma
from scipy.stats import ttest_ind
import warnings
import pandas as pd

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


In [2]:
paper_model_names = models = {
    "dsbmm_dpf.z-only": "Ours-no-meta",
    "dsbmm_dpf.z-theta-joint": "Ours",
    # 'spf.main':'MSPF',
    "unadjusted.main": "Unadjusted",
    "network_pref_only.main": "Net.-only",
    "topic_only.main": "Topic-only",
    "no_unobs.main": "Oracle",
    "topic_only_oracle.main": "Topic-oracle",
}


In [7]:
tmp_kv = tuple(paper_model_names.items())
for k, v in tmp_kv:
    if "dsbmm_dpf" in k:
        paper_model_names[k + "-ndc"] = v + "-NDC"

tmp_dict = {
    k + "old_subs" + "_ewcnone" + "_rcolmain_adm1_1hot": v + "-old-A"
    for k, v in paper_model_names.items()
    if "dsbmm_dpf" not in k
}
tmp_dict.update(
    {
        k + "pres_subs" + "_ewcnone" + "_rcolmain_adm1_1hot": v + "-pres-A"
        for k, v in paper_model_names.items()
        if "dsbmm_dpf" not in k
    }
)
sub_choice_pretty = {"old_subs": "-old", "pres_subs": "-pres"}
reg_choice_pretty = {"adm1": "-A", "ctry": "-C"}
for k, v in paper_model_names.items():
    if "dsbmm_dpf" in k:
        for sub_choice in ["old_subs", "pres_subs"]:
            for region in ["adm1", "ctry"]:
                tmp_dict[k + f"{sub_choice}_ewcnone_rcolmain_{region}_1hot"] = (
                    v + sub_choice_pretty[sub_choice] + reg_choice_pretty[region]
                )


In [8]:
full_paper_model_names = tmp_dict


In [9]:
def print_table(exp_results, regimes, models, exps=10, print_notfound=False):
    ncols = len(regimes.keys())
    nrows = len(models)
    results = np.zeros((nrows, ncols))
    std = np.zeros((nrows, ncols))
    alt_results = np.zeros((nrows, ncols))
    alt_std = np.zeros((nrows, ncols))

    col_idx = 0
    for regime, c in regimes.items():
        row_idx = 0
        for model in models:
            mse = np.zeros((exps, 4))
            for i in range(exps):
                try:
                    beta_predicted = exp_results[c][model][i][0]
                    truth = exp_results[c][model][i][1]
                    sq_err = (beta_predicted - truth) ** 2
                    mse[i] = sq_err.mean(axis=0)
                except:
                    if print_notfound:
                        print(model, "exp", i, "not found")
            results[row_idx][col_idx] = round(mse.mean() * 1000, 2)
            std[row_idx][col_idx] = round(mse.std() * 1000, 2)

            alt_results[row_idx][col_idx] = round(mse[:, :-1].mean() * 1000, 2)
            alt_std[row_idx][col_idx] = round(mse[:, :-1].std() * 1000, 2)

            row_idx += 1
        col_idx += 1

    proper_names = [full_paper_model_names[m] for m in models]
    col_names = list(regimes.keys())
    df = pd.DataFrame(results, index=proper_names, columns=col_names, dtype=str)
    std_df = pd.DataFrame(std, index=proper_names, columns=col_names, dtype=str)
    df = df + "$\pm$" + std_df

    alt_df = pd.DataFrame(alt_results, index=proper_names, columns=col_names, dtype=str)
    alt_std_df = pd.DataFrame(alt_std, index=proper_names, columns=col_names, dtype=str)
    alt_df = alt_df + "$\pm$" + alt_std_df
    return df, alt_df


### Load results

In [154]:
from pathlib import Path

res_dir = Path("/scratch/fitzgeraldj/data/caus_inf_data/results")
exps = 5
# embed = "user"
sub_choices = ["old_subs", "pres_subs"]
regions = ["adm1", "ctry"]  # or "ctry" for dsbmm_dpf models
base_models = [
    "unadjusted.main",
    "network_pref_only.main",
    "topic_only.main",
    "no_unobs.main",
    "topic_only_oracle.main",
    "dsbmm_dpf.z-only",
    "dsbmm_dpf.z-theta-joint",
    "dsbmm_dpf.z-theta-joint-ndc",
]
models = [
    m + f"{sub_choice}_ewcnone_rcolmain_{region}_1hot"
    for m in base_models
    for sub_choice in sub_choices
    for region in regions
]

conf_types = ["homophily", "exog", "both"]
confounding_strengths = [(50, 10), (50, 50), (50, 100)]
exp_results = {}
found = set()

for i in range(1, exps + 1):
    for model in models:
        for (cov1conf, cov2conf) in confounding_strengths:
            for ct in conf_types:
                try:
                    base_file_name = (
                        "conf="
                        + str((cov1conf, cov2conf))
                        + ";conf_type="
                        + ct
                        + ".npz"
                    )
                    result_file = (
                        (res_dir / str(i)) / (model + "_model_fitted_params")
                    ) / base_file_name
                    res = np.load(result_file)
                    params = res["fitted"]
                    truth = res["true"]

                    if (ct, (cov1conf, cov2conf)) in exp_results:
                        if model in exp_results[(ct, (cov1conf, cov2conf))]:
                            exp_results[(ct, (cov1conf, cov2conf))][model].append(
                                (params, truth)
                            )
                        else:
                            exp_results[(ct, (cov1conf, cov2conf))][model] = [
                                (params, truth)
                            ]
                    else:
                        exp_results[(ct, (cov1conf, cov2conf))] = {
                            model: [(params, truth)]
                        }
                    if model not in found:
                        print(model, "found")
                        found |= set([model])
                except:
                    # print(result_file, " not found")
                    continue


dsbmm_dpf.z-onlyold_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-onlypres_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-jointold_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-jointold_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-theta-jointpres_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-jointpres_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-theta-joint-ndcold_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-theta-joint-ndcpres_subs_ewcnone_rcolmain_adm1_1hot found
dsbmm_dpf.z-onlyold_subs_ewcnone_rcolmain_ctry_1hot found
dsbmm_dpf.z-onlypres_subs_ewcnone_rcolmain_ctry_1hot found
unadjusted.mainold_subs_ewcnone_rcolmain_adm1_1hot found
unadjusted.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
network_pref_only.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
topic_only.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
no_unobs.mainold_subs_ewcnone_rcolmain_adm1_1hot found
no_unobs.mainpres_subs_ewcnone_rcolmain_adm1_1hot found
topic_only_oracl

### Visualize results

In [156]:
confounding_type = "exog"
models = list(exp_results[(confounding_type, confounding_strengths[1])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df1, alt_df1 = print_table(exp_results, regime1, models)


In [157]:
confounding_type = "homophily"
models = list(exp_results[(confounding_type, confounding_strengths[0])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df2, alt_df2 = print_table(exp_results, regime1, models)


In [158]:
confounding_type = "both"
models = list(exp_results[(confounding_type, confounding_strengths[0])].keys())
regime1 = {
    "Low": (confounding_type, confounding_strengths[0]),
    "Med.": (confounding_type, confounding_strengths[1]),
    "High": (confounding_type, confounding_strengths[2]),
}

df3, alt_df3 = print_table(exp_results, regime1, models)


In [159]:
all_results = pd.concat([df1, df2, df3], axis=1, keys=["Exog.", "Homophily", "Both"])


In [160]:
all_results[all_results == all_results.min(axis=0)] = (
    "\textbf{" + all_results[all_results == all_results.min(axis=0)] + "}"
)


In [165]:
all_results.index = all_results.index.str.wrap(15)


In [166]:
all_results.to_latex("./results/semi-synth.tex", escape=False)


In [167]:
alt_all_results = pd.concat(
    [alt_df1, alt_df2, alt_df3], axis=1, keys=["Exog.", "Homophily", "Both"]
)
alt_all_results.index=alt_all_results.index.str.wrap(15)


In [168]:
alt_all_results[alt_all_results == alt_all_results.min(axis=0)] = (
    "\textbf{" + alt_all_results[alt_all_results == alt_all_results.min(axis=0)] + "}"
)


In [169]:
alt_all_results


Unnamed: 0_level_0,Exog.,Exog.,Exog.,Homophily,Homophily,Homophily,Both,Both,Both
Unnamed: 0_level_1,Low,Med.,High,Low,Med.,High,Low,Med.,High
Ours-no-meta-\nold-A,0.8$\pm$2.23,1.22$\pm$3.69,1.06$\pm$2.37,4.07$\pm$8.69,5.59$\pm$12.8,4.29$\pm$8.22,6.18$\pm$12.27,3.61$\pm$7.82,3.89$\pm$8.22
Ours-no-meta-\npres-A,0.8$\pm$2.23,1.21$\pm$3.69,1.06$\pm$2.37,4.18$\pm$9.03,5.58$\pm$12.82,4.27$\pm$8.21,6.19$\pm$12.28,3.59$\pm$7.8,3.87$\pm$8.21
Ours-old-A,0.8$\pm$2.22,1.23$\pm$3.68,1.15$\pm$2.39,5.86$\pm$11.87,4.45$\pm$8.62,4.3$\pm$8.28,6.22$\pm$12.31,4.81$\pm$9.17,4.94$\pm$9.35
Ours-old-C,0.81$\pm$2.24,1.24$\pm$3.72,1.15$\pm$2.41,6.34$\pm$12.21,4.59$\pm$8.64,4.43$\pm$8.32,6.37$\pm$12.28,4.94$\pm$9.16,5.09$\pm$9.39
Ours-pres-A,0.92$\pm$2.28,1.25$\pm$3.72,1.16$\pm$2.42,8.51$\pm$21.57,4.92$\pm$8.89,4.43$\pm$8.67,6.37$\pm$12.32,5.11$\pm$9.46,5.24$\pm$9.75
Ours-pres-C,0.89$\pm$2.27,1.23$\pm$3.65,1.16$\pm$2.4,7.71$\pm$17.74,5.29$\pm$11.0,4.83$\pm$9.72,6.42$\pm$12.43,5.19$\pm$9.73,5.48$\pm$10.66
Ours-no-meta-\nold-C,0.43$\pm$1.5,0.5$\pm$1.41,0.64$\pm$1.67,3.97$\pm$10.17,2.34$\pm$6.11,1.97$\pm$4.34,4.16$\pm$10.8,2.42$\pm$5.86,2.55$\pm$5.97
Ours-no-meta-\npres-C,0.41$\pm$1.5,0.49$\pm$1.41,0.64$\pm$1.67,1.96$\pm$4.37,3.11$\pm$8.12,1.95$\pm$4.32,4.13$\pm$10.77,2.4$\pm$5.83,2.52$\pm$5.93
Ours-NDC-pres-A,0.22$\pm$0.92,0.22$\pm$0.97,0.11$\pm$0.37,10.72$\pm$32.17,3.99$\pm$9.56,0.47$\pm$1.6,,,
Unadjusted-\nold-A,0.07$\pm$0.22,0.11$\pm$0.34,0.1$\pm$0.33,1.43$\pm$5.4,1.44$\pm$5.46,1.18$\pm$4.13,2.64$\pm$11.58,1.71$\pm$6.77,1.42$\pm$5.3


In [20]:
alt_all_results.to_latex("./results/alt-semi-synth.tex", escape=False)


In [21]:
import pickle

# load up PPC results
with open(res_dir / "dsbmm_ppc_results.pkl", "rb") as f:
    dsbmm_ppc_results = pickle.load(f)
with open(res_dir / "dpf_ppc_results.pkl", "rb") as f:
    dpf_ppc_results = pickle.load(f)
with open(res_dir / "dpf_auc_results.pkl", "rb") as f:
    dpf_auc_results = pickle.load(f)


In [22]:
# in shape (n_exps,n_Q_tested)
# where each experiment is a subsample of the data using a different seed
# each value is then n_pos / n_repls, where n_pos counts the number of
# replicates in which the likelihood of observing the replicated data
# was greater than observing the held-out data, after fitting on the
# remaining data
dpf_ppc_results


array([[[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0

In [23]:
dpf_auc_results


array([[[0.50910372, 0.5927523 , 0.72340592, 0.83475554],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
    

In [24]:
dsbmm_ppc_results


array([[[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0

In [223]:
tmp = np.stack([dsbmm_ppc_results.mean(axis=0), dpf_ppc_results.mean(axis=0)])


In [226]:
ppc_df = pd.DataFrame(tmp.T, columns=["$A$", "$Y$"])
ppc_df["$Q$"] = [4, 9, 16]
ppc_df["$K$"] = [3, 5, 8]
ppc_df[["$K$", "$Y$"]].to_latex("./results/topic-synth-ppcs.tex", escape=False)
ppc_df[["$Q$", "$A$"]].to_latex("./results/auth-synth-ppcs.tex", escape=False)


In [228]:
ppc_df[["$K$", "$Y$", "$Q$", "$A$"]]


Unnamed: 0,$K$,$Y$,$Q$,$A$
0,3,0.05,4,0.05
1,5,0.05,9,0.05
2,8,0.05,16,0.05


In [25]:
from pif_dsbmm_dpf.citation.predictive_check import calculate_ppc_dsbmm, mask_topics

exp_idx = 0
Q = 4
dsbmm_datadir = res_dir.parent / "dsbmm_data"
with open(dsbmm_datadir / f"dsbmmppc_runsim_model_{exp_idx}_Q{Q}_subs.pkl", "rb") as f:
    node_probs, Z_trans, block_probs = pickle.load(f)
with open(res_dir.parent / f"sim_model_{exp_idx}.pkl", "rb") as f:
    sim_model = pickle.load(f)


In [26]:
np.random.seed(exp_idx)
Y = sim_model.make_multi_covariate_simulation(
    noise=10.0, confounding_strength=50.0, confounding_to_use="both"
)
A = sim_model.A
N = Y[0].shape[0]
M = Y[0].shape[1]
T = len(Y)


Saving semi-synth data to /scratch/fitzgeraldj/data/caus_inf_data/sim_model_0.pkl


In [27]:
masked_friends = [mask_topics(N, N) for _ in range(T - 1)]
aus = np.arange(N, dtype=int)
masked_friends = [(aus.copy(), mf) for mf in masked_friends]


In [35]:
import time 
np.random.seed(int(time.time()) % 2**32)
A_ll_heldout, A_ll_repl, e_rates = calculate_ppc_dsbmm(masked_friends,A,node_probs,block_probs,ret_rates=True)

In [74]:
in_degs = A[-1].sum(axis=0)
out_degs = A[-1].sum(axis=1)
# in_degs[in_degs == 0] = 1.0
# out_degs[out_degs == 0] = 1.0
pos_nodes = (in_degs > 0) & (out_degs > 0)
e_rates = np.einsum(
    "iq,qr,jr->ij",
    out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    block_probs[..., -1],
    in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
)


In [78]:
good_samp_idxs = np.isin(masked_friends[-1][1], pos_nodes)
tmp_j = masked_friends[-1][1][good_samp_idxs]
tmp_i = masked_friends[-1][0][good_samp_idxs]
print(f"Sampled w no out citations: {(A[-1][tmp_j,:].sum(axis=1) == 0).sum()}")
print(f"Sampled w no in citations: {(A[-1][:,tmp_j].sum(axis=0) == 0).sum()}")


Sampled w no out citations: 0
Sampled w no in citations: 0


In [117]:
# ah this was using incorrectly calculated block probs -- reupdate directly instead
tp_marg = np.einsum(
    "iq,ijqr,jr->ijqr",
    node_probs[:, -1, :],
    poisson.pmf(
        A[-1].toarray()[..., np.newaxis, np.newaxis],
        out_degs[:, np.newaxis, np.newaxis, np.newaxis]
        * in_degs[np.newaxis, :, np.newaxis, np.newaxis]
        * block_probs[..., -1][np.newaxis, np.newaxis],
    ),
    node_probs[:, -1, :],
)

In [118]:
tp_marg /= tp_marg.sum(axis=(-2,-1),keepdims=True)

  tp_marg /= tp_marg.sum(axis=(-2,-1),keepdims=True)


In [146]:
eff_block_probs_num = np.nansum(
    tp_marg[np.ix_(pos_nodes,pos_nodes)]*(A[-1].toarray()[np.ix_(pos_nodes,pos_nodes)][...,np.newaxis,np.newaxis]),axis=(0,1)
)
eff_block_probs_denom = np.einsum(
    "iq,jr->qr",
    out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
)
eff_block_probs = np.divide(
    eff_block_probs_num,
    eff_block_probs_denom,
    where=eff_block_probs_denom > 0,
    out=np.zeros_like(eff_block_probs_denom),
)


In [107]:
A[-1].nnz / (A[-1].shape[0] ** 2)


0.04923098262431963

In [150]:
eff_e_rates = np.einsum(
    "iq,qr,jr->ij",
    out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    eff_block_probs,
    in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
)

In [151]:
from scipy.stats import poisson

heldout = A[-1][masked_friends[-1][0],masked_friends[-1][1]]
ppc = 0.0
repls = 100
rev_idx = np.zeros(N,dtype=int)
rev_idx[pos_nodes] = np.arange(pos_nodes.sum(),dtype=int)
sub_mask_i = masked_friends[-1][0]
sub_mask_j = masked_friends[-1][1]
joint_mask = np.isin(sub_mask_i,np.flatnonzero(pos_nodes)) & np.isin(sub_mask_j,np.flatnonzero(pos_nodes))
sub_mask_i = sub_mask_i[joint_mask]
sub_mask_j = sub_mask_j[joint_mask]

subset_rates = eff_e_rates[rev_idx[sub_mask_i],rev_idx[sub_mask_j]]
for _ in range(repls):
    samps = poisson.rvs(subset_rates)
    ho_ll = poisson.logpmf(heldout[joint_mask],subset_rates).sum()
    rep_ll = poisson.logpmf(samps,subset_rates).sum()
    if rep_ll > ho_ll:
        ppc += 1.0
print(ppc/repls)


0.32


## Recalc PPCs for incorrect block prob samples 
- NB will only work at all for smaller numbers of groups, where dividing by n_descs is less of a problem

In [171]:
from scipy.stats import poisson
import time
from pif_dsbmm_dpf.citation.predictive_check import calculate_ppc_dsbmm, mask_topics

exp_idx = 0
dsbmm_datadir = res_dir.parent / "dsbmm_data"
with open(res_dir.parent / f"sim_model_{exp_idx}.pkl", "rb") as f:
    sim_model = pickle.load(f)
np.random.seed(exp_idx)
Y = sim_model.make_multi_covariate_simulation(
    noise=10.0, confounding_strength=50.0, confounding_to_use="both"
)
A = sim_model.A
N = Y[0].shape[0]
M = Y[0].shape[1]
T = len(Y)
masked_friends = [mask_topics(N, N) for _ in range(T - 1)]
aus = np.arange(N, dtype=int)
masked_friends = [(aus.copy(), mf) for mf in masked_friends]
heldout = A[-1][masked_friends[-1][0],masked_friends[-1][1]]
in_degs = A[-1].sum(axis=0)
out_degs = A[-1].sum(axis=1)

pos_nodes = (in_degs > 0) & (out_degs > 0)
Qs = [4,9,16]
ppc_scores = np.zeros(len(Qs))
for q_idx,Q in enumerate(Qs):
    with open(dsbmm_datadir / f"dsbmmppc_runsim_model_{exp_idx}_Q{Q}_subs.pkl", "rb") as f:
        node_probs, Z_trans, block_probs = pickle.load(f)

    np.random.seed(int(time.time()) % 2**32)
    
    
    
    tp_marg = np.einsum(
        "iq,ijqr,jr->ijqr",
        node_probs[:, -1, :],
        poisson.pmf(
            A[-1].toarray()[..., np.newaxis, np.newaxis],
            out_degs[:, np.newaxis, np.newaxis, np.newaxis]
            * in_degs[np.newaxis, :, np.newaxis, np.newaxis]
            * block_probs[..., -1][np.newaxis, np.newaxis],
        ),
        node_probs[:, -1, :],
    )
    tp_marg_sums = np.nansum(tp_marg,axis=(-2,-1),keepdims=True)
    tp_marg = np.divide(tp_marg,tp_marg_sums,where=tp_marg_sums>0,out=np.zeros_like(tp_marg))
    eff_block_probs_num = np.nansum(
        tp_marg[np.ix_(pos_nodes,pos_nodes)]*(A[-1].toarray()[np.ix_(pos_nodes,pos_nodes)][...,np.newaxis,np.newaxis]),axis=(0,1)
    )
    eff_block_probs_denom = np.einsum(
        "iq,jr->qr",
        out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
        in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    )
    eff_block_probs = np.divide(
        eff_block_probs_num,
        eff_block_probs_denom,
        where=eff_block_probs_denom > 0,
        out=np.zeros_like(eff_block_probs_denom),
    )

    eff_e_rates = np.einsum(
        "iq,qr,jr->ij",
        out_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
        eff_block_probs,
        in_degs[pos_nodes, np.newaxis] * node_probs[pos_nodes, -1, :],
    )
    ppc = 0.0
    repls = 100
    rev_idx = np.zeros(N,dtype=int)
    rev_idx[pos_nodes] = np.arange(pos_nodes.sum(),dtype=int)
    sub_mask_i = masked_friends[-1][0]
    sub_mask_j = masked_friends[-1][1]
    joint_mask = np.isin(sub_mask_i,np.flatnonzero(pos_nodes)) & np.isin(sub_mask_j,np.flatnonzero(pos_nodes))
    sub_mask_i = sub_mask_i[joint_mask]
    sub_mask_j = sub_mask_j[joint_mask]

    subset_rates = eff_e_rates[rev_idx[sub_mask_i],rev_idx[sub_mask_j]]
    for _ in range(repls):
        samps = poisson.rvs(subset_rates)
        ho_ll = poisson.logpmf(heldout[joint_mask],subset_rates).sum()
        rep_ll = poisson.logpmf(samps,subset_rates).sum()
        if rep_ll > ho_ll:
            ppc += 1.0
    ppc_scores[q_idx] = ppc/repls


Saving semi-synth data to /scratch/fitzgeraldj/data/caus_inf_data/sim_model_0.pkl


In [None]:
ppc_scores

array([0.26, 1.  , 1.  ])