In [1]:
from functools import partial

import os
if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
import bayesflow as bf
import keras

from priors import rdm_prior_simple
from design import random_num_obs_discrete
from simulation import rdm_experiment_simple

from bayesflow.utils import batched_call, tree_stack

In [2]:
prior = partial(rdm_prior_simple, drift_intercept_loc=2,
    drift_intercept_scale=1,
    drift_slope_loc=2,
    drift_slope_scale=1,
    sd_true_shape=10,
    sd_true_scale=0.2,
    threshold_shape=10,
    threshold_scale=0.2,
    t0_loc=0.2,
    t0_scale=0.1,
    t0_lower=0)

In [3]:
prior(batch_shape=(20,), rng=np.random.default_rng())

{'v_intercept': array([1.95537526, 2.03643292, 2.38406905, 2.06980881, 3.34972961,
        2.83652213, 4.4253601 , 2.88190308, 0.14754334, 2.15344113,
        1.34355198, 1.40525777, 4.32912583, 0.52175012, 1.58420339,
        1.62959361, 1.95521059, 2.97169343, 2.54474749, 0.16199953]),
 'v_slope': array([2.32131513, 2.9085572 , 1.62775466, 3.28825945, 2.34459281,
        2.27360641, 3.12837419, 3.46645412, 2.19978514, 0.99126966,
        2.79991913, 1.54100459, 1.19205535, 1.59248276, 1.06537696,
        2.43574518, 2.12081884, 1.42249506, 2.58951566, 3.1441317 ]),
 's_true': array([1.84747074, 1.86614251, 2.43687233, 2.55151262, 2.00513866,
        1.08257172, 0.73584575, 1.35227617, 2.07134563, 1.93312845,
        3.36361351, 2.34877772, 2.69851691, 2.11172094, 1.38261597,
        2.21375078, 1.98589538, 1.59637723, 1.75036982, 0.99244079]),
 'b': array([1.71040488, 1.73630572, 3.29784917, 2.17377695, 1.81187374,
        2.40236884, 2.31030707, 1.78729236, 1.50653086, 3.39872245,
 

In [4]:
num_obs = partial(random_num_obs_discrete, values=[100, 250, 500, 1000])

In [5]:
sim = partial(rdm_experiment_simple, s_false=1)

In [6]:
def batched_sim(batch_shape, **kwargs):
    data = batched_call(sim, batch_shape, kwargs=kwargs, flatten=True)
    data = tree_stack(data, axis=0, numpy=True)
    return data

In [7]:
simulator = bf.simulators.CompositeLambdaSimulator([prior, num_obs, batched_sim], **{"is_batched": True})

In [8]:
sample_data = simulator.sample((10,))

In [9]:
print("Type of sample_data:\n\t", type(sample_data))
print("Keys of sample_data:\n\t", sample_data.keys())
print("Types of sample_data values:\n\t", {k: type(v) for k, v in sample_data.items()})
print("Shapes of sample_data values:\n\t", {k: v.shape for k, v in sample_data.items()})

Type of sample_data:
	 <class 'dict'>
Keys of sample_data:
	 dict_keys(['v_intercept', 'v_slope', 's_true', 'b', 't0', 'num_obs', 'x'])
Types of sample_data values:
	 {'v_intercept': <class 'numpy.ndarray'>, 'v_slope': <class 'numpy.ndarray'>, 's_true': <class 'numpy.ndarray'>, 'b': <class 'numpy.ndarray'>, 't0': <class 'numpy.ndarray'>, 'num_obs': <class 'numpy.ndarray'>, 'x': <class 'numpy.ndarray'>}
Shapes of sample_data values:
	 {'v_intercept': (10, 1), 'v_slope': (10, 1), 's_true': (10, 1), 'b': (10, 1), 't0': (10, 1), 'num_obs': (10, 1), 'x': (10, 100, 2)}


In [10]:
data_adapter = bf.ContinuousApproximator.build_data_adapter(
    inference_variables=["v_intercept", "v_slope", "s_true", "b", "t0"],
    inference_conditions=["num_obs"],
    summary_variables=["x"],
    transforms=[
        bf.data_adapters.transforms.Standardize(["v_intercept", "v_slope", "s_true", "b", "t0"]),
        bf.data_adapters.transforms.NumpyTransform("num_obs", forward="sqrt", inverse="square")
    ]
)

In [11]:
summary_network = bf.networks.SetTransformer()

In [12]:
inference_network = bf.networks.FlowMatching(
    subnet="mlp",
    subnet_kwargs=dict(
        depth=6,
        width=256,
    ),
    use_optimal_transport=True,
)

In [13]:
approximator = bf.ContinuousApproximator(
    summary_network=summary_network,
    inference_network=inference_network,
    data_adapter=data_adapter,
)

In [14]:
import keras

learning_rate = 1e-4
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

In [15]:
approximator.compile(optimizer=optimizer)

In [16]:
history = approximator.fit(
    epochs=5,
    num_batches=500,
    batch_size=64,
    # memory_budget="8 GiB",
    simulator=simulator
)

INFO:bayesflow:Building dataset from simulator instance of CompositeLambdaSimulator.
INFO:bayesflow:Using 32 data loading workers.
INFO:bayesflow:Building on a test batch.


Epoch 1/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m57s[0m 90ms/step - loss: 1.6968 - loss/inference_loss: 1.6968
Epoch 2/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 32ms/step - loss: 1.3059 - loss/inference_loss: 1.3059
Epoch 3/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 34ms/step - loss: 1.0376 - loss/inference_loss: 1.0376
Epoch 4/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 34ms/step - loss: 0.8730 - loss/inference_loss: 0.8730
Epoch 5/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 34ms/step - loss: 0.8133 - loss/inference_loss: 0.8133


In [17]:
from bayesflow.types import Shape
from numpy import ndarray


class RdmSimulator(bf.simulators.Simulator):
    def sample(self, batch_shape: tuple[int, ...], rng, **kwargs) -> dict[str, ndarray]:
        print(kwargs)
        prior_dict = prior(batch_shape, rng=rng)
        num_obs_dict = num_obs(batch_shape, rng=rng, **kwargs)
        data = batched_sim(batch_shape, **prior_dict, rng=rng, **num_obs_dict, **kwargs)
        return data

In [18]:
simulator = RdmSimulator()

In [19]:
simulator.sample(batch_shape=(10,), rng=np.random.default_rng(2024))

{}


{'x': array([[[1.29160164, 0.        ],
         [1.23875701, 1.        ],
         [0.95308439, 1.        ],
         ...,
         [0.67960705, 1.        ],
         [1.09005947, 0.        ],
         [0.46358008, 1.        ]],
 
        [[0.73348642, 0.        ],
         [0.54917246, 1.        ],
         [0.6059088 , 0.        ],
         ...,
         [0.3276139 , 1.        ],
         [0.68810598, 1.        ],
         [0.4210034 , 1.        ]],
 
        [[0.38195942, 1.        ],
         [0.53835735, 0.        ],
         [0.38568983, 0.        ],
         ...,
         [0.39153377, 1.        ],
         [0.50812967, 1.        ],
         [0.28875576, 1.        ]],
 
        ...,
 
        [[1.24211333, 0.        ],
         [0.34163651, 1.        ],
         [0.4579847 , 1.        ],
         ...,
         [0.95142886, 1.        ],
         [0.45873321, 1.        ],
         [0.52265442, 1.        ]],
 
        [[0.89800966, 1.        ],
         [0.46239307, 1.        ],
  