In [None]:
from pathlib import Path
import os

os.environ['KERAS_BACKEND'] = 'jax'
#os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import median_abs_deviation as mad

import keras
import bayesflow as bf
from tqdm import tqdm

from experiments.problems.fli import FLI_Prior
from experiments.problems import visualize_simulation_output

In [None]:
parameterization = ['difference', 'ratios'][0]
prior = FLI_Prior(parameterization)

def prior_bf(n_local_samples=1):
    global_sample, _ = prior._sample_global()
    local_sample_raw, local_sample_trans = prior._sample_local(n_local_samples=n_local_samples)
    local_sample = local_sample_raw
    local_sample.update(local_sample_trans)
    #local_sample.update(global_sample)
    return local_sample

def model_bf(tau_L, tau_L_2, A_L):
    sim = prior.simulator.decay_gen_single(tau_L, tau_L_2, A_L).flatten()
    return dict(sim=sim)

In [None]:
simulator = bf.simulators.make_simulator([prior_bf, model_bf])

In [None]:
%%time
test = simulator.sample(128)

In [None]:
adapter = (
    bf.adapters.Adapter()
    .drop('tau_L')
    .drop('tau_L_2')
    .drop('A_L')
    #.drop('log_tau_L')
    #.drop('log_delta_tau_L')
    #.drop('a_l')
    .to_array()
    .convert_dtype(from_dtype="float64", to_dtype="float32")
    .as_time_series("sim")
    #.concatenate(["log_r_L", "log_s_L", "a_l"], into="inference_variables")
    .concatenate(["log_tau_L", "log_delta_tau_L", "a_l"], into="inference_variables")
    #.concatenate(['log_tau_G', 'log_sigma_tau_G', 'log_delta_tau_G', 'log_delta_sigma_tau_G', 'a_mean', 'a_log_std'], into="inference_variables")
    .rename("sim", "summary_variables")
    .standardize(include="summary_variables")
)

In [None]:
workflow = bf.BasicWorkflow(
    adapter=adapter,
    inference_network=bf.networks.DiffusionModel(),
    summary_network=bf.networks.TimeSeriesNetwork(recurrent_dim=256),
    simulator=simulator
)
filepath = Path("bf_checkpoints") / f"vanilla_fli_time_series_{parameterization}.keras"
filepath.parent.mkdir(exist_ok=True)

In [None]:
history = workflow.fit_online(epochs=100, batch_size=128, num_batches_per_epoch=100, validation_data=1000)
workflow.approximator.save(filepath=filepath)

In [None]:
#workflow.approximator = keras.saving.load_model(filepath)

In [None]:
workflow.plot_default_diagnostics(test_data=300, calibration_ecdf_kwargs={'difference': True})

 # Apply the Model to Real Data

In [None]:
from sklearn.metrics import r2_score
from joblib import Parallel, delayed

In [None]:
global_param_names = prior.global_param_names

In [None]:
# load MLE binary map
mle_parameters = np.load("problems/FLI/mle_parameters.npy")
tau_mean = mle_parameters[:, :, 2] * mle_parameters[:, :, 0] + (1 - mle_parameters[:, :, 2]) * mle_parameters[:, :, 1]
mle_estimates = np.concatenate((mle_parameters[:, :, :3], tau_mean[..., np.newaxis]), axis=-1)

In [None]:
grid_data = 512
global_param_names = prior.global_param_names
local_param_names = prior.get_local_param_names(grid_data * grid_data)

x_offset = 0
y_offset = 0
binned_data = np.load('problems/FLI-all/exp_binned_data.npy')[x_offset:x_offset + grid_data, y_offset:y_offset + grid_data]
binned_data = binned_data.reshape(1, grid_data * grid_data, 256, 1)

data = np.load('problems/FLI-all/final_Data.npy')[:, x_offset:x_offset + grid_data, y_offset:y_offset + grid_data]
data = data.reshape(1, grid_data * grid_data, 256, 1)
cut_off = 17
binary_mask = (np.sum(data, axis=2, keepdims=True) > cut_off)
binary_mask = binary_mask.flatten() & (mle_estimates[:, :, 0] != 0).flatten() & (mle_estimates[:, :, 1] != 0).flatten()  # binary mask sets estimates to 0
binary_mask = binary_mask.reshape(1, binned_data.shape[1], 1, 1)

real_data = binned_data
norm = np.max(real_data, axis=2, keepdims=True)
norm[~binary_mask] = 1
real_data = real_data / norm

plt.imshow(np.sum(data[0], axis=(1,2)).reshape(grid_data, grid_data), cmap='jet')
plt.colorbar()
plt.xticks([])
plt.yticks([])

In [None]:
num_samples = 100
chunk_size = 100

posterior_samples_real = {'log_tau_L': [], 'log_delta_tau_L': [], 'a_l': []}
for start_idx in tqdm(range(0, grid_data**2, chunk_size)):
    end_idx = min(start_idx + chunk_size, grid_data**2)
    posterior_samples_chunk = workflow.sample(conditions={'sim': real_data[0, start_idx:end_idx, :, 0]}, num_samples=num_samples)

    for k in posterior_samples_real.keys():
        posterior_samples_real[k].append(posterior_samples_chunk[k])

for k in posterior_samples_real.keys():
    posterior_samples_real[k] = np.concatenate(posterior_samples_real[k])

In [None]:
tau, tau_2, A = prior.transform_raw_params(
    log_tau=posterior_samples_real['log_tau_L'].T[0].reshape(num_samples, grid_data, grid_data),
    log_delta_tau=posterior_samples_real['log_delta_tau_L'].T[0].reshape(num_samples, grid_data, grid_data),
    a=posterior_samples_real['a_l'].T[0].reshape(num_samples, grid_data, grid_data),
)
ps = np.concatenate([tau[:, :, :, np.newaxis], tau_2[:, :, :, np.newaxis], A[:, :, :, np.newaxis]], axis=-1)
transf_local_param_names = [r'$\tau_1^L$', r'$\tau_2^L$', r'$A^L$']

med = np.median(ps, axis=0)
posterior_mad = mad(ps, axis=0)
visualize_simulation_output(med,
                            mask=binary_mask.reshape(grid_data, grid_data),
                            title_prefix=['Posterior Median ' + p for p in transf_local_param_names],
                            cmap='turbo', scales=[(0,1), (0, 2), (0,1)])
visualize_simulation_output(posterior_mad,
                            mask=binary_mask.reshape(grid_data, grid_data),
                            title_prefix=['Posterior MAD ' + p for p in transf_local_param_names],
                            cmap='turbo', scales=[(0,1), (0, 2), (0,1)])

In [None]:
fig, axis = plt.subplots(1, 5, figsize=(10, 3), tight_layout=True, sharex=True, sharey=True)
axis = axis.flatten()
for ax in axis:
    while True:
        pixel_ids = [np.random.randint(0, grid_data), np.random.randint(0, grid_data)]
        if binary_mask.reshape(grid_data, grid_data)[pixel_ids[0], pixel_ids[1]]:
            break  # only plot meaningful data
    plot_index = np.random.randint(0, tau.shape[0])

    simulations = np.array([
        prior.simulator.decay_gen_single(
            tau_L=tau[post_index, pixel_ids[0], pixel_ids[1]],
            tau_L_2=tau_2[post_index, pixel_ids[0], pixel_ids[1]],
            A_L=A[post_index, pixel_ids[0], pixel_ids[1]]
        ) for post_index in range(tau.shape[0])
    ])

    ax.plot(real_data.reshape(grid_data, grid_data, 256)[pixel_ids[0], pixel_ids[1]], label='data')
    ax.plot(np.median(simulations, axis=0), label='posterior median', alpha=0.8, color='orange')
    ax.fill_between(
        np.arange(simulations.shape[1]),
        np.quantile(simulations, 0.025, axis=0),
        np.quantile(simulations, 0.975, axis=0),
        alpha=0.4,
        color='orange',
        label='posterior 95% CI'
    )
    ax.set_xlabel('Time')
axis[0].set_ylabel('Normalized Photon Count')
fig.legend(labels=['data', 'posterior median', 'posterior 95% CI'], bbox_to_anchor=(0.5, -0.07),
           ncol=3, loc='lower center')
plt.show()

In [None]:
n_post_samples_sim = 20
tau_mle = mle_estimates[:, :, 0]
tau_2_mle = mle_estimates[:, :, 1]
A_mle = mle_estimates[:, :, 2]

@delayed
def wrapper(pixel_i):
    _simulations = np.ones((n_post_samples_sim, 1, grid_data, 256)) * np.nan
    _simulations_mle = np.ones((1, 1, grid_data, 256)) * np.nan
    for pixel_j in range(grid_data):
        if not binary_mask.reshape(grid_data, grid_data)[pixel_i, pixel_j]:
            continue  # not a valid pixel
        _simulations[:, 0, pixel_j, :] = [
            prior.simulator.decay_gen_single(
                tau_L=tau[post_index, pixel_i, pixel_j],
                tau_L_2=tau_2[post_index, pixel_i, pixel_j],
                A_L=A[post_index, pixel_i, pixel_j]
            ) for post_index in range(n_post_samples_sim)
        ]
        _simulations_mle[:, 0, pixel_j, :] = [
            prior.simulator.decay_gen_single(
                tau_L=tau_mle[pixel_i, pixel_j],
                tau_L_2=tau_2_mle[pixel_i, pixel_j],
                A_L=A_mle[pixel_i, pixel_j]
            )
        ]
    return _simulations, _simulations_mle

simulations = Parallel(n_jobs=10, verbose=1)(wrapper(pixel_i) for pixel_i in range(grid_data))
simulations_mle = np.concatenate([s[1] for s in simulations], axis=1)  # shape: (1, grid_data, grid_data, 256)
simulations = np.concatenate([s[0] for s in simulations], axis=1)  # shape: (n_post_samples, grid_data, grid_data, 256)

In [None]:
def ppc_chi2(real_flat, simulations):
    """
    real_flat: array of shape (G, T) — observed values per group/pixel
    simulations: array of shape (S, G, T) — posterior‐predictive draws
    Returns:
      D_obs:  array of shape (G,)    — observed discrepancies
      p_ppc:  array of shape (G,)    — Bayesian p-values
    """
    # 1. Compute summary f_t and empirical variance Var_t
    f = np.median(simulations, axis=0)                # shape (G, T)
    var_t = np.var(simulations, axis=0, ddof=1)       # shape (G, T)

    # in tails simulations can be very small
    var_t[var_t < 1e-10] = 1e-10                       # avoid division by zero

    # 2. Discrepancy for observed data
    D_obs = np.sum((real_flat - f)**2 / var_t, axis=1)  # shape (G,)

    # 3. Discrepancies for replicated data
    #    D_rep[s, g] = sum_t (simulations[s,g,t] - f[g,t])**2 / var_t[g,t]
    diffs = (simulations - f[None, :, :])**2 / var_t[None, :, :]
    D_rep = np.sum(diffs, axis=2)                      # shape (S, G)

    # 4. Bayesian p-values
    p_ppc = np.mean(D_rep >= D_obs[None, :], axis=0)   # shape (G,)

    return D_obs, p_ppc

In [None]:
def compute_pixel_metrics_ppc(real_flat, simulations, alphas=(0.01, 0.05, 0.1, 0.2)):
    """
    real_flat: shape (G, T)  — observed values per group/pixel
    simulations: shape (S, G, T) — S posterior‐predictive draws per group/pixel
    alphas: list of coverage levels for bootstrap CIs
    Returns:
      - coverages: dict[alpha] → array of length G
      - r2s:    array of length G
      - p_ppc:  array of length G   (Bayesian p-value from PPC)
    """
    S, G, T = simulations.shape

    # 1) R² using sklearn (unchanged)
    print('Computing R²')
    median_pred = np.median(simulations, axis=0)  # (G, T)
    r2s = np.zeros(G)
    for g in range(G):
        y = real_flat[g]
        f = median_pred[g]
        r2s[g] = r2_score(y, f)

    if S == 1:
        # compute only R2, everything else is not defined
        return r2s

    # 2) Coverage with bootstrap confidence intervals (unchanged)
    print('Computing coverages')
    coverages = {}
    for alpha in alphas:
        lo = np.percentile(simulations, 100*(alpha/2), axis=0)        # shape (G, T)
        hi = np.percentile(simulations, 100*(1-alpha/2), axis=0)
        covered = (real_flat >= lo) & (real_flat <= hi)
        coverages[alpha] = covered.mean()                     # mean over time and data points

    # 3) Posterior‐predictive
    print('Computing posterior‐predictive check')
    D_obs, ppc = ppc_chi2(real_flat, simulations)
    return r2s, coverages, ppc

In [None]:
# reshape real data and simulations
real_flat = real_data[0].reshape(grid_data*grid_data, 256)[binary_mask.flatten()]          # (G_active, T)
sim_flat  = simulations.reshape(n_post_samples_sim, grid_data*grid_data, 256)[:, binary_mask.flatten(), :]  # (S, G_active, T)
sim_flat_mle  = simulations_mle.reshape(1, grid_data*grid_data, 256)[:, binary_mask.flatten(), :]  # (S, G_active, T)

In [None]:
# Compute pixel metrics
r2s, coverages, ppc = compute_pixel_metrics_ppc(real_flat, sim_flat)
r2s_mle = compute_pixel_metrics_ppc(real_flat, sim_flat_mle)
# 2/3 of the gates are sufficient to capture the whole decay, if it is in the range of 1 ns
r2s_short, coverages_short, ppc_short = compute_pixel_metrics_ppc(real_flat[:, :int(256/3*2)], sim_flat[:, :, :int(256/3*2)])

In [None]:
# Print summary:
print(f"Mean R² over valid pixels (MLE):           {r2s_mle.mean():.3f} ({r2s_mle.std():.3f})")
print(f'Min/Max R²  (MLE):                         {np.min(r2s_mle):.3f}, {np.max(r2s_mle):.3f}')
print('\n')

for a in coverages.keys():
    print(f"Nominal CI = {(1-a)*100:.1f}% → empirical coverage = {coverages[a]:.3f}")
c_keys = (1-np.array(list(coverages.keys())))*100
c_vals = np.array(list(coverages.values()))*100
print(f"Mean R² over valid pixels:           {r2s.mean():.3f} ({r2s.std():.3f})")
print(f'Min/Max R²:                         {np.min(r2s):.3f}, {np.max(r2s):.3f}')
print(f"Mean PPC over valid pixels:   {ppc.mean():.3f} ({ppc.std():.3f})")
print('\n')

print('Using only first 2/3 of the gates:')
for a in coverages_short.keys():
    print(f"Nominal CI = {(1-a)*100:.1f}% → empirical coverage = {coverages_short[a]:.3f}")
c_keys = (1-np.array(list(coverages_short.keys())))*100
c_vals = np.array(list(coverages_short.values()))*100
print(f"Mean R² over valid pixels:           {r2s_short.mean():.3f} ({r2s_short.std():.3f})")
print(f'Min/Max R²:                         {np.min(r2s_short):.3f}, {np.max(r2s_short):.3f}')
print(f"Mean PPC over valid pixels:   {ppc_short.mean():.3f} ({ppc_short.std():.3f})")

In [None]:
r2_map = np.ones(grid_data*grid_data) * np.nan
r2_map[binary_mask.flatten()] = r2s
r2_map = r2_map.reshape(grid_data, grid_data)

r2_map_mle = np.ones(grid_data*grid_data) * np.nan
r2_map_mle[binary_mask.flatten()] = r2s_mle
r2_map_mle = r2_map_mle.reshape(grid_data, grid_data)

p_ppc_map = np.ones(grid_data*grid_data) * np.nan
p_ppc_map[binary_mask.flatten()] = ppc
p_ppc_map = p_ppc_map.reshape(grid_data, grid_data)

p_ppc_map_short = np.ones(grid_data*grid_data) * np.nan
p_ppc_map_short[binary_mask.flatten()] = ppc_short
p_ppc_map_short = p_ppc_map_short.reshape(grid_data, grid_data)

cmap = plt.get_cmap('jet').copy()
cmap.set_bad(color="black")
cmap.set_under(color='black')

fig, ax = plt.subplots(1, 2, layout='constrained', figsize=(6, 2))
im0 = ax[0].imshow(r2_map, cmap=cmap, vmin=0.8, vmax=1)
c0 = fig.colorbar(im0, ax=ax[0])
c0.set_label(r'$R^2$')
ax[0].set_title(r'Flat Bayesian')
ax[0].set_xticks([])
ax[0].set_yticks([])

im1 = ax[1].imshow(r2_map_mle, cmap=cmap, vmin=0.8, vmax=1)
c1 = fig.colorbar(im1, ax=ax[1])
c1.set_label(r'$R^2$')
ax[1].set_title(r'MLE')
ax[1].sharex(ax[0])
ax[1].set_xticks([])
ax[1].set_yticks([])

# Define the pixel-to-micron ratio
microns_per_pixel = 135./512
scale_bar_length_um = 10
scale_bar_length_px = int(scale_bar_length_um / microns_per_pixel)

# Position the scale bar in the upper-right corner
x0 = r2_map_mle.shape[1] - scale_bar_length_px - 40
y0 = 40

# Add the scale bar
ax[0].hlines(y=y0, xmin=x0, xmax=x0 + scale_bar_length_px, color='white', linewidth=2)
ax[1].hlines(y=y0, xmin=x0, xmax=x0 + scale_bar_length_px, color='white', linewidth=2)

# # Add label
# ax[0].text(x0 + scale_bar_length_px / 2, y0 - 5, f'{scale_bar_length_um} µm',
#            color='white', ha='center', va='bottom', fontsize=8)
# ax[1].text(x0 + scale_bar_length_px / 2, y0 - 5, f'{scale_bar_length_um} µm',
#            color='white', ha='center', va='bottom', fontsize=8)

#plt.savefig(f'plots/vanilla_real_data_fit_diagnostics.pdf', transparent=True, bbox_inches='tight')
plt.show()