In [None]:
import os
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import keras
os.environ["KERAS_BACKEND"] = "tensorflow"
import bayesflow as bf
import pickle
import EZ2

In [None]:
# Helper functions

def _p_bottom_safe(nu, z, a, s=0.1):
    """Bottom-hit probability with a fallback for near-zero nu."""
    if abs(nu) < 1e-12:
        return (a - z) / a
    s2 = s * s
    num = np.exp(-2*a*nu/s2) - np.exp(-2*z*nu/s2)
    den = np.exp(-2*a*nu/s2) - 1.0
    return float(np.clip(num / den, 0.0, 1.0))

def _safe_rddexit(size, nu, z, a, top_boundary):
    """Call EZ2.rddexit and always return a list (even for size==1)."""
    if size <= 0:
        return []
    arr = EZ2.rddexit(size, nu, z, a, top_boundary=top_boundary)
    if np.isscalar(arr):
        return [float(arr)]
    return [float(x) for x in np.asarray(arr).ravel()]

def _sample_times(size, nu, z, a, s=0.1, rng=None):
    """Safer equivalent of rddexitj using robust fallbacks."""
    if rng is None:
        rng = np.random.default_rng()
    p0 = _p_bottom_safe(nu, z, a, s=s)
    n_bottom = rng.binomial(size, p0)
    n_top = size - n_bottom
    et_bottom = _safe_rddexit(n_bottom, nu, z, a, top_boundary=False)
    et_top    = _safe_rddexit(n_top,    nu, z, a, top_boundary=True)
    return et_bottom, et_top

# Forward model

def forward_model_ez2(
    vL, vR, a, z, terL, terR, n_trials=200, rng=None,
    rt_transform="log1p" # "log1p" or "none"
):
    if rng is None:
        rng = np.random.default_rng()

    # scale to s=0.1 used in functions EZ2
    c = 0.1
    vL_ez, vR_ez = float(vL)*c, float(vR)*c
    a_ez = float(a)*c

    # converting relative z to absolute z
    z_abs = float(z) * a_ez
    eps = 1e-9 * a_ez
    z_abs = min(max(z_abs, eps), a_ez - eps)

    nA = n_trials // 2
    nB = n_trials - nA

    # A condition (Left correct): top->Left(0), bottom->Right(1)
    et_b_A, et_t_A = _sample_times(nA, vL_ez, z_abs, a_ez, s=0.1, rng=rng)
    et_b_A = np.asarray(et_b_A, dtype=np.float64)
    et_t_A = np.asarray(et_t_A, dtype=np.float64)
    nAb, nAt = et_b_A.size, et_t_A.size
    dts_A = np.empty(nA, dtype=np.float64); dts_A[:nAb] = et_b_A; dts_A[nAb:] = et_t_A
    choices_A = np.empty(nA, dtype=np.int64); choices_A[:nAb] = 1; choices_A[nAb:] = 0
    correct_A = np.empty(nA, dtype=np.int64); correct_A[:nAb] = 0; correct_A[nAb:] = 1
    stim_A = np.zeros(nA, dtype=np.int64)

    # B condition (Right correct): top->Right(1), bottom->Left(0)
    et_b_B, et_t_B = _sample_times(nB, vR_ez, a_ez - z_abs, a_ez, s=0.1, rng=rng)
    et_b_B = np.asarray(et_b_B, dtype=np.float64)
    et_t_B = np.asarray(et_t_B, dtype=np.float64)
    nBb, nBt = et_b_B.size, et_t_B.size
    dts_B = np.empty(nB, dtype=np.float64); dts_B[:nBb] = et_b_B; dts_B[nBb:] = et_t_B
    choices_B = np.empty(nB, dtype=np.int64); choices_B[:nBb] = 0; choices_B[nBb:] = 1
    correct_B = np.empty(nB, dtype=np.int64); correct_B[:nBb] = 0; correct_B[nBb:] = 1
    stim_B = np.ones(nB, dtype=np.int64)

    dts = np.concatenate([dts_A, dts_B])
    choices = np.concatenate([choices_A, choices_B])
    correct = np.concatenate([correct_A, correct_B])
    stimulus = np.concatenate([stim_A, stim_B])

    perm = rng.permutation(n_trials)
    dts, choices, correct, stimulus = dts[perm], choices[perm], correct[perm], stimulus[perm]

    # Add ter and optional log-transform
    rts = dts + np.where(choices == 0, terL, terR)
    if rt_transform == "log1p":
        rts = np.log1p(rts)

    return {
        "rts": rts.astype(np.float32),
        "choices": choices,
        "stimulus": stimulus,
        "correct": correct,
    }


In [None]:
def prior():
  params = {}

  # Drift rates v toward left and right responses
  params['vL'] = np.random.uniform(0.1, 6.0)
  params['vR'] = np.random.uniform(0.1, 6.0)

  # Boundary separation a
  params['a'] = np.random.uniform(0.3, 4.0)

  # Relative starting point z
  params['z'] = np.random.uniform(0.1, 0.9)

  # Non-decision times ter for left and right responses (in seconds)
  params['terL'] = np.random.uniform(0.1, 1.0)
  params['terR'] = np.random.uniform(0.1, 1.0)

  return params

In [None]:
simulator = bf.make_simulator([prior, forward_model_ez2])

In [None]:
param_names = ['vL', 'vR', 'a', 'z', 'terL', 'terR']
data_names = ['rts', 'stimulus', 'choices']  # Removed 'correct' for redundancy of information
# 'correct' can be derived from 'choices' and 'stimulus' so we it's not needed

adapter = (
    bf.adapters.Adapter()
    .keep(param_names + data_names)
    .to_array()
    .convert_dtype("float64", "float32")
    .expand_dims("rts", axis=-1)
    .expand_dims("choices", axis=-1)
    .expand_dims("stimulus", axis=-1)
    .concatenate(param_names, into="inference_variables")
    .concatenate(data_names, into="summary_variables")
)

In [None]:
from bayesflow.networks import CouplingFlow, DeepSet
from bayesflow.workflows import BasicWorkflow

summary_net = DeepSet(
    summary_dim=16,
    dropout=0.1
)

flow = CouplingFlow(
    num_coupling_layers=6,
    hidden_units=[128, 128],
    coupling_type="spline",
    batch_norm=True,
    dropout=0.05,
    tail_bound=6.0
)

wf = BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    summary_network=summary_net,
    inference_network=flow,
    inference_variables=["inference_variables"],
    summary_variables=["summary_variables"]
)


In [None]:
loaded_approx = keras.models.load_model(
  "C:/Users/emils/Documents/uni/M_Thesis/diffusion-bayesflow/scripts/standard_model.keras"
)

In [None]:
wf = BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    summary_network=summary_net,
    inference_network=flow,
    inference_variables=["inference_variables"],
    summary_variables=["summary_variables"]
)

wf.approximator = loaded_approx

Making parameter estimates for the test datasets. The test datasets and true parameter values are generated in the notebook ezbf_results.ipynb, to ensure that the same thetas and datasets are used in both notebooks.

In [None]:
N = 100        
n_trials = 500  
n_samples = 500 
param_names = ['vL', 'vR', 'a', 'z', 'terL', 'terR']

# Load stored sets from ez_results.ipynb
# These sets were generated to ensure the same thetas are used for both models to obtain comparable results
with open("set_thetas.pkl", "rb") as f:
    set_thetas = pickle.load(f)

obs_data = np.load("set_test_data.npy")

true_params = []
post_samples = []

for i in range(N):
    theta = set_thetas[i]
    true_params.append([theta[k] for k in param_names])

    # Extract components from raw data
    rt = obs_data[i, :, 0]
    stimulus = obs_data[i, :, 1]
    choice = obs_data[i, :, 2]

    # Prepare input for BayesFlow, where keys must match adapter expectations
    input_data = {
        "rts": rt[np.newaxis, :],
        "stimulus": stimulus[np.newaxis, :],
        "choices": choice[np.newaxis, :]
    }

    samples = wf.sample(conditions=input_data, num_samples=n_samples, to_numpy=True)

    param_array = np.stack([
        samples["vL"], samples["vR"],
        samples["a"], samples["z"],
        samples["terL"], samples["terR"]
    ], axis=-1)

    param_array = np.squeeze(param_array, axis=(0, 2))

    post_samples.append(param_array)

# Convert to arrays
true_params = np.array(true_params)   # shape: (N, 6)
post_samples = np.array(post_samples) # shape: (N, n_samples, 6)

In [None]:
# Definining a new prior to be used to generate the simulated test data
# This prior is based on realistic, empirical parameter distributions described
# in Tran et al. (2021)

def sample_trunc_t(df, loc, scale, lower, upper):
    while True:
        x = np.random.standard_t(df=df)
        y = loc + scale * x
        if lower <= y <= upper:
            return y

def plausible_prior():
    # Drift rates: Normal(1.76, 1.51) truncated to ≥ 0.2
    while True:
        vL = np.random.normal(loc=1.76, scale=1.51)
        vR = np.random.normal(loc=1.76, scale=1.51)
        if vL >= 0.2 and vR >= 0.2:
            break

    # Boundary separation: Gamma(11.69, 0.12) capped at 4
    a = np.random.gamma(shape=11.69, scale=0.12)
    a = min(a, 4)

    # Starting point: Truncated Student-T in [0, 1]
    z = sample_trunc_t(df=1.85, loc=0.5, scale=0.1, lower=0.0, upper=1.0)

    # Non-decision times: Truncated Student-T ≥ 0
    while True:
        terL = sample_trunc_t(df=1.32, loc=0.44, scale=0.08, lower=0.0, upper=np.inf)
        terR = sample_trunc_t(df=1.32, loc=0.44, scale=0.08, lower=0.0, upper=np.inf)
        if terL >= 0 and terR >= 0:
            break

    return {
        "vL": vL,
        "vR": vR,
        "a": a,
        "z": z,
        "terL": terL,
        "terR": terR
    }

Computing the posterior calibration for the boundary separation (a) parameter. The calibration is computed on an independent set of simulated data, which uses the plausible prior truncated to the ranges of the uniform prior used in training.

In [None]:
# Simulating data for the calibration

# Defining the parameter ranges used in model training
TRAIN_SUPPORT = {
    "vL":   (0.1, 6.0),
    "vR":   (0.1, 6.0),
    "a":    (0.3, 4.0),
    "z":    (0.1, 0.9),
    "terL": (0.1, 1.0),
    "terR": (0.1, 1.0),
}

N_CAL         = 200
N_TRIALS_CAL  = 400
SAMPLES_CAL   = 500
SEED_CAL      = 123

rng = np.random.default_rng(SEED_CAL)

cal_true_params   = []
cal_post_samples  = []
cal_skip_log      = []
accepted = 0
candidates = 0
MAX_CAND = 2000

while accepted < N_CAL and candidates < MAX_CAND:
    candidates += 1
    theta = plausible_prior()

    # Keep only thetas inside the training support
    if any(not (lo <= theta[k] <= hi) for k,(lo,hi) in TRAIN_SUPPORT.items()):
        continue

    try:
        sim = forward_model_ez2(**theta, n_trials=N_TRIALS_CAL, rng=rng)
        rts_log = np.asarray(sim["rts"], float)
        stim    = np.asarray(sim["stimulus"], int)
        choice  = np.asarray(sim["choices"], int)

        # Quick safety checks
        if not (np.isfinite(rts_log).all() and rts_log.size == stim.size == choice.size):
            cal_skip_log.append({"reason": "non-finite or size mismatch", "theta": theta})
            continue
        if not (np.any(stim == 0) and np.any(stim == 1)):
            cal_skip_log.append({"reason": "missing one stimulus", "theta": theta})
            continue

        input_data = {
            "rts":      rts_log[np.newaxis, :],
            "stimulus": stim[np.newaxis, :],
            "choices":  choice[np.newaxis, :],
        }

        samples = wf.sample(conditions=input_data, num_samples=SAMPLES_CAL, to_numpy=True)

        post_array = np.column_stack([
            samples["vL"][0].squeeze(),
            samples["vR"][0].squeeze(),
            samples["a"][0].squeeze(),
            samples["z"][0].squeeze(),
            samples["terL"][0].squeeze(),
            samples["terR"][0].squeeze()
        ]).astype(np.float32)

        # Safety check
        if not np.isfinite(post_array).all():
            cal_skip_log.append({"reason": "NaNs in posterior", "theta": theta})
            continue

    # Log thetas that had to be skipped (if any)
    except Exception as e:
        cal_skip_log.append({"reason": f"Exception: {e}", "theta": theta})
        continue

    cal_true_params.append([theta[k] for k in ["vL","vR","a","z","terL","terR"]])
    cal_post_samples.append(post_array)
    accepted += 1

print(f"Calibration split: accepted {accepted} from {candidates} candidates.")
cal_true_params  = np.asarray(cal_true_params, dtype=float)       # (N_CAL, 6)
cal_post_samples = np.asarray(cal_post_samples, dtype=np.float32) # (N_CAL, SAMPLES_CAL, 6)


# Computing the calibration

PARAM_ORDER = ['vL','vR','a','z','terL','terR']

# fitting: piecewise inverse for 'a' on the independent calibration split
def fit_piecewise_a(cal_true_params, cal_post_samples, param_order,
                    nbins=10, min_bin_n=8, point_est="mean"):
    j = param_order.index("a")
    est_mat = np.median(cal_post_samples, 1) if point_est=="median" else np.mean(cal_post_samples, 1)
    x = cal_true_params[:, j].astype(float)   # truth
    y = est_mat[:, j].astype(float)           # posterior mean (raw)

    # quantile bins on truth
    bins = np.quantile(x, np.linspace(0, 1, nbins + 1))
    idx  = np.digitize(x, bins[1:-1], right=True)

    xs, ys = [], []
    for b in range(nbins):
        m = (idx == b)
        if m.sum() >= min_bin_n:
            xs.append(float(x[m].mean()))
            ys.append(float(y[m].mean()))

    if len(xs) < 3:
        return {"mode": "identity"}

    xs = np.asarray(xs, float)
    ys = np.asarray(ys, float)

    # ensure monotone in y for safe interpolation; collapse near-duplicates
    order = np.argsort(ys)
    ys = ys[order]; xs = xs[order]
    uniq = np.concatenate(([True], np.diff(ys) > 1e-12))
    ys = ys[uniq]; xs = xs[uniq]
    if ys.size < 3:
        return {"mode": "identity"}

    return {"mode": "piecewise", "ys": ys, "xs": xs}

# apply to calibrate only 'a', since prior investigation found the calibration to not improve the MSE of vL and vR
def apply_piecewise_a_only(post_samples, model, param_order):
    if model.get("mode") != "piecewise":
        return post_samples  # identity

    j = param_order.index("a")
    ys, xs = model["ys"], model["xs"]

    out = post_samples.copy()
    y = out[..., j].astype(np.float64)                        # raw draws of 'a'
    x_hat = np.interp(y, ys, xs, left=xs[0], right=xs[-1])    # inverse map g(y)
    x_hat = np.clip(x_hat, 1e-4, None)                        # 'a' must be > 0
    out[..., j] = x_hat.astype(out.dtype, copy=False)
    return out

# fit on the calibration data and apply to the test data
a_model = fit_piecewise_a(cal_true_params, cal_post_samples, PARAM_ORDER,
                          nbins=10, min_bin_n=8, point_est="mean")

post_samples_cal = apply_piecewise_a_only(post_samples, a_model, PARAM_ORDER)

# quick sanity on TEST (posterior mean as estimator)
def _est(arr, how="mean"): return np.median(arr,1) if how=="median" else np.mean(arr,1)
j_a = PARAM_ORDER.index("a")
b0 = float(np.mean(_est(post_samples,       "mean")[:, j_a] - true_params[:, j_a]))
b1 = float(np.mean(_est(post_samples_cal,  "mean")[:, j_a] - true_params[:, j_a]))
print(f"[TEST] a: mean bias {b0:.3f} -> {b1:.3f}")

# quick safety tests
# after creating post_samples_cal
assert post_samples.shape == post_samples_cal.shape
j = PARAM_ORDER.index("a")
# calibration changes only 'a'
unchanged = np.allclose(post_samples[..., np.r_[0:j, j+1:6]],
                        post_samples_cal[..., np.r_[0:j, j+1:6]])
assert unchanged, "Only 'a' should change under calibration."

# function for calibrating a dict of posterior samples (used in the validation study)
def calibrate_bf_sample_dict_a_only(samples, a_model, param_order):
    if a_model.get("mode") != "piecewise":
        return samples
    arr = np.column_stack([
        samples['vL'][0].squeeze(),
        samples['vR'][0].squeeze(),
        samples['a'][0].squeeze(),
        samples['z'][0].squeeze(),
        samples['terL'][0].squeeze(),
        samples['terR'][0].squeeze()
    ]).astype(np.float32)[np.newaxis, ...]
    
    arr_cal = apply_piecewise_a_only(arr, a_model, param_order)[0]

    out = {k: v.copy() for k,v in samples.items()}
    out['a'][0] = arr_cal[:, 2][:, None]
    return out

The whether the posterior samples including the calibrated 'a' are used can be toggled using the code cell below.

In [None]:
APPLY_CALIBRATION = False # toggle for using the calibration (True) or not using it (False)
post_samples_used = (post_samples_cal if APPLY_CALIBRATION else post_samples)

In [None]:
# Accuracy metrics

# Optional: toggle which point estimate is used for calculating errors
POINT_EST = "mean" # "mean" or "median"

param_names = ['vL', 'vR', 'a', 'z', 'terL', 'terR']

records = []
for i in range(true_params.shape[0]):
    for j, name in enumerate(param_names):
        true_val = float(true_params[i, j])
        samples  = post_samples_used[i, :, j]

        # Posterior summaries
        post_mean = float(np.mean(samples))
        post_median = float(np.median(samples))
        post_var = float(np.var(samples, ddof=1))
        post_sd  = float(np.sqrt(post_var))

        # Point estimate used for accuracy metrics
        point_est = post_median if POINT_EST == "median" else post_mean

        # Per-dataset estimation error (bias at the single-dataset level)
        error = float(point_est - true_val)
        se    = float(error**2)                      # squared error (per dataset)

        records.append({
            "dataset": i,
            "parameter": name,
            "true_value": true_val,
            "posterior_mean": post_mean,
            "posterior_median": post_median,
            "posterior_variance": post_var,
            "posterior_sd": post_sd,
            "point_estimate": point_est,            # the estimator used for triad
            "error": error,                         # will aggregate to Bias = mean(error)
            "mse": se                               # per-dataset squared error; MSE = mean(mse)
        })

performance_stats_df = pd.DataFrame(records)

In [None]:
triad = (
    performance_stats_df
    .groupby('parameter')
    .agg(
        n=('error','size'),
        Bias=('error','mean'),
        Variance_of_error=('error', lambda s: s.var(ddof=1)),
        MSE=('mse','mean')
    ).reset_index()
)

triad.round(3)

In [None]:
# Identifying and cleaning outliers
thr = performance_stats_df.groupby('parameter')['error'].transform(lambda s: s.abs().quantile(0.75) + 3.0*(s.abs().quantile(0.75)-s.abs().quantile(0.25)))
df_clean = performance_stats_df[performance_stats_df['error'].abs() <= thr].copy()

is_out = ~performance_stats_df.index.isin(df_clean.index)
outlier_counts = performance_stats_df.assign(outlier=is_out).groupby('parameter')['outlier'].agg(n_total='size', n_dropped='sum').assign(pct_dropped=lambda x: 100*x['n_dropped']/x['n_total'])
outlier_counts

In [None]:
triad_clean = (
    df_clean
    .groupby('parameter')
    .agg(
        n=('error','size'),
        Bias=('error','mean'),
        Variance=('error', lambda s: s.var(ddof=1)),   # variance of error across datasets
        MSE=('error', lambda s: float(np.mean(s**2)))  # mean squared error across datasets
    )
    .reset_index()
)

triad_clean.round(3)

In [None]:
# Storing output for later model comparison

MODEL_LABEL = "standard"

performance_stats_df = performance_stats_df.copy()
performance_stats_df["model"] = MODEL_LABEL
df_clean = df_clean.copy()
df_clean["model"] = MODEL_LABEL

performance_stats_df.to_csv(f"{MODEL_LABEL}_perf_full.csv", index=False)
df_clean.to_csv(f"{MODEL_LABEL}_perf_clean.csv", index=False)

In [None]:
# Plotting the bias and posterior variance

import seaborn as sns, matplotlib.pyplot as plt, numpy as np, pandas as pd
sns.set(style="whitegrid")

DFU = performance_stats_df.copy()
base = performance_stats_df

# counts dropped per parameter
n_total = base.groupby("parameter").size()
n_kept  = DFU.groupby("parameter").size()
n_drop  = (n_total - n_kept).reindex(n_total.index).fillna(0).astype(int)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Panel A: estimation error
sns.boxplot(data=DFU, x="parameter", y="error", ax=axes[0], showfliers=False)
axes[0].axhline(0, ls="--", c="gray", lw=1)
axes[0].set_ylim(-1.6, 1.9)
axes[0].set_title("Bias by Parameter")
axes[0].set_xlabel(""); axes[0].set_ylabel("Bias")

# Panel B: posterior variance
sns.boxplot(data=DFU, x="parameter", y="posterior_sd", ax=axes[1], showfliers=False)
axes[1].set_ylim(0, 0.6)
axes[1].set_title("Posterior SD by Parameter")
axes[1].set_xlabel(""); axes[1].set_ylabel("Posterior SD")

plt.tight_layout(rect=[0,0.05,1,1])
plt.show()

In [None]:
# Credible intervals and coverage

N, S, P = post_samples_used.shape  # N = datasets, S = samples, P = parameters

ci_records = []

for i in range(N):
    for j, name in enumerate(param_names):
        samples_ij = post_samples_used[i, :, j]
        true_val = true_params[i, j]

        # Compute 95% credible interval
        lower = np.percentile(samples_ij, 2.5)
        upper = np.percentile(samples_ij, 97.5)
        width = upper - lower

        # Coverage
        covered = int(lower <= true_val <= upper)

        ci_records.append({
            "dataset": i,
            "parameter": name,
            "true_value": true_val,
            "lower_95": lower,
            "upper_95": upper,
            "width_95": width,
            "covered": covered
        })

ci_df = pd.DataFrame(ci_records)


In [None]:
# Overall coverage per parameter
coverage_summary = ci_df.groupby("parameter")["covered"].mean().reset_index()
coverage_summary.rename(columns={"covered": "coverage_rate"}, inplace=True)

# Plot distribution of credible interval widths
plt.figure(figsize=(10, 6))
sns.boxplot(data=ci_df, x="parameter", y="width_95")
plt.title("Distribution of 95% CrI Widths by Parameter")
plt.ylabel("Width of 95% CrI")
plt.xlabel("Parameter")
plt.ylim(0, 2)
plt.tight_layout()
plt.show()

# Return the summary coverage table
coverage_summary

In [None]:
# Storing output for later model comparison

MODEL_LABEL = "standard"
ci_df_model = ci_df.copy()
ci_df_model["model"] = MODEL_LABEL
ci_df_model.to_csv(f"{MODEL_LABEL}_ci_full.csv", index=False)

In [None]:
# Simulation-based calibration (SBC)

from scipy.stats import chi2, kstest, binom

N, S, P = post_samples_used.shape
rng = np.random.default_rng(1312)

# Randomized ranks and normalized ranks z in (0,1)
sbc_records = []
for j, name in enumerate(param_names):
    for i in range(N):
        true_val = float(true_params[i, j])
        samples  = post_samples_used[i, :, j]

        n_less  = np.sum(samples < true_val)
        n_equal = np.sum(samples == true_val)   # robust to rare ties
        rank    = n_less + rng.uniform(0, 1) * n_equal     # fractional rank
        z       = (rank + 1.0) / (S + 1.0)                 # normalized (Uniform(0,1) target)

        sbc_records.append({"dataset": i, "parameter": name, "rank": float(rank), "z": float(z)})

sbc_df = pd.DataFrame(sbc_records)

# KL divergence helper (observed || uniform), with Jeffreys smoothing
def kl_to_uniform(counts, alpha=0.5, base='e'):
    counts = np.asarray(counts, float)
    B = len(counts)
    p = (counts + alpha) / (counts.sum() + alpha * B)  # smoothed observed probs
    q = np.full(B, 1.0 / B)                            # exact uniform probs
    kl = np.sum(p * np.log(p / q))
    if base == '2':
        kl /= np.log(2.0)
    return float(max(kl, 0.0))

# Histograms with 95% expected bands + diagnostics
num_bins = 20
bins = np.linspace(0.0, 1.0, num_bins + 1)

fig, axes = plt.subplots(3, 2, figsize=(10, 12), sharey=False)
axes = axes.flatten()

sbc_summ_rows = []

for ax, name in zip(axes, param_names):
    sub = sbc_df.loc[sbc_df["parameter"] == name, "z"].to_numpy()
    counts, _ = np.histogram(sub, bins=bins)
    n_sub = len(sub)
    exp = n_sub / num_bins

    # 95% band for each bin's count under Binomial(n_sub, 1/B)
    low, high = binom.interval(0.95, n_sub, 1/num_bins)
    low, high = float(low), float(high)

    # diagnostics
    kl = kl_to_uniform(counts, alpha=0.5, base='e')
    chi_stat = float(((counts - exp)**2 / exp).sum())
    chi_p = float(chi2.sf(chi_stat, df=num_bins - 1))

    sbc_summ_rows.append({
        "parameter": name, "n": n_sub,
        "kl_div_to_uniform": kl,
        "chi2_stat": chi_stat, "chi2_p": chi_p
    })

    ax.bar(np.arange(num_bins), counts, width=1, edgecolor='k')
    ax.axhline(exp, color='red', linestyle='--', lw=1, label='expected')
    ax.axhspan(low, high, color='lightgray', alpha=0.4, zorder=0, label='95% band')
    ax.set_title(f"{name}\nKL={kl:.3f}  χ²p={chi_p:.3f}")
    ax.set_xlabel("rank bin"); ax.set_ylabel("count")
    ax.set_xlim(-0.5, num_bins - 0.5)

plt.tight_layout()
plt.show()

sbc_summary = pd.DataFrame(sbc_summ_rows).sort_values("parameter")
display(sbc_summary)


In [None]:
# PPC with symmetric KL on counts + Jeffreys smoothing

n_sim_samples = 50
n_trials_per_sim = 100

kl_records = []
skipped_logs = []

def sym_kl_counts(obs_counts, sim_counts, alpha=0.5):
    obs_counts = np.asarray(obs_counts, float)
    sim_counts = np.asarray(sim_counts, float)
    B = len(obs_counts)
    p = (obs_counts + alpha) / (obs_counts.sum() + alpha * B)
    q = (sim_counts + alpha) / (sim_counts.sum() + alpha * B)
    kl_pq = np.sum(p * np.log(p / q))
    kl_qp = np.sum(q * np.log(q / p))
    return 0.5 * (kl_pq + kl_qp)

for i in range(post_samples_used.shape[0]):
    obs_rts = obs_data[i, :, 0]
    sim_rts = []
    skipped = 0

    # choose posterior indices
    idxs = np.random.choice(post_samples_used.shape[1], size=n_sim_samples, replace=False)

    for idx in idxs:
        vL, vR, a, z, terL, terR = post_samples_used[i, idx]

        # parameter sanity checks and logging of errors
        if any([vL <= 0, vR <= 0, a <= 0, not (0 < z < 1), terL < 0, terR < 0]):
            skipped_logs.append({
                "dataset": i, "sample_index": idx,
                "vL": vL, "vR": vR, "a": a, "z": z, "terL": terL, "terR": terR,
                "reason": "invalid parameter values"
            })
            skipped += 1
            continue

        try:
            sim = forward_model_ez2(vL=vL, vR=vR, a=a, z=z, terL=terL, terR=terR,
                                    n_trials=n_trials_per_sim)
            rts = np.asarray(sim["rts"], float)
            # guard against nan/inf
            rts = rts[np.isfinite(rts)]
            if rts.size:
                sim_rts.extend(rts)
        except Exception as e:
            skipped_logs.append({
                "dataset": i, "sample_index": idx,
                "vL": vL, "vR": vR, "a": a, "z": z, "terL": terL, "terR": terR,
                "reason": str(e)
            })
            skipped += 1

    sim_rts = np.asarray(sim_rts, float)

    if sim_rts.size == 0:
        kl_div = np.nan
    else:
        # shared bins from pooled data (Freedman–Diaconis; FD)
        pooled = np.concatenate([obs_rts, sim_rts])
        pooled = pooled[np.isfinite(pooled)]
        bin_edges = np.histogram_bin_edges(pooled, bins='fd')
        # fallback if FD returns too few bins:
        if len(bin_edges) < 5:
            lo, hi = np.nanmin(pooled), np.nanmax(pooled)
            bin_edges = np.linspace(lo, hi, 21)

        obs_counts, _ = np.histogram(obs_rts, bins=bin_edges, density=False)
        sim_counts, _ = np.histogram(sim_rts, bins=bin_edges, density=False)

        kl_div = sym_kl_counts(obs_counts, sim_counts, alpha=0.5)

    kl_records.append({
        "dataset": i,
        "kl_divergence": float(kl_div),
        "skipped_samples": int(skipped),
        "n_sim_rts": int(sim_rts.size)
    })

kl_ppc_df = pd.DataFrame(kl_records)
skipped_df = pd.DataFrame(skipped_logs)


Not all samples of a parameter's posterior are within valid parameter bounds, such as negative values. If the PPC code above selects negative samples, the forward model breaks, since it expects positive inputs. The dataframe below shows that these parameter sets were skipped due to the presence of invalid (negative) parameters.

In [None]:
skipped_df

In [None]:
kl_ppc_df.describe()

In [None]:
plt.figure(figsize=(6, 5))
sns.boxplot(y=kl_ppc_df["kl_divergence"], width=0.35, showfliers=False)
plt.ylabel("Symmetric KL")
plt.grid(True, axis="y", alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# helper function and toggle for applying a log-transform on RTs
# this is required for the standard model, since the real experimental data
# is on a regular scale, while the model was trained on log-transformed data

def apply_rt_transform(x, how):
    if how in (None, "none"): return x.astype(np.float32)
    if how == "log1p":
        if np.any(x < 0): raise ValueError("Negative RT encountered with log1p.")
        return np.log1p(x.astype(np.float32))
    raise ValueError(f"Unknown RT transform: {how}")

RT_TRANSFORM = "log1p"

In [None]:
# Importing the validation data

def load_validation_data_by_subject(base_path):
    """
    Reads
      base_path/exp_1/pp1.txt ... pp20.txt,
      ...
      base_path/exp_14/pp1.txt ... pp20.txt

    Returns: list of dicts:
      {"dataset_id": i, "participants": [
          {"subject_id": "pp1",
           "A": {"rts":..., "stimulus":..., "response":..., "correct":...},
           "B": {...}},
          ...
      ]}
    """

    datasets = []

    for i in range(1, 15):  # exp_1 ... exp_14
        exp_path = os.path.join(base_path, f"exp_{i}")
        participants = []

        for j in range(1, 21):  # pp1 ... pp20
            df = pd.read_csv(os.path.join(exp_path, f"pp{j}.txt"), sep=r"\s+", engine="python")
            df.columns = [c.lower() for c in df.columns]

            df["stimulus"] = df["stim"].map({"L":0,"R":1}).astype("int64")
            df["response"] = df["resp"].map({"L":0,"R":1}).astype("int64")
            df["correct"]  = df["correct"].astype("int64")
            rt_raw = df["rt"].astype("float32").to_numpy()
            rt_for_model = apply_rt_transform(rt_raw, RT_TRANSFORM)

            subj = {"subject_id": f"pp{j}"}
            for cond in ["A","B"]:
                idx = df.index[df["cond"]==cond].to_numpy()
                subj[cond] = {
                    "rts":      rt_for_model[idx],
                    "stimulus": df.loc[idx,"stimulus"].to_numpy(np.int64),
                    "response": df.loc[idx,"response"].to_numpy(np.int64),
                    "correct":  df.loc[idx,"correct"].to_numpy(np.int64),
                }

            participants.append(subj)
        datasets.append({"dataset_id": i, "participants": participants})
    return datasets

base = r"C:/Users/emils/Documents/uni/M_Thesis/diffusion-bayesflow/data/real/validation_text_data/validation_text_data"
val_data_by_subj = load_validation_data_by_subject(base)

In [None]:
# Obtaining posteriors for each subject per condition

def run_inference_per_subject(wf, datasets, num_samples=500):
    """
    Returns: [ {"dataset_id": i,
                "subjects": [{"subject_id": "...",
                              "A_samples": (num_samples, 6) or None,
                              "B_samples": (num_samples, 6) or None, ...}, ...]} ]
    """
    out = []

    for ds in datasets:
        dsid = ds["dataset_id"]
        subj_results = []

        for subj in ds["participants"]:
            sid = subj["subject_id"]
            rec = {"subject_id": sid}

            for cond in ["A", "B"]:
                cd = subj[cond]
                if cd["rts"].size == 0:
                    rec[f"{cond}_samples"] = None
                    rec[f"{cond}_error"] = "no trials"
                    continue

                input_data = {
                    "rts":      cd["rts"][np.newaxis, :],
                    "stimulus": cd["stimulus"][np.newaxis, :],
                    "choices":  cd["response"][np.newaxis, :],
                }

                try:
                    samples = wf.sample(conditions=input_data, num_samples=num_samples, to_numpy=True)
                    # Optional: apply calibration if toggled
                    if APPLY_CALIBRATION:
                        samples = calibrate_bf_sample_dict_a_only(samples, a_model, PARAM_ORDER)
                    param_array = np.column_stack([
                        samples['vL'][0].squeeze(),
                        samples['vR'][0].squeeze(),
                        samples['a'][0].squeeze(),
                        samples['z'][0].squeeze(),
                        samples['terL'][0].squeeze(),
                        samples['terR'][0].squeeze()
                    ]).astype(np.float32)
                    rec[f"{cond}_samples"] = param_array
                except Exception as e:
                    rec[f"{cond}_samples"] = None
                    rec[f"{cond}_error"] = str(e)

            subj_results.append(rec)

        out.append({"dataset_id": dsid, "subjects": subj_results})
    return out

post_by_subj = run_inference_per_subject(wf, val_data_by_subj, num_samples=500)

In [None]:
# Computing condition differences by parameter, on subject-level

def make_subject_deltas(post_by_subj):
    """
    Returns a flat list of rows:
      {"dataset_id": i, "subject_id": sid,
       "delta": {"v": (K,), "a": (K,), "z": (K,), "ter": (K,)} }
    """
    rows = []
    for ds in post_by_subj:
        dsid = ds["dataset_id"]
        for s in ds["subjects"]:
            A, B = s.get("A_samples"), s.get("B_samples")
            if A is None or B is None:
                continue

            vA, vB   = (A[:,0] + A[:,1]), (B[:,0] + B[:,1])
            aA, aB   = A[:,2], B[:,2]
            zA, zB   = np.log(1 - A[:,3]), np.log(1 - B[:,3])
            terA, terB = (A[:,4] + A[:,5]), (B[:,4] + B[:,5])

            rows.append({
                "dataset_id": dsid,
                "subject_id": s["subject_id"],
                "delta": {"v": vB - vA, "a": aB - aA, "z": zB - zA, "ter": terB - terA},
            })
    return rows

subj_deltas = make_subject_deltas(post_by_subj)

In [None]:
# Obtaining group-level condition differences

def aggregate_group_mean(subj_deltas, num_draws=4000, seed=123):
    """
    Composition sampling:
      For each dataset & parameter:
        draw one index from each subject's Δ-samples, average -> one μ draw.
    Returns: [{"dataset_id": i, "v_samples_group": (M,), "a_samples_group":..., ...}, ...]
    """
    rng = np.random.default_rng(seed)

    # collect per dataset/param
    grouped = {}
    for row in subj_deltas:
        dsid = row["dataset_id"]
        grouped.setdefault(dsid, {p: [] for p in ["v","a","z","ter"]})
        for p in ["v","a","z","ter"]:
            grouped[dsid][p].append(row["delta"][p])

    group_samples = []
    for dsid, by_param in grouped.items():
        rec = {"dataset_id": dsid}
        for p, arrs in by_param.items():
            if len(arrs) == 0:
                rec[f"{p}_samples_group"] = None
                continue
            K = arrs[0].shape[0]
            S = len(arrs)
            stacked = np.stack(arrs, axis=1)

            # indices: for each group draw, choose a (possibly different) sample per subject
            idx = rng.integers(0, K, size=(num_draws, S))
            # vectorized gather; average across subjects
            draws = stacked[idx, np.arange(S)].mean(axis=1)
            rec[f"{p}_samples_group"] = draws.astype(np.float32)
        group_samples.append(rec)
    return group_samples

group_samples = aggregate_group_mean(subj_deltas, num_draws=4000)

In [None]:
# Computing HDIs

def compute_hdi_from_group(group_samples, hdi_prob=0.99):
    rows = []
    for ds in group_samples:
        dsid = ds["dataset_id"]
        rec = {"dataset_id": dsid}
        for p in ["v","a","z","ter"]:
            s = ds.get(f"{p}_samples_group")
            if s is None or len(s)==0:
                rec.update({f"{p}_diff_mean": np.nan,
                            f"{p}_diff_median": np.nan,
                            f"{p}_hdi_lower": np.nan,
                            f"{p}_hdi_upper": np.nan,
                            f"{p}_excludes_zero": False})
                continue
            lo = np.percentile(s, (1-hdi_prob)/2*100)
            hi = np.percentile(s, (1+hdi_prob)/2*100)
            rec.update({
                f"{p}_diff_mean": float(np.mean(s)),
                f"{p}_diff_median": float(np.median(s)),
                f"{p}_hdi_lower": float(lo),
                f"{p}_hdi_upper": float(hi),
                f"{p}_excludes_zero": not (lo <= 0 <= hi),
            })
        rows.append(rec)
    return pd.DataFrame(rows)

In [None]:
hdi_df = compute_hdi_from_group(group_samples, hdi_prob=0.99)
hdi_df

In [None]:
sns.set(style="whitegrid")
fig, axes = plt.subplots(3, 2, figsize=(14, 12))
axes = axes.flatten()

param_names = ["v", "a", "z", "ter"]

for i, param in enumerate(param_names):
    ax = axes[i]
        
    subset = hdi_df[["dataset_id", 
                    f"{param}_diff_mean", 
                    f"{param}_hdi_lower", 
                    f"{param}_hdi_upper", 
                    f"{param}_excludes_zero"]].copy()

    subset = subset.sort_values("dataset_id")

    for _, row in subset.iterrows():
        y = row["dataset_id"]
        color = "tab:blue" if row[f"{param}_excludes_zero"] else "gray"
        ax.plot([row[f"{param}_hdi_lower"], row[f"{param}_hdi_upper"]], [y, y], color=color)
        ax.plot(row[f"{param}_diff_mean"], y, "o", color=color)

    ax.axvline(0, color="black", linestyle="--")
    ax.set_xlabel("ΔB-A")
    ax.set_ylabel("Dataset")
    ax.set_title(f"Δ{param}")
    ax.set_yticks(range(1, 15))
    ax.invert_yaxis()

for j in range(len(param_names), 6):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.show()

In [None]:
# Computing BFs

PRIOR_BOUNDS = {
    "vL":   (0.0, 3.0),
    "vR":   (0.0, 3.0),
    "a":    (0.5, 2.5),
    "z":    (0.05, 0.95),
    "terL": (0.2, 0.6),
    "terR": (0.2, 0.6),
}

def subject_counts(subj_deltas):
    S = {}
    for row in subj_deltas:
        S[row["dataset_id"]] = S.get(row["dataset_id"], 0) + 1
    return S

def moment_matched_prior_sd_for_mu(S, num_draws=20000, seed=123, bounds=PRIOR_BOUNDS):
    rng = np.random.default_rng(seed)
    def U(name): lo, hi = bounds[name]; return rng.uniform(lo, hi, size=(num_draws, S))
    vL_A, vR_A = U("vL"), U("vR"); vL_B, vR_B = U("vL"), U("vR")
    a_A, a_B   = U("a"), U("a")
    z_A, z_B   = U("z"), U("z")
    tL_A, tR_A = U("terL"), U("terR"); tL_B, tR_B = U("terL"), U("terR")
    vA, vB   = vL_A+vR_A, vL_B+vR_B
    aA, aB   = a_A, a_B
    zA, zB   = np.log(1 - z_A), np.log(1 - z_B)
    terA,terB= tL_A+tR_A,      tL_B+tR_B
    dV, dA, dZ, dT = (vB-vA), (aB-aA), (zB-zA), (terB-terA)
    muV, muA, muZ, muT = dV.mean(1), dA.mean(1), dZ.mean(1), dT.mean(1)
    return {"v": float(muV.std(ddof=1)),
            "a": float(muA.std(ddof=1)),
            "z": float(muZ.std(ddof=1)),
            "ter": float(muT.std(ddof=1))}

def build_model_prior_sd_map(subj_deltas, num_draws=20000, seed=123, bounds=PRIOR_BOUNDS):
    S_map = subject_counts(subj_deltas)
    return {dsid: moment_matched_prior_sd_for_mu(S, num_draws=num_draws, seed=seed, bounds=bounds)
            for dsid, S in S_map.items()}


from scipy.stats import norm, gaussian_kde

def compute_bf_from_group(group_samples, prior_sd=1.0, bw_mult=1.2, floor=1e-300):
    """
    Simple Savage–Dickey BF at μ=0 using Normal(0, prior_sd) and KDE for posterior.
    prior_sd can be scalar, per-param dict, or {dsid: {param: sd}}.
    Returns BF10, log BF10.
    """
    def _sd_for(dsid, p):
        if isinstance(prior_sd, dict):
            if dsid in prior_sd:
                return float(prior_sd[dsid][p])
            if p in prior_sd:
                return float(prior_sd[p])
        return float(prior_sd)

    rows, tiny = [], np.finfo(float).tiny
    for ds in group_samples:
        dsid = ds["dataset_id"]; rec = {"dataset_id": dsid}
        for p in ["v","a","z","ter"]:
            s = ds.get(f"{p}_samples_group")
            if s is None or len(s)==0:
                rec.update({f"{p}_bf_10": np.nan, f"{p}_ln_bf_10": np.nan, f"{p}_log10_bf_10": np.nan})
                continue

            s = np.asarray(s, np.float64)
            kde = gaussian_kde(s)
            try: kde.set_bandwidth(kde.factor * bw_mult)
            except Exception: pass
            post0 = max(float(kde.evaluate(0.0)[0]), floor)

            sd = max(_sd_for(dsid, p), tiny)
            prior0 = float(norm(0, sd).pdf(0.0))

            log_bf  = np.log(prior0) - np.log(post0)
            bf10   = np.exp(log_bf) if log_bf < 700 else np.inf

            rec[f"{p}_bf_10"]       = float(bf10)
            rec[f"{p}_log_bf_10"]   = float(log_bf)
            rec[f"{p}_prior_sd"]    = sd
        rows.append(rec)
    return pd.DataFrame(rows)

In [None]:
# Build model-based σ_μ map from training priors (depends on S per dataset)
model_sd_map = build_model_prior_sd_map(subj_deltas, num_draws=20000, seed=123)

bf_df = compute_bf_from_group(group_samples, prior_sd=model_sd_map,
                                     bw_mult=1.0, floor=1e-300)

bf_df

In [None]:
# Codebook of valid inferences, based on Dutilh et al. (2019)
ground_truth_planned_analysis = [
    {"dataset_id": 1, "manipulated": {}},
    {"dataset_id": 2, "manipulated": {"v": "B"}},
    {"dataset_id": 3, "manipulated": {"a": "B"}},
    {"dataset_id": 4, "manipulated": {"z": "B"}},
    {"dataset_id": 5, "manipulated": {"v": "B", "a": "B"}},
    {"dataset_id": 6, "manipulated": {"v": "B", "z": "B"}},
    {"dataset_id": 7, "manipulated": {"a": "B", "z": "B"}},
    {"dataset_id": 8, "manipulated": {"v": "A", "a": "B"}},
    {"dataset_id": 9, "manipulated": {"v": "A", "z": "B"}},
    {"dataset_id": 10, "manipulated": {"a": "A", "z": "B"}},
    {"dataset_id": 11, "manipulated": {"v": "A", "a": "B", "z": "B"}},
    {"dataset_id": 12, "manipulated": {"v": "B", "a": "A", "z": "B"}},
    {"dataset_id": 13, "manipulated": {"v": "B", "a": "B", "z": "A"}},
    {"dataset_id": 14, "manipulated": {"v": "B", "a": "B", "z": "B"}},
]

ground_truth_alt_analysis_1 = [
    {'dataset_id': 1, 'manipulated': {}},
    {'dataset_id': 2, 'manipulated': {'v': 'B'}},
    {'dataset_id': 3, 'manipulated': {'a': 'B', 'v': 'B'}},
    {'dataset_id': 4, 'manipulated': {'z': 'B'}},
    {'dataset_id': 5, 'manipulated': {'v': 'B', 'a': 'B'}},
    {'dataset_id': 6, 'manipulated': {'v': 'B', 'z': 'B'}},
    {'dataset_id': 7, 'manipulated': {'a': 'B', 'z': 'B', 'v': 'B'}},
    {'dataset_id': 9, 'manipulated': {'v': 'A', 'z': 'B'}},
    {'dataset_id': 10, 'manipulated': {'a': 'A', 'z': 'B', 'v': 'A'}},
    {'dataset_id': 13, 'manipulated': {'v': 'B', 'a': 'B', 'z': 'A'}},
    {'dataset_id': 14, 'manipulated': {'v': 'B', 'a': 'B', 'z': 'B'}}
]

ground_truth_alt_analysis_2 = [
    {'dataset_id': 1, 'manipulated': {}},
    {'dataset_id': 2, 'manipulated': {'v': 'B'}},
    {'dataset_id': 3, 'manipulated': {'a': 'B', 'ter': 'B'}},
    {'dataset_id': 4, 'manipulated': {'z': 'B'}},
    {'dataset_id': 5, 'manipulated': {'v': 'B', 'a': 'B', 'ter': 'B'}},
    {'dataset_id': 6, 'manipulated': {'v': 'B', 'z': 'B'}},
    {'dataset_id': 7, 'manipulated': {'a': 'B', 'z': 'B', 'ter': 'B'}},
    {'dataset_id': 8, 'manipulated': {'v': 'A', 'a': 'B', 'ter': 'B'}},
    {'dataset_id': 9, 'manipulated': {'v': 'A', 'z': 'B'}},
    {'dataset_id': 10, 'manipulated': {'a': 'A', 'z': 'B', 'ter': 'A'}},
    {'dataset_id': 11, 'manipulated': {'v': 'A', 'a': 'B', 'z': 'B', 'ter': 'B'}},
    {'dataset_id': 12, 'manipulated': {'v': 'B', 'a': 'A', 'z': 'B', 'ter': 'A'}},
    {'dataset_id': 13, 'manipulated': {'v': 'B', 'a': 'B', 'z': 'A', 'ter': 'B'}},
    {'dataset_id': 14, 'manipulated': {'v': 'B', 'a': 'B', 'z': 'B', 'ter': 'B'}}
]

In [None]:
def infer_direction(hdi_df, bf_df, bf_threshold=3.0, params=("v","a","z","ter")):
    """
    Calls a condition difference when:
      - 99% HDI excludes 0, AND
      - BF10 > bf_threshold.
    Direction is by the sign of the median Δ = (B - A):
      Δ > 0 -> "B", Δ < 0 -> "A", Δ = 0 -> "0".
    Calls "0" (no difference) when:
      - HDI includes 0 AND BF10 < 1/bf_threshold.
    Otherwise, "0" (inconclusive).
    """
    merged = pd.merge(hdi_df, bf_df, on="dataset_id", suffixes=("_hdi", "_bf"))
    rows = []
    inv_thr = 1.0 / bf_threshold

    for _, r in merged.iterrows():
        dsid = r["dataset_id"]
        for p in params:
            hdi_excl = bool(r.get(f"{p}_excludes_zero"))
            bf = r.get(f"{p}_bf_10")
            med = r.get(f"{p}_diff_median")

            if hdi_excl and (bf is not None) and (bf > bf_threshold):
                inferred = "B" if med > 0 else ("A" if med < 0 else "0")
            elif (not hdi_excl) and (bf is not None) and (bf < inv_thr):
                inferred = "0"
            else:
                inferred = "0"

            rows.append({
                "dataset_id": dsid,
                "param": p,
                "hdi_excludes_zero": hdi_excl,
                "hdi_B-A_median": med,
                "bf_10": bf,
                "inferred_direction": inferred,
            })

    return pd.DataFrame(rows)


In [None]:
inferred_df = infer_direction(hdi_df, bf_df)

# Return the inference table
inferred_df

In [None]:
from collections import defaultdict
import numpy as np
import pandas as pd

def inference_evaluation(inference_df, ground_truth):
    """
    Returns rates per parameter (and a Total row) with only:
      - Correct  (correct directional + correct null)
      - Miss     (called 0 but truth A/B)
      - False Alarm (called A/B but truth 0)
      - Flip     (called A vs truth B or vice versa)
    """
    counts = defaultdict(lambda: {"correct_dir": 0, "correct_null": 0,
                                  "miss": 0, "false_alarm": 0, "flip": 0, "total": 0})

    # Lookup predictions
    pred = {(r["dataset_id"], r["param"]): r["inferred_direction"]
            for _, r in inference_df.iterrows()}

    for gt in ground_truth:
        ds = gt["dataset_id"]
        truth_map = gt.get("manipulated", {})
        for p in ["v", "a", "z", "ter"]:
            truth = truth_map.get(p, "0")
            guess = pred.get((ds, p))
            if guess is None:
                continue

            if truth == "0" and guess == "0":
                counts[p]["correct_null"] += 1
            elif truth in ("A","B") and guess == truth:
                counts[p]["correct_dir"] += 1
            elif truth in ("A","B") and guess == "0":
                counts[p]["miss"] += 1
            elif truth == "0" and guess in ("A","B"):
                counts[p]["false_alarm"] += 1
            elif truth in ("A","B") and guess in ("A","B") and guess != truth:
                counts[p]["flip"] += 1

            counts[p]["total"] += 1

    # Per-parameter rates (only requested columns)
    rows = []
    for p, c in counts.items():
        T = c["total"]
        if T == 0:
            rows.append({"parameter": p,
                         "Correct": np.nan, "Miss": np.nan,
                         "False Alarm": np.nan, "Flip": np.nan})
            continue
        correct = (c["correct_dir"] + c["correct_null"]) / T
        rows.append({"parameter": p,
                     "Correct": correct,
                     "Miss": c["miss"] / T,
                     "False Alarm": c["false_alarm"] / T,
                     "Flip": c["flip"] / T})

    # Total row
    tot = {"correct_dir": 0, "correct_null": 0, "miss": 0, "false_alarm": 0, "flip": 0, "total": 0}
    for c in counts.values():
        for k in tot: tot[k] += c[k]
    TT = tot["total"]
    total_row = {"parameter": "Total"}
    if TT == 0:
        total_row.update({"Correct": np.nan, "Miss": np.nan, "False Alarm": np.nan, "Flip": np.nan})
    else:
        total_row.update({
            "Correct": (tot["correct_dir"] + tot["correct_null"]) / TT,
            "Miss": tot["miss"] / TT,
            "False Alarm": tot["false_alarm"] / TT,
            "Flip": tot["flip"] / TT
        })

    return pd.concat([pd.DataFrame(rows), pd.DataFrame([total_row])], ignore_index=True)


In [None]:
planned_eval = inference_evaluation(inferred_df, ground_truth_planned_analysis)
planned_eval.round(3)

In [None]:
alt_eval_1 = inference_evaluation(inferred_df, ground_truth_alt_analysis_1)
alt_eval_1.round(3)

In [None]:
alt_eval_2 = inference_evaluation(inferred_df, ground_truth_alt_analysis_2)
alt_eval_2.round(3)

In [None]:
# Filter out non-decision time inferences
inferred_df_no_ter = inferred_df[inferred_df["param"] != "ter"].copy()

# Evaluate with non-decision time excluded
eval_no_ter = inference_evaluation(inferred_df_no_ter, ground_truth_planned_analysis)
eval_no_ter.round(3)