In [1]:
from functools import partial

import os
if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
import bayesflow as bf
import keras
from numba import jit, njit, prange



In [2]:
RNG = np.random.default_rng(2024)

In [3]:
@njit # (nopython=True, parallel=True)
def find_min_and_argmin(arr):
    """
    Combined minimum value and index finder for 2D array along axis 0.
    More efficient than separate min and argmin operations.
    """
    rows, cols = arr.shape
    min_vals = np.empty(cols, dtype=arr.dtype)
    min_idxs = np.empty(cols, dtype=np.int64)
    
    for j in prange(cols):
        min_val = arr[0, j]
        min_idx = 0
        for i in range(1, rows):
            if arr[i, j] < min_val:
                min_val = min_val
                min_idx = i
        min_vals[j] = min_val
        min_idxs[j] = min_idx
    
    return min_vals, min_idxs

In [4]:
@njit #(nopython=True, parallel=True)
def dmc_experiment_simple(mu, b, s, t0, num_obs, t_max=1000):
    num_accumulators = mu.shape[0]

    fpt = np.zeros((num_accumulators, num_obs))
    
    for n in prange(num_obs):
        for i in prange(num_accumulators):
            xt = 0.0
            for t in range(t_max):
                xt += mu[i, t] + (s * np.random.randn())
                if xt > b:
                    fpt[i, n] = t
                    break

    rt, resp = find_min_and_argmin(fpt)
    rt += t0

    return resp, rt

In [5]:
def prior_fun(mu_c_loc=2, mu_c_scale=1, rng=None):
    num_accumulators = 2

    amp = 20.0
    tau = 30
    a_shape = 2.0
    b = 100.0
    t0 = 300.0
    s = 4.0
    mu_c = rng.gamma(mu_c_loc, mu_c_scale, size=num_accumulators)

    return {"mu_c": mu_c, "amp": amp, "tau": tau, "a_shape": a_shape, "b": b, "s": s, "t0": t0}

In [6]:
def random_number_obs(batch_shape, rng=None):
    return {"num_obs": np.tile(rng.integers(200, 300), batch_shape)}

In [7]:
def simulator_fun(mu_c, b, s, t0, amp, tau, a_shape, num_obs, t_max=1000, rng=None):
    t = np.arange(1, t_max + 1, 1)

    eq4 = (
        amp
        * np.exp(-t / tau)
        * (np.exp(1) * t / (a_shape - 1) / tau) ** (a_shape - 1)
    )

    data = np.zeros((num_obs, 3))

    num_obs_h = int(num_obs/2)

    for m, k in enumerate((1, -1)):
        mu = k * mu_c[:, None] + eq4 * ((a_shape - 1) / t - 1 / tau)
        rt, resp = dmc_experiment_simple(mu, b, s, t0, num_obs_h, t_max)
        data[(m * num_obs_h):((m + 1) * num_obs_h), 0] = rt
        data[(m * num_obs_h):((m + 1) * num_obs_h), 1] = resp
        data[(m * num_obs_h):((m + 1) * num_obs_h), 2] = k

    return {"x": data}

In [8]:
from bayesflow.utils import batched_call, tree_stack

def batch_simulator(batch_shape, simulator_fun, **kwargs):
    data = batched_call(simulator_fun, batch_shape, kwargs=kwargs, flatten=True)
    data = tree_stack(data, axis=0, numpy=True)
    return data

In [9]:
from typing import Callable
from bayesflow.types import Shape

class DmcSimulator(bf.simulators.Simulator):
    def __init__(self, prior_fun: Callable, design_fun: Callable, simulator_fun: Callable):
        self.prior_fun = prior_fun
        self.design_fun = design_fun
        self.simulator_fun = simulator_fun


    def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
        prior_dict = self.prior_fun(batch_shape)

        design_dict = self.design_fun(batch_shape)

        design_dict.update(**kwargs)

        sims_dict = self.simulator_fun(batch_shape, **prior_dict, **design_dict)

        data = prior_dict | design_dict | sims_dict

        data = {
            key: np.expand_dims(value, axis=-1) if np.ndim(value) == 1 else value for key, value in data.items()
        }

        return data

In [10]:
simulator = DmcSimulator(
    prior_fun=partial(batch_simulator, simulator_fun=partial(prior_fun, rng=RNG)),
    design_fun=partial(random_number_obs, rng=RNG),
    simulator_fun=partial(batch_simulator, simulator_fun=simulator_fun)
)

In [11]:
%%timeit

sample_data = simulator.sample((64,))

369 ms ± 17.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
sample_data = simulator.sample((10,))

In [13]:
print("Type of sample_data:\n\t", type(sample_data))
print("Keys of sample_data:\n\t", sample_data.keys())
print("Types of sample_data values:\n\t", {k: type(v) for k, v in sample_data.items()})
print("Shapes of sample_data values:\n\t", {k: v.shape for k, v in sample_data.items()})

Type of sample_data:
	 <class 'dict'>
Keys of sample_data:
	 dict_keys(['mu_c', 'amp', 'tau', 'a_shape', 'b', 's', 't0', 'num_obs', 'x'])
Types of sample_data values:
	 {'mu_c': <class 'numpy.ndarray'>, 'amp': <class 'numpy.ndarray'>, 'tau': <class 'numpy.ndarray'>, 'a_shape': <class 'numpy.ndarray'>, 'b': <class 'numpy.ndarray'>, 's': <class 'numpy.ndarray'>, 't0': <class 'numpy.ndarray'>, 'num_obs': <class 'numpy.ndarray'>, 'x': <class 'numpy.ndarray'>}
Shapes of sample_data values:
	 {'mu_c': (10, 2), 'amp': (10, 1), 'tau': (10, 1), 'a_shape': (10, 1), 'b': (10, 1), 's': (10, 1), 't0': (10, 1), 'num_obs': (10, 1), 'x': (10, 237, 3)}


In [14]:
data_adapter = (
    bf.data_adapters.DataAdapter()
    .to_array()
    .convert_dtype("float64", "float32")
    .concatenate(["mu_c"], into="inference_variables")
)

data_adapter = data_adapter.concatenate(["num_obs"], into="inference_conditions")

data_adapter = data_adapter.as_set(["x"]).concatenate(
    ["x"], into="summary_variables"
)

data_adapter = data_adapter.keep(
    ["inference_variables", "inference_conditions", "summary_variables"]
)

In [15]:
summary_network = bf.networks.SetTransformer()

In [16]:
inference_network = bf.networks.FlowMatching(
    subnet="mlp",
    use_optimal_transport=False,
)

In [17]:
approximator = bf.ContinuousApproximator(
    summary_network=summary_network,
    inference_network=inference_network,
    data_adapter=data_adapter,
)

In [18]:
import keras

learning_rate = 1e-4
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

In [19]:
approximator.compile(optimizer=optimizer)

In [20]:
history = approximator.fit(
    epochs=5,
    num_batches=500,
    batch_size=64,
    # memory_budget="8 GiB",
    simulator=simulator
)

INFO:bayesflow:Building dataset from simulator instance of DmcSimulator.
INFO:bayesflow:Using 32 data loading workers.
INFO:bayesflow:Building on a test batch.


Epoch 1/5
[1m  2/500[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:41:44[0m 12s/step - loss: 8.9624 - loss/inference_loss: 8.9624

In [None]:
from bayesflow_plots import plot_z_score_contraction, plot_recovery
from utils import convert_samples_posterior, convert_samples_prior

In [None]:
param_names = ["v_intercept", "v_slope", "s_true", "b", "t0"]

In [None]:
forward_dict = simulator.sample(
    batch_shape=(100,), num_obs=np.tile([500], (100,))
)

prior_samples = convert_samples_prior(forward_dict, param_names)

sample_dict = {k: v for k, v in forward_dict.items() if k not in data_adapter.keys["inference_variables"]}

posterior_samples_sens = convert_samples_posterior(approximator.sample(
    conditions=sample_dict, num_samples=100
), param_names)

In [None]:
plot_z_score_contraction(posterior_samples_sens, prior_samples, param_names=param_names)

In [None]:
plot_recovery(np.swapaxes(posterior_samples_sens, 0, 1), prior_samples, param_names=param_names)