In [113]:
from functools import partial

import os
if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
import bayesflow as bf
import keras
from numba import jit, njit, prange
import numba.random


ModuleNotFoundError: No module named 'numba.random'

In [16]:
RNG = np.random.default_rng(2024)

In [116]:
def simulator_fun():
    num_accumulators = 2
    num_obs = 500
    dt = 1
    t_max = 1_000
    t = np.linspace(dt, t_max, int(t_max/dt))
    amp = 20
    tau = RNG.gamma(12, 6)
    a_shape = 2
    b = 100
    t0 = 300
    s = 4
    mu_c = RNG.gamma(1, 1, size=2)

    eq4 = (
        amp
        * np.exp(-t / tau)
        * (np.exp(1) * t / (a_shape - 1) / tau)
        ** (a_shape - 1)
    )

    mu = mu_c[:, None] + eq4 * ((a_shape - 1) / t - 1 / tau)

    # fpt = np.zeros((batch_size, num_accumulators, num_obs))
    
    xt = mu[:, None] + (s * RNG.normal(size=(num_accumulators, num_obs, t_max)))
    fpt = t[(np.cumsum(xt, axis=2) > b).argmax(axis=2)]

    resp = fpt.argmin(axis=0)
    rt = (fpt.min(axis=0) + t0)

    return {"x": np.c_[rt, resp], "mu_c": mu_c, "tau": tau}


In [117]:
%%timeit

simulator_fun()["x"]

13.4 ms ± 125 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [114]:
@njit
def simulator_fun():
    num_accumulators = 2
    num_obs = 500
    dt = 1.0
    t_max = 1000
    t = np.linspace(dt, t_max, int(t_max / dt))
    amp = 20.0
    tau = RNG.gamma(12, 6)  # Random tau from gamma distribution
    a_shape = 2.0
    b = 100.0
    t0 = 300.0
    s = 4.0
    mu_c = RNG.gamma(1, 1, size=num_accumulators)

    eq4 = (
        amp
        * np.exp(-t / tau)
        * (np.exp(1) * t / (a_shape - 1) / tau) ** (a_shape - 1)
    )

    mu = mu_c[:, None] + eq4 * ((a_shape - 1) / t - 1 / tau)

    # Initialize arrays for results
    fpt = np.zeros((num_accumulators, num_obs))
    xt = np.zeros((num_accumulators, num_obs, t_max))

    # Fill the xt array with random values
    for i in range(num_accumulators):
        xt[i, :, :] = mu[i, None, :] + (s * RNG.standard_normal(size=(num_obs, t_max)))

    # Compute fpt manually
    for i in range(num_accumulators):
        for j in range(num_obs):
            cumulative_sum = 0.0
            for k in range(t_max):
                cumulative_sum += xt[i, j, k]
                if cumulative_sum > b:
                    fpt[i, j] = t[k]
                    break

    # Compute resp and rt
    resp = np.empty(num_obs, dtype=np.int64)
    rt = np.empty(num_obs)

    for j in range(num_obs):
        min_fpt = np.inf
        min_idx = -1
        for i in range(num_accumulators):
            if fpt[i, j] < min_fpt:
                min_fpt = fpt[i, j]
                min_idx = i
        resp[j] = min_idx
        rt[j] = min_fpt + t0

    return {"x": rt, "mu_c": mu_c, "tau": tau}

In [115]:
simulator_fun()

NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
[1m[1m<numba.core.base.OverloadSelector object at 0x7cb03417d150>, (NumPyRandomGeneratorType,)[0m
[0m[1mDuring: lowering "$44load_global.15 = global(RNG: Generator(PCG64))" at /tmp/ipykernel_191116/3431294076.py (9)[0m

In [19]:
simulator = bf.simulators.CompositeLambdaSimulator(sample_fns=[simulator_fun])

In [20]:
forward_batch = simulator.sample((10,))

In [21]:
data_adapter = bf.ContinuousApproximator.build_data_adapter(
    inference_variables=["mu_c", "tau"],
    # inference_conditions=["num_obs"],
    summary_variables=["x"],
    transforms=[
        bf.data_adapters.transforms.Standardize(["mu_c", "tau"])
    ]
)

In [22]:
summary_network = bf.networks.SetTransformer()

In [23]:
inference_network = bf.networks.FlowMatching(
    subnet="mlp",
    subnet_kwargs=dict(
        depth=6,
        width=256,
    ),
    use_optimal_transport=False,
)

In [24]:
approximator = bf.ContinuousApproximator(
    summary_network=summary_network,
    inference_network=inference_network,
    data_adapter=data_adapter,
)

In [25]:
import keras

learning_rate = 1e-4
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

In [26]:
approximator.compile(optimizer=optimizer)

In [27]:
history = approximator.fit(
    epochs=5,
    num_batches=500,
    batch_size=64,
    # memory_budget="8 GiB",
    simulator=simulator
)

INFO:bayesflow:Building dataset from simulator instance of CompositeLambdaSimulator.
INFO:bayesflow:Using 32 data loading workers.
INFO:bayesflow:Building on a test batch.


Epoch 1/5
[1m135/500[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m5:23[0m 887ms/step - loss: 1.5400 - loss/inference_loss: 1.5400

KeyboardInterrupt: 

In [None]:
from bayesflow_plots import plot_z_score_contraction, plot_recovery
from utils import convert_samples_posterior, convert_samples_prior

In [None]:
param_names = ["v_intercept", "v_slope", "s_true", "b", "t0"]

In [None]:
forward_dict = simulator.sample(
    batch_shape=(100,), num_obs=np.tile([500], (100,))
)

prior_samples = convert_samples_prior(forward_dict, param_names)

sample_dict = {k: v for k, v in forward_dict.items() if k not in data_adapter.keys["inference_variables"]}

posterior_samples_sens = convert_samples_posterior(approximator.sample(
    conditions=sample_dict, num_samples=100
), param_names)

In [None]:
plot_z_score_contraction(posterior_samples_sens, prior_samples, param_names=param_names)

In [None]:
plot_recovery(np.swapaxes(posterior_samples_sens, 0, 1), prior_samples, param_names=param_names)