# FLI with compositional score matching

Before running this notebook, please unzip the data, IRF and noise in the folder `problems/FLI`.

In [None]:
import os
os.environ['KERAS_BACKEND'] = 'torch'

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.stats import median_abs_deviation as mad
from sklearn.metrics import r2_score
from joblib import Parallel, delayed

from bayesflow import diagnostics
from torch.utils.data import DataLoader
from diffusion_model import HierarchicalScoreModel, SDE, TimeSeriesNetwork, ShallowSet, euler_maruyama_sampling, train_score_model, probability_ode_solving
from problems.fli import FLIProblem, FLI_Prior, generate_synthetic_data
from problems import visualize_simulation_output

In [None]:
torch_device = torch.device("mps")

In [None]:

prior = FLI_Prior()
batch_size = 64
number_of_obs = 1
max_number_of_obs = number_of_obs if isinstance(number_of_obs, int) else max(number_of_obs)

current_sde = SDE(
    kernel_type='variance_preserving',
    noise_schedule='cosine'
)

dataset = FLIProblem(
    n_data=30000,
    prior=prior,
    sde=current_sde,
    online_learning=False,
    number_of_obs=number_of_obs,
)

dataset_valid = FLIProblem(
    n_data=1000,
    prior=prior,
    sde=current_sde,
    number_of_obs=number_of_obs
)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False)

for test in dataloader:
    print(test[0].shape)
    print(test[2].shape)
    print(test[4].shape)
    break

In [None]:
# Define diffusion model
n_blocks = [5,6][0]
hidden_dim = [256, 512][0]
hidden_dim_summary = [10, 14, 18, 22][0]
split_summary_vector = [True, False][0]

summary_net = TimeSeriesNetwork(input_dim=1, recurrent_dim=256, summary_dim=hidden_dim_summary)

global_summary_dim = hidden_dim_summary
global_summary_net = ShallowSet(dim_input=hidden_dim_summary, dim_output=global_summary_dim, dim_hidden=128)

score_model = HierarchicalScoreModel(
    input_dim_theta_global=prior.n_params_global,
    input_dim_theta_local=prior.n_params_local,
    input_dim_x_global=global_summary_dim,
    input_dim_x_local=hidden_dim_summary,
    summary_net=summary_net,
    global_summary_net=global_summary_net if isinstance(number_of_obs, list) else None,
    hidden_dim=hidden_dim,
    n_blocks=n_blocks,
    max_number_of_obs=max_number_of_obs,
    prediction_type='v',
    sde=current_sde,
    weighting_type='likelihood_weighting',
    prior=prior,
    dropout_rate=0.1,
    name_prefix=f'FLI_{max_number_of_obs}_{hidden_dim_summary}_{hidden_dim}_{n_blocks}{"_split" if split_summary_vector else ""}_{summary_net.name}_',
    split_summary_vector=split_summary_vector
)

# make dir for plots
if not os.path.exists(f"plots/{score_model.name}"):
    os.makedirs(f"plots/{score_model.name}")

In [None]:
if not os.path.exists(f"models/{score_model.name}.pt"):
    # train model
    loss_history = train_score_model(score_model, dataloader, dataloader_valid=dataloader_valid, hierarchical=True,
                                                  epochs=3000, device=torch_device)
    score_model.eval()
    torch.save(score_model.state_dict(), f"models/{score_model.name}.pt")

    # plot loss history
    plt.figure(figsize=(16, 4), tight_layout=True)
    plt.plot(loss_history[:, 0], label='Training', color="#132a70", lw=2.0, alpha=0.9)
    plt.plot(loss_history[:, 1], label='Validation', linestyle="--", marker="o", color='black')
    plt.grid(alpha=0.5)
    plt.xlabel('Training epoch #')
    plt.ylabel('Value')
    plt.legend()
    plt.savefig(f'plots/{score_model.name}/loss_training.png')
else:
    score_model.load_state_dict(torch.load(f"models/{score_model.name}.pt", map_location=torch_device, weights_only=True))
    score_model.eval()

# Validation

In [None]:
n_local_samples = 32**2
valid_prior_global, valid_prior_local, valid_data = generate_synthetic_data(prior=prior, n_data=100,
                                                                    n_local_samples=n_local_samples, random_seed=0)
n_post_samples = 100
global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(n_local_samples)
#score_model.current_number_of_obs = 4  # we can choose here, how many observations are passed together through the score
score_model.current_number_of_obs = max_number_of_obs
print(valid_data.shape, score_model.current_number_of_obs)

In [None]:
t1_value = 0.001
t0_value = 0.4
sampling_arg = {
    'size': 2,
    'damping_factor': lambda t: t0_value * torch.exp(-np.log(t0_value / t1_value) * t),
    #'damping_factor': lambda t: (1-torch.ones_like(t)) * 1/(n_local_samples-500) + 0.01,
    #'damping_factor': lambda t: (1-torch.ones_like(t)) * 1/(n_local_samples-900) + 0.01,
    #'damping_factor': lambda t: torch.ones_like(t) * 1e-10 + 0.0001,
    #'sampling_chunk_size': 512,
}
score_model.sde.s_shift_cosine = 0

In [None]:
posterior_global_samples_valid = probability_ode_solving(score_model, valid_data,
                                                         n_post_samples=n_post_samples,
                                                         sampling_arg=sampling_arg,
                                                         diffusion_steps=500, device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_global_samples_valid, np.array(valid_prior_global), variable_names=global_param_names)
#fig.savefig(f'plots/{score_model.name}/recovery_global.png')

fig = diagnostics.calibration_ecdf(posterior_global_samples_valid, np.array(valid_prior_global),
                          difference=True, variable_names=global_param_names)
#fig.savefig(f'plots/{score_model.name}/ecdf_global.png')

In [None]:
score_model.sde.s_shift_cosine = 0
score_model.current_number_of_obs = 1
posterior_local_samples_valid = euler_maruyama_sampling(score_model, valid_data,
                                                        n_post_samples=n_post_samples, conditions=posterior_global_samples_valid,
                                                        diffusion_steps=100, device=torch_device, verbose=True)

In [None]:
fig = diagnostics.recovery(posterior_local_samples_valid.reshape(valid_data.shape[0],
                                                                 n_post_samples, -1)[:, :, :12],
                          np.array(valid_prior_local).reshape(valid_data.shape[0], -1)[:, :12],
                          variable_names=local_param_names[:12])

In [None]:
# plot locals for a single grid
for i in range(14, 15):
    tau, tau_2, A = prior.transform_raw_params(
        log_tau=posterior_local_samples_valid[i:i+1, :, :, 0],
        log_delta_tau=posterior_local_samples_valid[i:i+1, :, :, 1],
        a=posterior_local_samples_valid[i:i+1, :, :, 2],
    )
    tau_mean = A * tau + (1-A) * tau_2
    posterior_local_samples_valid_transf = np.concatenate([tau[:, :, :, np.newaxis], tau_2[:, :, :, np.newaxis], A[:, :, :, np.newaxis], tau_mean[:, :, :, np.newaxis]], axis=-1)

    tau, tau_2, A = prior.transform_raw_params(
        log_tau=valid_prior_local[i:i+1, :, 0],
        log_delta_tau=valid_prior_local[i:i+1, :, 1],
        a=valid_prior_local[i:i+1, :, 2],
    )
    tau_mean = A * tau + (1-A) * tau_2
    valid_prior_local_transf = np.concatenate([tau[:, :, np.newaxis], tau_2[:,  :, np.newaxis], A[:, :, np.newaxis], tau_mean[:, :, np.newaxis]], axis=-1)

    fig = diagnostics.recovery(np.transpose(posterior_local_samples_valid_transf, (0,2,1,3)).reshape(-1, n_post_samples, 4)[:, :, :3],
                         valid_prior_local_transf.reshape(-1, 4)[:, :3],
                         variable_names=[r'$\tau_1^L$', r'$\tau_2^L$', r'$A^L$']) #, r'$\tau_\text{mean}^L$'])
    #fig.savefig("plots/hierarchical_simulated_recovery.pdf")
    plt.show()

 # Apply the Model to Real Data

In [None]:
global_param_names = prior.global_param_names

In [None]:
# load MLE binary map
mle_parameters = np.load("problems/FLI/mle_parameters.npy")
tau_mean = mle_parameters[:, :, 2] * mle_parameters[:, :, 0] + (1 - mle_parameters[:, :, 2]) * mle_parameters[:, :, 1]
mle_estimates = np.concatenate((mle_parameters[:, :, :3], tau_mean[..., np.newaxis]), axis=-1)

In [None]:
grid_data = 512 #32
global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(grid_data * grid_data)

x_offset = 0#225
y_offset = 0#245
binned_data = np.load('problems/FLI-all/exp_binned_data.npy')[x_offset:x_offset + grid_data, y_offset:y_offset + grid_data]
binned_data = binned_data.reshape(1, grid_data * grid_data, 256, 1)

data = np.load('problems/FLI-all/final_Data.npy')[:, x_offset:x_offset + grid_data, y_offset:y_offset + grid_data]
data = data.reshape(1, grid_data * grid_data, 256, 1)
cut_off = 17
binary_mask = (np.sum(data, axis=2, keepdims=True) > cut_off)
binary_mask = binary_mask.flatten() & (mle_estimates[:, :, 0] != 0).flatten() & (mle_estimates[:, :, 1] != 0).flatten()  # binary mask sets estimates to 0
binary_mask = binary_mask.reshape(1, binned_data.shape[1], 1, 1)

real_data = binned_data
norm = np.max(real_data, axis=2, keepdims=True)
norm[~binary_mask] = 1
real_data = real_data / norm

plt.imshow(np.sum(data[0], axis=(1,2)).reshape(grid_data, grid_data), cmap='jet')
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
n_post_samples = 100
sampling_arg = {
    'size': 2,
    'damping_factor': lambda t: torch.ones_like(t) * 1e-10 + 0.0001,
    "sampling_weights": binary_mask.flatten() * 1,
}
score_model.sde.s_shift_cosine = 0

In [None]:
posterior_global_samples_real = euler_maruyama_sampling(score_model, real_data,
                                                         n_post_samples=n_post_samples,
                                                         sampling_arg=sampling_arg,
                                                         diffusion_steps=300, device=torch_device, verbose=True)

In [None]:
prior_dict = {}
posterior_dict = {}
prior_tranf_dict = {}
posterior_tranf_dict = {}
for i in range(len(global_param_names)):
    prior_dict[global_param_names[i]] = valid_prior_global[:, i]
    posterior_dict[global_param_names[i]] = posterior_global_samples_real[0, :, i]

tau, tau_2, A = prior.transform_raw_params(
        log_tau=prior_dict[global_param_names[0]],
        log_delta_tau=prior_dict[global_param_names[2]],
        a=prior_dict[global_param_names[4]]
    )
prior_tranf_dict = {
    r'$\tau$': tau,
    r'$\tau_2$': tau_2,
    r'$A$': A
}

tau, tau_2, A = prior.transform_raw_params(
        log_tau=posterior_dict[global_param_names[0]],
        log_delta_tau=posterior_dict[global_param_names[2]],
        a=posterior_dict[global_param_names[4]]
    )
posterior_tranf_dict = {
    r'$\tau$': tau,
    r'$\tau_2$': tau_2,
    r'$A$': A
}
print(r'$\tau$', np.median(tau))
print(r'$\tau_2$', np.median(tau_2))
print(r'$A$', np.median(A))

In [None]:
fig = diagnostics.pairs_posterior(
    posterior_dict,
    priors=prior_dict,
)
#fig.savefig(f'plots/real_data_global_posterior.pdf')

fig = diagnostics.pairs_posterior(
    posterior_tranf_dict,
    priors=prior_tranf_dict,
)
#fig.savefig(f'plots/real_data_global_posterior_transf.pdf')

In [None]:
np.median(posterior_global_samples_real[0], axis=0)

In [None]:
score_model.sde.s_shift_cosine = 0
posterior_local_samples_real = np.ones((1, n_post_samples, real_data.shape[1], 3)) * np.nan
posterior_local_samples_real[:, :, binary_mask.flatten()] = euler_maruyama_sampling(
    score_model, real_data[:, binary_mask.flatten()],
    conditions=posterior_global_samples_real,
    n_post_samples=n_post_samples,
    diffusion_steps=100, device=torch_device, verbose=True
)

In [None]:
#np.save('fli_real_local_samples.npy', posterior_local_samples_real)
posterior_local_samples_real = np.load('fli_real_local_samples.npy')

In [None]:
tau, tau_2, A = prior.transform_raw_params(
    log_tau=posterior_local_samples_real[0, :, :, 0].reshape(n_post_samples, grid_data, grid_data),
    log_delta_tau=posterior_local_samples_real[0, :, :, 1].reshape(n_post_samples, grid_data, grid_data),
    a=posterior_local_samples_real[0, :, :, 2].reshape(n_post_samples, grid_data, grid_data),
)
tau_mean = A * tau + (1-A) * tau_2
ps = np.concatenate([tau[:, :, :, np.newaxis], tau_2[:, :, :, np.newaxis], A[:, :, :, np.newaxis], tau_mean[:, :, :, np.newaxis]], axis=-1)
transf_local_param_names = [r'$\tau_1^L$', r'$\tau_2^L$', r'$A^L$', r'$\tau^\text{mean}$']

med = np.median(ps, axis=0)
posterior_mad = mad(ps, axis=0)
visualize_simulation_output(med,
                            mask=binary_mask.reshape(grid_data, grid_data),
                            title_prefix=['Posterior Median ' + p for p in transf_local_param_names],
                            cmap='jet', scales=[(0,1), (0, 2), (0,1), (0, 2)], add_scale_bar=False)
visualize_simulation_output(posterior_mad,
                            mask=binary_mask.reshape(grid_data, grid_data),
                            title_prefix=['Posterior MAD ' + p for p in transf_local_param_names],
                            cmap='jet', scales=[(0,1), (0, 2), (0,1), (0, 2)], add_scale_bar=False)

In [None]:
import matplotlib as mpl
mpl.rcParams.update({
    "font.size": 12,            # Base font size
    "axes.titlesize": 12,       # Axes title
    "axes.labelsize": 12,       # Axes labels
    "xtick.labelsize": 10,      # Tick labels
    "ytick.labelsize": 10,
})

fig, axis = plt.subplots(2, 1, figsize=(5, 4), layout='constrained', sharex=True, sharey=True)
axis = axis.flatten()
for i, ax in enumerate(axis):
    while True:
        pixel_ids = [np.random.randint(0, grid_data), np.random.randint(0, grid_data)]
        #pixel_ids = [[16, 15], [21, 15]][i]
        #pixel_ids = [[11, 0], [7, 0]][i]
        pixel_ids = [[11, 0], [409, 362]][i]
        if binary_mask.reshape(grid_data, grid_data)[pixel_ids[0], pixel_ids[1]]:
            break  # only plot meaningful data
    print(pixel_ids)
    simulations = np.array([
        prior.simulator.decay_gen_single(
            tau_L=tau[post_index, pixel_ids[0], pixel_ids[1]],
            tau_L_2=tau_2[post_index, pixel_ids[0], pixel_ids[1]],
            A_L=A[post_index, pixel_ids[0], pixel_ids[1]]
        ) for post_index in range(tau.shape[0])
    ])

    ax.plot(real_data.reshape(grid_data, grid_data, 256)[pixel_ids[0], pixel_ids[1]], label='data', color='black')
    ax.plot(np.median(simulations, axis=0), label='posterior median', alpha=0.8, color='red')
    alpha = 0.05
    ax.fill_between(
        np.arange(256),
        np.percentile(simulations, 100*(alpha/2), axis=0),
        np.percentile(simulations, 100*(1-alpha/2), axis=0),
        color='red', alpha=0.3, label='posterior 95% CI'
    )
    if i == 0:
        ax.legend(labels=[r'Real Data', r'Posterior Median', 'Posterior 95% CI'], ncol=1, loc='upper right')
    if i == len(axis) - 1:
        ax.set_xlabel(r'Time [s]')
    ax.set_ylabel('Normalized\nPhoton Count')

#plt.savefig(f'plots/real_data_fit2.pdf', transparent=True, bbox_inches='tight')
plt.show()

In [None]:
fig, axis = plt.subplots(2, 4, figsize=(10, 4), layout='constrained', sharex=True, sharey=True)
axis = axis.flatten()
for i, ax in enumerate(axis):
    while True:
        pixel_ids = [np.random.randint(0, grid_data), np.random.randint(0, grid_data)]
        #pixel_ids = [[16, 15], [21, 15]][i]
        #pixel_ids = [[11, 0], [7, 0]][i]
        if binary_mask.reshape(grid_data, grid_data)[pixel_ids[0], pixel_ids[1]]:
            break  # only plot meaningful data
    #print(pixel_ids)
    simulations = np.array([
        prior.simulator.decay_gen_single(
            tau_L=tau[post_index, pixel_ids[0], pixel_ids[1]],
            tau_L_2=tau_2[post_index, pixel_ids[0], pixel_ids[1]],
            A_L=A[post_index, pixel_ids[0], pixel_ids[1]]
        ) for post_index in range(25)
    ])

    ax.plot(real_data.reshape(grid_data, grid_data, 256)[pixel_ids[0], pixel_ids[1]], label='data', color='black')
    ax.plot(np.median(simulations, axis=0), label='posterior median', alpha=0.8, color='red')
    if i == len(axis) - 1 or i == len(axis) - 2 or i == len(axis) - 3 or i == len(axis) - 4:
        ax.set_xlabel(r'Time [s]')
    if i % 4 == 0:
        ax.set_ylabel('Normalized\nPhoton Count')

    real_pixel = real_data.reshape(grid_data, grid_data, 256)[pixel_ids[0], pixel_ids[1]]
    coverages = {}
    for alpha in [0.01, 0.05, 0.1, 0.2]:
        lo = np.percentile(simulations, 100*(alpha/2), axis=0)
        hi = np.percentile(simulations, 100*(1-alpha/2), axis=0)
        covered = (real_pixel >= lo) & (real_pixel <= hi)
        coverages[alpha] = covered.mean()
    #print(coverages)
    alpha = 0.05
    ax.fill_between(
        np.arange(256),
        np.percentile(simulations, 100*(alpha/2), axis=0),
        np.percentile(simulations, 100*(1-alpha/2), axis=0),
        color='red', alpha=0.3, label='posterior 95% CI'
    )

    ax.text(0.4, 0.95, f'Coverage: {coverages[alpha].mean():.2f}', transform=ax.transAxes,
            fontsize=10, verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))

fig.legend(labels=[r'Data', r'Posterior Median', r'Posterior 95% CI'], bbox_to_anchor=(0.5, -0.1),
           ncol=3, loc='lower center')
#plt.savefig(f'plots/real_data_fit_more.pdf', transparent=True, bbox_inches='tight')
plt.show()

In [None]:
n_post_samples_sim = 20
tau_mle = mle_estimates[:, :, 0]
tau_2_mle = mle_estimates[:, :, 1]
A_mle = mle_estimates[:, :, 2]

@delayed
def wrapper(pixel_i):
    _simulations = np.ones((n_post_samples_sim, 1, grid_data, 256)) * np.nan
    _simulations_mle = np.ones((1, 1, grid_data, 256)) * np.nan
    for pixel_j in range(grid_data):
        if not binary_mask.reshape(grid_data, grid_data)[pixel_i, pixel_j]:
            continue  # not a valid pixel
        _simulations[:, 0, pixel_j, :] = [
            prior.simulator.decay_gen_single(
                tau_L=tau[post_index, pixel_i, pixel_j],
                tau_L_2=tau_2[post_index, pixel_i, pixel_j],
                A_L=A[post_index, pixel_i, pixel_j]
            ) for post_index in range(n_post_samples_sim)
        ]
        _simulations_mle[:, 0, pixel_j, :] = [
            prior.simulator.decay_gen_single(
                tau_L=tau_mle[pixel_i, pixel_j],
                tau_L_2=tau_2_mle[pixel_i, pixel_j],
                A_L=A_mle[pixel_i, pixel_j]
            )
        ]
    return _simulations, _simulations_mle

simulations = Parallel(n_jobs=10, verbose=1)(wrapper(pixel_i) for pixel_i in range(grid_data))
simulations_mle = np.concatenate([s[1] for s in simulations], axis=1)  # shape: (1, grid_data, grid_data, 256)
simulations = np.concatenate([s[0] for s in simulations], axis=1)  # shape: (n_post_samples, grid_data, grid_data, 256)

In [None]:
def ppc_chi2(real_flat, simulations):
    """
    real_flat: array of shape (G, T) — observed values per group/pixel
    simulations: array of shape (S, G, T) — posterior‐predictive draws
    Returns:
      D_obs:  array of shape (G,)    — observed discrepancies
      p_ppc:  array of shape (G,)    — Bayesian p-values
    """
    # 1. Compute summary f_t and empirical variance Var_t
    f = np.median(simulations, axis=0)                # shape (G, T)
    var_t = np.var(simulations, axis=0, ddof=1)       # shape (G, T)

    # in tails simulations can be very small
    var_t[var_t < 1e-10] = 1e-10                       # avoid division by zero

    # 2. Discrepancy for observed data
    D_obs = np.sum((real_flat - f)**2 / var_t, axis=1)  # shape (G,)

    # 3. Discrepancies for replicated data
    #    D_rep[s, g] = sum_t (simulations[s,g,t] - f[g,t])**2 / var_t[g,t]
    diffs = (simulations - f[None, :, :])**2 / var_t[None, :, :]
    D_rep = np.sum(diffs, axis=2)                      # shape (S, G)

    # 4. Bayesian p-values
    p_ppc = np.mean(D_rep >= D_obs[None, :], axis=0)   # shape (G,)

    return D_obs, p_ppc

In [None]:
def compute_pixel_metrics_ppc(real_flat, simulations, alphas=(0.01, 0.05, 0.1, 0.2)):
    """
    real_flat: shape (G, T)  — observed values per group/pixel
    simulations: shape (S, G, T) — S posterior‐predictive draws per group/pixel
    alphas: list of coverage levels for bootstrap CIs
    Returns:
      - coverages: dict[alpha] → array of length G
      - r2s:    array of length G
      - p_ppc:  array of length G   (Bayesian p-value from PPC)
    """
    S, G, T = simulations.shape

    # 1) R² using sklearn (unchanged)
    print('Computing R²')
    median_pred = np.median(simulations, axis=0)  # (G, T)
    r2s = np.zeros(G)
    for g in range(G):
        y = real_flat[g]
        f = median_pred[g]
        r2s[g] = r2_score(y, f)

    if S == 1:
        # compute only R2, everything else is not defined
        return r2s

    # 2) Coverage with bootstrap confidence intervals (unchanged)
    print('Computing coverages')
    coverages = {}
    for alpha in alphas:
        lo = np.percentile(simulations, 100*(alpha/2), axis=0)        # shape (G, T)
        hi = np.percentile(simulations, 100*(1-alpha/2), axis=0)
        covered = (real_flat >= lo) & (real_flat <= hi)
        coverages[alpha] = covered.mean()                     # mean over time and data points

    # 3) Posterior‐predictive
    print('Computing posterior‐predictive check')
    D_obs, ppc = ppc_chi2(real_flat, simulations)
    return r2s, coverages, ppc

In [None]:
# reshape real data and simulations
real_flat = real_data[0].reshape(grid_data*grid_data, 256)[binary_mask.flatten()]          # (G_active, T)
sim_flat  = simulations.reshape(n_post_samples_sim, grid_data*grid_data, 256)[:, binary_mask.flatten(), :]  # (S, G_active, T)
sim_flat_mle  = simulations_mle.reshape(1, grid_data*grid_data, 256)[:, binary_mask.flatten(), :]  # (S, G_active, T)

In [None]:
# Compute pixel metrics
r2s, coverages, ppc = compute_pixel_metrics_ppc(real_flat, sim_flat)
r2s_mle = compute_pixel_metrics_ppc(real_flat, sim_flat_mle)
# 2/3 of the gates are sufficient to capture the whole decay, if it is in the range of 1 ns
r2s_short, coverages_short, ppc_short = compute_pixel_metrics_ppc(real_flat[:, :int(256/3*2)], sim_flat[:, :, :int(256/3*2)])

In [None]:
# Print summary:
print(f"Mean R² over valid pixels (MLE):           {r2s_mle.mean():.3f} ({r2s_mle.std():.3f})")
print(f'Min/Max R²  (MLE):                         {np.min(r2s_mle):.3f}, {np.max(r2s_mle):.3f}')
print('\n')

for a in coverages.keys():
    print(f"Nominal CI = {(1-a)*100:.1f}% → empirical coverage = {coverages[a]:.3f}")
c_keys = (1-np.array(list(coverages.keys())))*100
c_vals = np.array(list(coverages.values()))*100
print(f"Mean R² over valid pixels:           {r2s.mean():.3f} ({r2s.std():.3f})")
print(f'Min/Max R²:                         {np.min(r2s):.3f}, {np.max(r2s):.3f}')
print(f"Mean PPC over valid pixels:   {ppc.mean():.3f} ({ppc.std():.3f})")
print('\n')

print('Using only first 2/3 of the gates:')
for a in coverages_short.keys():
    print(f"Nominal CI = {(1-a)*100:.1f}% → empirical coverage = {coverages_short[a]:.3f}")
c_keys = (1-np.array(list(coverages_short.keys())))*100
c_vals = np.array(list(coverages_short.values()))*100
print(f"Mean R² over valid pixels:           {r2s_short.mean():.3f} ({r2s_short.std():.3f})")
print(f'Min/Max R²:                         {np.min(r2s_short):.3f}, {np.max(r2s_short):.3f}')
print(f"Mean PPC over valid pixels:   {ppc_short.mean():.3f} ({ppc_short.std():.3f})")

In [None]:
r2_map = np.ones(grid_data*grid_data) * np.nan
r2_map[binary_mask.flatten()] = r2s
r2_map = r2_map.reshape(grid_data, grid_data)

r2_map_mle = np.ones(grid_data*grid_data) * np.nan
r2_map_mle[binary_mask.flatten()] = r2s_mle
r2_map_mle = r2_map_mle.reshape(grid_data, grid_data)

p_ppc_map = np.ones(grid_data*grid_data) * np.nan
p_ppc_map[binary_mask.flatten()] = ppc
p_ppc_map = p_ppc_map.reshape(grid_data, grid_data)

p_ppc_map_short = np.ones(grid_data*grid_data) * np.nan
p_ppc_map_short[binary_mask.flatten()] = ppc_short
p_ppc_map_short = p_ppc_map_short.reshape(grid_data, grid_data)

cmap = plt.get_cmap('jet').copy()
cmap.set_bad(color="black")
cmap.set_under(color='black')

fig, ax = plt.subplots(1, 2, layout='constrained', figsize=(6, 2))
im0 = ax[0].imshow(r2_map, cmap=cmap, vmin=0.8, vmax=1)
c0 = fig.colorbar(im0, ax=ax[0])
c0.set_label(r'$R^2$')
ax[0].set_title(r'Hierarchical Bayesian')
ax[0].set_xticks([])
ax[0].set_yticks([])

im1 = ax[1].imshow(r2_map_mle, cmap=cmap, vmin=0.8, vmax=1)
c1 = fig.colorbar(im1, ax=ax[1])
c1.set_label(r'$R^2$')
ax[1].set_title(r'MLE')
ax[1].sharex(ax[0])
ax[1].set_xticks([])
ax[1].set_yticks([])

# Define the pixel-to-micron ratio
microns_per_pixel = 135./512
scale_bar_length_um = 10
scale_bar_length_px = int(scale_bar_length_um / microns_per_pixel)

# Position the scale bar in the upper-right corner
x0 = r2_map_mle.shape[1] - scale_bar_length_px - 40
y0 = 40

# Add the scale bar
ax[0].hlines(y=y0, xmin=x0, xmax=x0 + scale_bar_length_px, color='white', linewidth=2)
ax[1].hlines(y=y0, xmin=x0, xmax=x0 + scale_bar_length_px, color='white', linewidth=2)

# Add label
ax[0].text(x0 + scale_bar_length_px / 2, y0 - 5, f'{scale_bar_length_um} µm',
           color='white', ha='center', va='bottom', fontsize=8)
ax[1].text(x0 + scale_bar_length_px / 2, y0 - 5, f'{scale_bar_length_um} µm',
           color='white', ha='center', va='bottom', fontsize=8)

#plt.savefig(f'plots/real_data_fit_diagnostics.pdf', transparent=True, bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, layout='constrained', figsize=(7, 3))
im2 = ax[0].imshow(p_ppc_map_short, cmap=cmap)
c2 = fig.colorbar(im2, ax=ax[0])
c2.set_label(r'Posterior Predictive $p$‑value')
ax[0].set_title(r'Hierarchical Bayesian')
ax[0].set_xticks([])
ax[0].set_yticks([])

ax[1].hist(ppc_short, density=True)
ax[1].set_title(r'Hierarchical Bayesian')
ax[1].set_xlabel(r'Posterior Predictive $p$‑value')
ax[1].set_ylabel(r'Density')
#plt.savefig(f'plots/real_data_fit_diagnostics_appendix.pdf', transparent=True, bbox_inches='tight')
plt.show()