In [None]:
import os
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import jraph

from cfm_training_structure import preload_hdf5_to_memory, data_loader, sample_cfm
from cfm_gnn import GraphConvNet, CFMGraphModel


def load_cfm_model(params_path, hidden_size=1024, num_mlp_layers=3, latent_size=128, target_dim=3):
    """Recreate model architecture exactly + load params."""
    backbone = GraphConvNet(
        latent_size=latent_size,
        hidden_size=hidden_size,
        num_mlp_layers=num_mlp_layers,
        message_passing_steps=5,
        skip_connections=True,
        edge_skip_connections=True,
        norm="none",
        attention=True,
        shared_weights=True,
        relative_updates=False,
        output_dim=target_dim,
        dropout_rate=0.0,
    )
    model = CFMGraphModel(backbone=backbone, target_dim=target_dim, time_emb_dim=32)

    with open(params_path, "rb") as f:
        params = pickle.load(f)

    return model, params


def get_nth_batch(test_data, batch_size=1, n=0, shuffle=False):
    """Grab the n-th batch from the deterministic loader."""
    it = data_loader(test_data, batch_size=batch_size, shuffle=shuffle)
    for _ in range(n):
        next(it)
    return next(it)  # graph, tgt, mask


def posterior_samples_for_graph(
    params,
    model,
    graph,
    num_samples=32,
    ode_steps=64,
    target_dim=3,
    seed=0,
):
    rng = jax.random.PRNGKey(seed)
    samples = []
    for k in range(num_samples):
        rng, key = jax.random.split(rng)
        x = sample_cfm(
            params=params,
            apply_fn=model.apply,
            graph=graph,
            rng_key=key,
            num_steps=ode_steps,
            target_dim=target_dim,
        )
        samples.append(x)
    return jnp.stack(samples, axis=0)


def posterior_stats(samples, mask=None):
    """
    samples: (S, N, D)
    mask: (N,) optional 0/1 or bool
    returns dict of mean/std and quantiles
    """
    mean = jnp.mean(samples, axis=0)
    std  = jnp.std(samples, axis=0)

    q16 = jnp.quantile(samples, 0.16, axis=0)
    q50 = jnp.quantile(samples, 0.50, axis=0)
    q84 = jnp.quantile(samples, 0.84, axis=0)

    out = dict(mean=mean, std=std, q16=q16, q50=q50, q84=q84)

    if mask is not None:
        m = mask.astype(bool).reshape(-1)
        out = {k: v[m] for k, v in out.items()}
    return out


def masked(arr, mask):
    m = mask.astype(bool).reshape(-1)
    n = min(arr.shape[0], m.shape[0])
    return arr[:n][m[:n]]


def plot_truth_vs_mean(tgt, post_mean, mask, labels=("z", "vx", "vy"), lims=None):
    tgt_m = masked(tgt, mask)
    mu_m  = masked(post_mean, mask)

    D = tgt_m.shape[-1]
    fig, axes = plt.subplots(1, D, figsize=(5*D, 4))
    if D == 1:
        axes = [axes]

    for d in range(D):
        ax = axes[d]
        ax.scatter(tgt_m[:, d], mu_m[:, d], s=6, alpha=0.6)
        lo = float(jnp.min(jnp.concatenate([tgt_m[:, d], mu_m[:, d]])))
        hi = float(jnp.max(jnp.concatenate([tgt_m[:, d], mu_m[:, d]])))
        if lims is not None:
            lo, hi = lims[d]
        ax.plot([lo, hi], [lo, hi], lw=1)
        ax.set_xlabel(f"Truth {labels[d]}")
        ax.set_ylabel(f"Posterior mean {labels[d]}")
        ax.set_title(f"{labels[d]} truth vs mean")
        ax.set_xlim(lo, hi)
        ax.set_ylim(lo, hi)

    plt.tight_layout()
    plt.show()


def plot_errorbars(tgt, post_mean, post_std, mask, dim=0, label="z", n_points=200, seed=0):
    """
    Random subset of nodes: truth vs mean with ±1σ errorbars.
    """
    rng = np.random.default_rng(seed)
    tgt_m = np.array(masked(tgt, mask))
    mu_m  = np.array(masked(post_mean, mask))
    sd_m  = np.array(masked(post_std, mask))

    N = tgt_m.shape[0]
    idx = rng.choice(N, size=min(n_points, N), replace=False)

    x = tgt_m[idx, dim]
    y = mu_m[idx, dim]
    yerr = sd_m[idx, dim]

    plt.figure(figsize=(6, 5))
    plt.errorbar(x, y, yerr=yerr, fmt="o", ms=3, alpha=0.5, capsize=0)
    lo = min(np.min(x), np.min(y))
    hi = max(np.max(x), np.max(y))
    plt.plot([lo, hi], [lo, hi], lw=1)
    plt.xlabel(f"Truth {label}")
    plt.ylabel(f"Posterior mean {label}")
    plt.title(f"{label}: mean ± 1σ (subset)")
    plt.tight_layout()
    plt.show()


def coverage_curve(tgt, samples, mask, dim=0, label="z"):
    """
    Empirical coverage: for each nominal p in (0,1), check fraction of truths inside
    the central p credible interval from samples.
    """
    tgt_m = masked(tgt, mask)[:, dim]
    s_m   = samples[:, mask.astype(bool), dim]  # (S, Nmasked)

    ps = jnp.linspace(0.05, 0.95, 19)
    cov = []
    for p in ps:
        lo = jnp.quantile(s_m, (1 - p) / 2, axis=0)
        hi = jnp.quantile(s_m, 1 - (1 - p) / 2, axis=0)
        inside = (tgt_m >= lo) & (tgt_m <= hi)
        cov.append(jnp.mean(inside.astype(jnp.float32)))

    ps = np.array(ps)
    cov = np.array(cov)

    plt.figure(figsize=(5, 5))
    plt.plot(ps, cov, marker="o")
    plt.plot([0, 1], [0, 1], lw=1)  # ideal
    plt.xlabel("Nominal credible mass p")
    plt.ylabel("Empirical coverage")
    plt.title(f"Coverage curve ({label})")
    plt.tight_layout()
    plt.show()


In [None]:
# --- paths ---
data_path = "/projects/mccleary_group/habjan.e/TNG/Data/GNN_SBI_data/"
test_file = "GNN_data_test.h5"

# --- load data ---
test_data = preload_hdf5_to_memory(data_path, test_file)

In [None]:
# path to the params you saved in train_cfm_model.py
params_path = os.path.join("/home/habjan.e/TNG/cluster_deprojection/probabilistic_model/CFM_models", "cfm_model_params_cfm_testing.pkl")  # adjust suffix

# --- load model/params ---
model, params = load_cfm_model(
    params_path,
    hidden_size=1024,
    num_mlp_layers=3,
    latent_size=128,
    target_dim=3,
)

# --- get a single test graph ---
graph, tgt, mask = get_nth_batch(test_data, batch_size=1, n=0, shuffle=False)

print("graph nodes:", graph.nodes.shape, "targets:", tgt.shape, "mask:", mask.shape)

# --- posterior sampling ---
samples = posterior_samples_for_graph(
    params=params,
    model=model,
    graph=graph,
    num_samples=1,   # posterior draws
    ode_steps=16,     # Euler steps (increase for better samples)
    target_dim=3,
    seed=0,
)

In [None]:
stats = posterior_stats(samples, mask=mask)
mu, sd = stats["mean"], stats["std"]

# --- plots ---
plot_truth_vs_mean(tgt, mu, mask, labels=("z", "vx", "vy"))
plot_errorbars(tgt, mu, sd, mask, dim=0, label="z", n_points=200)
plot_errorbars(tgt, mu, sd, mask, dim=1, label="vx", n_points=200)
plot_errorbars(tgt, mu, sd, mask, dim=2, label="vy", n_points=200)

coverage_curve(tgt, samples, mask, dim=0, label="z")
coverage_curve(tgt, samples, mask, dim=1, label="vx")
coverage_curve(tgt, samples, mask, dim=2, label="vy")