In [None]:
import os

N_CORES = 4
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={N_CORES}"


import jax.numpy as jnp
from jax import random, vmap
import jax
from jax.scipy.special import logit

from tqdm import tqdm
import numpy as np
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, leaves_list

import Simulations.data_gen as dg
# import data_gen as dg
from src.MWG_sampler import MWG_sampler, MWG_init
import src.utils as utils
from src.Models import cond_logpost_a_star, triu_star_grad_fn
import src.Models as models
import src.GWG as gwg

from time import time
import pandas as pd


import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['savefig.facecolor'] = 'white'



In [None]:
# Global variables (as in mwg_simulation.py)

N = 500
TRIU_DIM = N * (N - 1) // 2

THETA = jnp.array([-2.5, 1])
GAMMA_BASELINE = jnp.array([logit(0.95), logit(0.05)])

GAMMA_REP = jnp.array([logit(0.8), 1.5, logit(0.2), 1.5])
GAMMA_X_NOISES = jnp.arange(2, 4 + 1e-6, 0.5)

GAMMA_B_NOISE_0 = GAMMA_BASELINE[0] - GAMMA_X_NOISES / 2
GAMMA_B_NOISE_1 = GAMMA_BASELINE[1] + GAMMA_X_NOISES / 2

# ETA = jnp.array([-1, 3, -0.5, 2])
ETA = jnp.array([-1, 3, -0.5, 2])
SIG_INV = 1.0
# SIG_INV = 2.0
RHO = 0.5
PZ = 0.5


PARAM = {
    "theta": THETA,
    "eta": ETA,
    "rho": RHO,
    "sig_inv": SIG_INV,
}

# Take mid value of GAMMA_X_NOISES for illustration

cur_gamma = jnp.concatenate(
                [
                    jnp.array([GAMMA_B_NOISE_0[2]]),
                    jnp.array([GAMMA_B_NOISE_1[2]]),
                    jnp.array([GAMMA_X_NOISES[2]]),
                    # GAMMA_REP,
                ]
            )

print(cur_gamma)



In [None]:
# rng_key = random.PRNGKey(1159)
rng_key = random.PRNGKey(0)

# generate data (not depedent on gamma)
rng_key, _ = random.split(rng_key)
fixed_data = dg.generate_fixed_data(rng_key, N, PARAM, PZ)

# true_vals for wasserstein distance
true_vals = {
    "eta": ETA,
    "rho": jnp.array([RHO]),
    "sig_inv": jnp.array([SIG_INV]),
    "triu_star": fixed_data["triu_star"],
}

print(f"mean true exposures: {jnp.mean(fixed_data['true_exposures']):.3f}")

# generate new interventions
rng_key, _ = random.split(rng_key)
new_interventions = dg.new_interventions_estimands(
    rng_key, N, fixed_data["x"], fixed_data["triu_star"], ETA
)

print(f"mean dynamic estimands: {jnp.mean(new_interventions.estimand_h):.3f}")
print(f"mean stochastic estimands: {jnp.mean(new_interventions.estimand_stoch):.3f}")
print(f"mean gate estimands: {jnp.mean(new_interventions.estimand_gate):.3f}")

In [None]:
# plot true network degree distribution
matplotlib.rcParams['xtick.color'] = 'black'
matplotlib.rcParams['ytick.color'] = 'black'
matplotlib.rcParams['axes.labelcolor'] = 'black'
matplotlib.rcParams['axes.edgecolor'] = 'black'
matplotlib.rcParams['axes.titlesize'] = 'large'
matplotlib.rcParams['axes.labelsize'] = 'medium'


true_degs = jnp.sum(utils.Triu_to_mat(fixed_data["triu_star"]), axis=1)

plt.figure(figsize=(6, 4))
plt.hist(true_degs, bins=20, color="#32527b", edgecolor="white")

plt.xlabel("Degree", fontsize=12)
plt.ylabel("Count", fontsize=12)
plt.gca().set_facecolor("white")  # Ensure background is white

plt.tight_layout()



# plt.savefig("Simulations/results/figs/degree_hist.png", 
            # dpi=300,
            # bbox_inches="tight")

plt.show()

In [None]:
# plt.hist(3*true_degs/(N-1), bins=20, color="#32527b", edgecolor="white")
plt.hist(fixed_data["true_exposures"], bins=20, color="#32527b", edgecolor="white")
# plt.hist(fixed_data["true_exposures"], bins=20, color="#32527b", edgecolor="white")



In [None]:
rng_key = random.split(rng_key)[0]
# sample proxy networks with current gamma
proxy_nets = dg.generate_proxy_networks(
    # rng,
    rng_key,
    TRIU_DIM,
    fixed_data["triu_star"],
    cur_gamma,
    fixed_data["x_diff"],
    fixed_data["Z"],
)

data_sim = dg.data_for_sim(fixed_data, proxy_nets)


In [None]:

rng_key = random.split(rng_key)[0]
mwg_init = MWG_init(
    rng_key=rng_key,
    data=data_sim,
    progress_bar=True,
    n_iter_networks=20000,
    triu_star_grad_fn=models.triu_star_grad_fn,
    gwg_kernel_fn=gwg.GWG_kernel,
    gwg_init_steps=int(2e4),
    # gwg_init_steps=int(2e2),
    gwg_init_batch_len=5,
    # refine_triu_star=False
)

mwg_init_vals = mwg_init.get_init_values()



In [None]:
# MWG sampler
rng_key = random.split(rng_key)[0]
mwg_sampler = MWG_sampler(
    rng_key=rng_key,
    data=data_sim,
    init_params=mwg_init_vals,
    progress_bar=True,
    gwg_n_steps= 1,
    gwg_batch_len= 1,
    # n_warmup= 2000,
    # n_samples = 3000
)


In [None]:
print(mwg_sampler.print_diagnostics())

In [None]:
from numpyro.diagnostics import summary, print_summary

cont_samples = {
    k: v for k, v in mwg_sampler.posterior_samples.items() if k != "triu_star"
}

# NumPyro summary (calculates ESS and R-hat)
# We assume group_by_chain=False to get global stats across chains
stats = summary(cont_samples, group_by_chain=False)
print_summary(cont_samples, group_by_chain=False)


min_ess_global = float("inf")
mean_ess_eta = float("inf")

for param_name, metrics in stats.items():
    # metrics['n_eff'] can be an array if the parameter is a vector (e.g. eta)
    n_eff = jnp.array(metrics["n_eff"])
    min_ess = jnp.min(n_eff)
    mean_ess = jnp.mean(n_eff)

    if param_name == "eta":
        mean_ess_eta = mean_ess

    # Update global minimum
    if min_ess < min_ess_global:
        min_ess_global = min_ess

    ess_per_sec = mean_ess 
    print(
            f"{param_name:<10} | Mean ESS: {mean_ess:.1f} | Efficiency: {ess_per_sec:.2f} samples/sec"
        )
min_ess_per_sec = min_ess_global 
print("-" * 50)
print(f"Global Min ESS/sec: {min_ess_per_sec:.4f}")
print(f"Mean ESS for 'eta': {mean_ess_eta:.1f}")

In [None]:
mwg_dynamic_stats = mwg_sampler.new_intervention_error_stats(
    new_z=new_interventions.Z_h,
    true_estimands=new_interventions.estimand_h,
    true_vals=true_vals
)
mwg_gate_stats = mwg_sampler.new_intervention_error_stats(
    new_z=new_interventions.Z_gate,
    true_estimands=new_interventions.estimand_gate,
    true_vals=true_vals
)
mwg_dynamic_stats, mwg_gate_stats

In [None]:
mwg_dynamic_esti = mwg_sampler.sample_pred_y(
    new_z=new_interventions.Z_h,
)

mwg_gate_esti = mwg_sampler.sample_pred_y(
    new_z=new_interventions.Z_gate,
)

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 10))
ax1, ax2, ax3, ax4 = axes.flatten()


ax1.plot(mwg_sampler.posterior_samples["gamma"][:,1], alpha=0.8, color='grey')
ax2.plot(mwg_sampler.posterior_samples["eta"][:,3], alpha=0.8, color='grey')
ax3.plot(mwg_dynamic_esti.mean(axis=1), alpha=0.8, color='grey')
ax4.plot(mwg_gate_esti.mean(axis=1), alpha=0.8, color='grey')

ax1.set_xlabel("MCMC Sample", fontsize=12)
ax2.set_xlabel("MCMC Sample", fontsize=12)
ax3.set_xlabel("MCMC Sample", fontsize=12)
ax4.set_xlabel("MCMC Sample", fontsize=12)
ax1.set_ylabel("Gamma Values", fontsize=12)
ax2.set_ylabel("Eta Values", fontsize=12)
ax3.set_ylabel("Dynamic Estimand", fontsize=12)
ax4.set_ylabel("TTE Estimand", fontsize=12)

# plt.yticks(np.linspace(-1e5, 0, num=5))  # Adjust number of y-axis ticks

# plt.grid(True, linestyle="--", alpha=0.5)  # Add a subtle grid for readability
ax1.set_facecolor("white")  # Ensure background is white
ax2.set_facecolor("white")  # Ensure background is white
ax3.set_facecolor("white")  # Ensure background is white
ax4.set_facecolor("white")  # Ensure background is white

fig.tight_layout()
fig.show()

# fig.savefig("Simulations/results/figs/mwg_traceplots.png",
#             dpi=300,
#             bbox_inches="tight")

In [None]:
post_mean_probs = jnp.mean(mwg_sampler.posterior_samples["triu_star"], axis=0)

# Save and read for later
# np_post_probs = np.asarray(post_mean_probs)
# np.savetxt("Simulations/results/post_triu_star_prob.txt", np_post_probs)
# 
# post_mean_probs = np.loadtxt("Simulations/results/post_triu_star_prob.txt")

# post_mean_probs = np.loadtxt("results/post_triu_star_prob.txt")
# post_edges = jnp.where(post_mean_probs > 1/3, 1, 0)
# post_edges = jnp.where(post_mean_probs > .25, 1, 0)
# post_edges = jnp.where(post_mean_probs > .5, 1, 0)


In [None]:
true_net = utils.Triu_to_mat(data_sim.triu_star)
# obs_net = utils.Triu_to_mat(data_sim.triu_obs_rep)
obs_net = utils.Triu_to_mat(data_sim.triu_obs)
# post_mean = utils.Triu_to_mat(post_edges)
post_mean = utils.Triu_to_mat(post_mean_probs)
post_mean_prob = utils.Triu_to_mat(post_mean_probs)

In [None]:
# Order edges by hierarchical clustering for clearer visualization

true_net_np = np.array(true_net)  

distances = pdist(true_net_np, metric='euclidean') 
Z = linkage(distances, method='ward')            

node_order = leaves_list(Z)                     

#  Apply the ordering to rows and columns in all matrices 
node_order_jax = jnp.array(node_order)

true_ordered = true_net[node_order_jax][:, node_order_jax]
observed_ordered = obs_net[node_order_jax][:, node_order_jax]
posterior_ordered = post_mean[node_order_jax][:, node_order_jax]
posterior_ordered_prob = post_mean_prob[node_order_jax][:, node_order_jax]



In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# 1 row Ã— 3 cols, no special width ratios needed
fig, axes = plt.subplots(1, 3, figsize=(15,5))

# Plot each heatmap with vmin/vmax=0/1 and a title
for ax, data, title in zip(axes,
                           [true_ordered, observed_ordered, posterior_ordered_prob],
                           ["True", "Observed", "Posterior prob."]):
    im = ax.imshow(data, vmin=0, vmax=1, cmap="viridis", alpha=.95)
    ax.set_title(title, fontsize=22, fontweight="bold", pad=8, color="black")
    ax.axis("off")

# Create an inset axes on the right side of the third plot,
# sized to 5% of its width and 100% of its height
cax = inset_axes(axes[2],
                 width="5%",    # 5% of the axes width
                 height="100%", # 100% of the axes height
                 loc="lower left",
                 bbox_to_anchor=(1.05, 0, 1, 1),
                 bbox_transform=axes[2].transAxes,
                 borderpad=0.1)

# Draw the colorbar in that inset, with ticks from 0 to 1
cbar = fig.colorbar(im, cax=cax, ticks=[0, 0.25, 0.5, 0.75, 1])
# cbar.set_label("Edge probability", fontsize=14, labelpad=10, color="black")
cbar.ax.tick_params(labelsize=12, length=3, width=.5, labelcolor="black", color="black")

# Tweak the outer margins so titles never get clipped
plt.subplots_adjust(left=0.05, right=0.88, top=0.88, bottom=0.05, wspace=0.05)

plt.savefig("Simulations/results/figs/network_heatmap_prob.png", 
            dpi=150,
            bbox_inches="tight")

plt.show()

In [None]:
# --- Locally informed proposal ---
# test how well gradients approximate the true differences

params = utils.ParamTuple(
    theta=THETA + np.random.normal(0, 0.1, size=THETA.shape),
    gamma=cur_gamma + np.random.normal(0, 0.1, size=cur_gamma.shape),
    eta=ETA + np.random.normal(0, 0.1, size=ETA.shape),
    # rho=RHO,
    sig_inv=SIG_INV + np.random.normal(0, 0.1),
)

_, grads = triu_star_grad_fn(data_sim.triu_obs_rep, data_sim, params)
score_grad = -(2*data_sim.triu_obs_rep-1)*grads/2

print(score_grad.shape)

In [None]:
# Compute manual differences of log-posterior of A* for each edge flip

@jax.jit
def flip_val(x):
    return jnp.where(x == 0.0, 1.0, 0.0)

@jax.jit
def single_flip_logpost(i, triu, data, param):
    flipped = triu.at[i].set(flip_val(triu[i]))
    return cond_logpost_a_star(flipped, data, param)


score_logdensity = jnp.zeros(TRIU_DIM) 
for i in tqdm(range(TRIU_DIM)):
    score_logdensity = score_logdensity.at[i].set(
        single_flip_logpost(i, data_sim.triu_obs_rep, data_sim, params)
        )

baseline = cond_logpost_a_star(data_sim.triu_obs_rep, data_sim, params)
score_diffs = score_logdensity - baseline



In [None]:
# plt.scatter(score_diffs[:3000], score_grad[:3000])
# plt.scatter(jax.nn.log_softmax(score_diffs[:3000]), jax.nn.log_softmax(score_grad[:3000]))

diff_log_softmax = jax.nn.log_softmax(score_diffs)
grad_log_softmax = jax.nn.log_softmax(score_grad)

plt.figure(figsize=(5, 3))  # Adjust figure size for better clarity
plt.scatter(
    diff_log_softmax, 
    grad_log_softmax, 
    s=10,  # Reduce point size for better visualization
    color='dodgerblue',  # Choose a color that contrasts well
    alpha=0.7  # Slight transparency for overlapping points
)

min_val = min(diff_log_softmax), min(grad_log_softmax)
max_val = max(diff_log_softmax), max(grad_log_softmax)
plt.plot([min_val[0], max_val[0]], [min_val[1], max_val[1]], linestyle="--", color="gray", alpha=0.8)

plt.xlabel("Manual differences", fontsize=12)
plt.ylabel("Gradient differences", fontsize=12)

# plt.xticks(np.linspace(-2e5, 0, num=5))  # Adjust number of x-axis ticks
# plt.yticks(np.linspace(-1e5, 0, num=5))  # Adjust number of y-axis ticks

plt.grid(True, linestyle="--", alpha=0.5)  # Add a subtle grid for readability
plt.gca().set_facecolor("white")  # Ensure background is white


plt.savefig("Simulations/results/figs/diff_vs_gradient_scatter.png", 
            dpi=300,
            bbox_inches="tight")

plt.show()


print(f"Pearson's correlation: {np.corrcoef(score_diffs, score_grad)[0, 1]:.3f}")


In [None]:
# Test scaling of Block Gibbs Algo by sample size N

rng_key = random.PRNGKey(1)

N_VALUES = [100, 250, 500, 750, 1000]
N_ITER = 10


# for i in tqdm(range(N_ITER), desc="Iterations"):
for i in tqdm(range(7, N_ITER), desc="Iterations"):
    res_list = []
    for n in tqdm(N_VALUES, desc="N values"):
        triu_dim_n = n * (n - 1) // 2
        rng_key, _ = random.split(rng_key)
        fixed_data = dg.generate_fixed_data(rng_key, n, PARAM, PZ)

        true_vals = {
            "eta": ETA,
            "rho": jnp.array([RHO]),
            "sig_inv": jnp.array([SIG_INV]),
            "triu_star": fixed_data["triu_star"],
        }
        # generate new interventions

        rng_key = random.split(rng_key)[0]
        # sample proxy networks with current gamma
        proxy_nets = dg.generate_proxy_networks(
            # rng,
            rng_key,
            triu_dim_n,
            fixed_data["triu_star"],
            cur_gamma,
            fixed_data["x_diff"],
            fixed_data["Z"],
        )

        data_sim = dg.data_for_sim(fixed_data, proxy_nets)

        print(f"Sampling with MWG for N={n}...")

        rng_key = random.split(rng_key)[1]
        mwg_init = MWG_init(
            rng_key=rng_key,
            data=data_sim,
            progress_bar=False,
            gwg_init_steps=int(2e4),
            gwg_init_batch_len=5,
            refine_triu_star=True,
        ).get_init_values()

        start = time()
        rng_key = random.split(rng_key)[1]
        mwg_sampler = MWG_sampler(
            rng_key=rng_key,
            data=data_sim,
            init_params=mwg_init,
            progress_bar=False, 
            gwg_batch_len=1,
            gwg_n_steps=1,
        )
        end = time()
        ttl_time = end - start
        print(f"Total sampling time for N={n}: {ttl_time:.2f} seconds")

        min_ess_sec, min_ess_eta, min_ess_sec_eta = mwg_sampler.print_diagnostics(to_print=False)

        print(f"Min ESS/sec for N={n}: {min_ess_sec:.2f}")
        print(f"Min ESS for 'eta' for N={n}: {min_ess_eta:.2f}")
        print(f"Min ESS/sec for 'eta' for N={n}: {min_ess_sec_eta:.2f}")

        res_list.append({
            'N': n,
            'min_ess_sec': float(min_ess_sec),
            'min_ess_eta': float(min_ess_eta),
            'min_ess_sec_eta': float(min_ess_sec_eta),
            'ttl_time': float(ttl_time),
            'iteration': i + 1,
        })

    results_df = pd.DataFrame(res_list)
    file_name = "Simulations/results/mwg_scaling_results.csv"
    header = i == 0
    results_df.to_csv(file_name, index=False, mode="a", header=header)



In [None]:
mwg_scaling_df = pd.read_csv("Simulations/results/mwg_scaling_results.csv")
mwg_scaling_df.head()

In [None]:

# 1. Isolate the baseline values (N=100) for each iteration
baseline_df = mwg_scaling_df[mwg_scaling_df['N'] == 100][['iteration', 'min_ess_sec']].rename(columns={'min_ess_sec': 'min_ess_sec_100'})

# 2. Merge these baseline values back into the main dataframe based on 'iteration'
merged_df = pd.merge(mwg_scaling_df, baseline_df, on='iteration')

# 3. Calculate the relative efficiency for every row
merged_df['min_ess_sec_rel'] = merged_df['min_ess_sec'] / merged_df['min_ess_sec_100']

# 4. Aggregate by N to find the mean and standard deviation across iterations
grouped = merged_df.groupby('N')[['ttl_time', 'min_ess_sec_rel']].agg(['mean', 'std'])

# Extract values for plotting
N_values = grouped.index
ttl_time_mean = grouped['ttl_time']['mean']
ttl_time_std = grouped['ttl_time']['std']
min_ess_sec_rel_mean = grouped['min_ess_sec_rel']['mean']
min_ess_sec_rel_std = grouped['min_ess_sec_rel']['std']

fig, ax1 = plt.subplots(figsize=(10, 6))

# --- Left Axis: Total Time ---
color_left = 'tab:blue'
ax1.set_xlabel('$N$', fontsize=16)
ax1.set_ylabel('Total Time (s)', color=color_left, fontsize=14)
ax1.errorbar(
    N_values, 
    ttl_time_mean, 
    yerr=ttl_time_std, 
    fmt='-o', 
    color=color_left, 
    label='Total Time',
    capsize=5
)
ax1.tick_params(axis='y', labelcolor=color_left, labelsize=12)
ax1.grid(True, alpha=0.3)

# Set x-ticks to correspond exactly to N values
ax1.set_xticks(N_values)
ax1.set_xticklabels(N_values, fontsize=12)

ax1.set_facecolor("white")  
ax2.set_facecolor("white")  

# --- Right Axis: Relative Min ESS/sec ---
ax2 = ax1.twinx()
color_right = 'tab:orange'
ax2.set_ylabel('Relative Min ESS / sec (vs $N=100$)', color=color_right, fontsize=14)
ax2.errorbar(
    N_values, 
    min_ess_sec_rel_mean, 
    yerr=min_ess_sec_rel_std, 
    fmt='-s', 
    color=color_right, 
    label='Relative Min ESS/sec',
    capsize=5
)
ax2.tick_params(axis='y', labelcolor=color_right, labelsize=12)

plt.tight_layout()

plt.savefig('Simulations/results/figs/mwg_scaling_plot.png', 
            dpi=300,
            bbox_inches='tight')