In [1]:
from functools import partial

import os
if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
import bayesflow as bf
import keras
import pydmc

import warnings
warnings.filterwarnings("ignore")

In [2]:
RNG = np.random.default_rng(2024)

In [14]:
%%timeit

dat = pydmc.Sim(full_data=False, n_trls=500, n_trls_data=500)

5.61 ms ± 147 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [30]:
def simulator_fun():
    params = pydmc.Prms(drc=RNG.uniform(0.2, 0.6),  tau=RNG.gamma(shape=12, scale=6))
    data = pydmc.Sim(prms=params, n_trls=500, n_trls_data=500)
    return {"mu_c": params.drc, "tau": params.tau, "x": np.array(data.data).reshape(500, 4), "num_obs": 500}

In [31]:
simulator = bf.simulators.CompositeLambdaSimulator(sample_fns=[simulator_fun])

In [32]:
forward_batch = simulator.sample((64,))

In [33]:
forward_batch["mu_c"]

array([[0.47033253],
       [0.5197864 ],
       [0.23149021],
       [0.2678477 ],
       [0.24215427],
       [0.3860477 ],
       [0.43872896],
       [0.37709022],
       [0.28526294],
       [0.30734614],
       [0.38688353],
       [0.31452733],
       [0.38720763],
       [0.23161373],
       [0.56218994],
       [0.5335588 ],
       [0.29134023],
       [0.33474112],
       [0.3005375 ],
       [0.3702391 ],
       [0.4341549 ],
       [0.5775771 ],
       [0.40757254],
       [0.2965027 ],
       [0.32883912],
       [0.20539059],
       [0.43563   ],
       [0.35172117],
       [0.30428642],
       [0.48115647],
       [0.28740823],
       [0.27878287],
       [0.58720326],
       [0.44809264],
       [0.28301612],
       [0.40218595],
       [0.29873008],
       [0.35341045],
       [0.53782016],
       [0.56202984],
       [0.45619163],
       [0.25061688],
       [0.42512283],
       [0.35656154],
       [0.22077344],
       [0.34798983],
       [0.29228386],
       [0.229

In [34]:
forward_batch["x"].reshape((64, 4, 500)).shape

(64, 4, 500)

In [35]:
data_adapter = bf.ContinuousApproximator.build_data_adapter(
    inference_variables=["mu_c", "tau"],
    inference_conditions=["num_obs"],
    summary_variables=["x"],
    transforms=[
        bf.data_adapters.transforms.Standardize(["mu_c", "tau"])
    ]
)

In [36]:
summary_network = bf.networks.SetTransformer()

In [37]:
inference_network = bf.networks.FlowMatching(
    subnet="mlp",
    subnet_kwargs=dict(
        depth=6,
        width=256,
    ),
    use_optimal_transport=False,
)

In [38]:
approximator = bf.ContinuousApproximator(
    summary_network=summary_network,
    inference_network=inference_network,
    data_adapter=data_adapter,
)

In [39]:
import keras

learning_rate = 1e-4
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

In [40]:
approximator.compile(optimizer=optimizer)

In [41]:
history = approximator.fit(
    epochs=5,
    num_batches=500,
    batch_size=64,
    # memory_budget="8 GiB",
    simulator=simulator
)

INFO:bayesflow:Building dataset from simulator instance of CompositeLambdaSimulator.
INFO:bayesflow:Using 32 data loading workers.
INFO:bayesflow:Building on a test batch.


Epoch 1/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m285s[0m 522ms/step - loss: 5.7747 - loss/inference_loss: 5.7747
Epoch 2/5
[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m280s[0m 524ms/step - loss: 1.5624 - loss/inference_loss: 1.5624
Epoch 3/5
[1m 97/500[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m3:48[0m 566ms/step - loss: 1.5496 - loss/inference_loss: 1.5496

KeyboardInterrupt: 

In [29]:
len(dat.data)

2

In [26]:
dat.data[1]

array([[4.70952738e+02, 3.94759226e+02, 5.47416556e+02, 3.70465686e+02,
        4.82546275e+02, 6.82721510e+02, 4.12059325e+02, 4.11478151e+02,
        3.95528375e+02, 3.97841335e+02, 4.54144905e+02, 5.76271212e+02,
        4.73412241e+02, 4.12751360e+02, 4.55635217e+02, 3.70425029e+02,
        5.65937361e+02, 4.33766064e+02, 5.83830420e+02, 3.67544203e+02,
        5.41247623e+02, 3.92000589e+02, 3.76643433e+02, 5.20069677e+02,
        3.70446778e+02, 4.43413024e+02, 4.33997228e+02, 3.14617351e+02,
        4.60766945e+02, 4.06860214e+02, 1.05664063e+03, 4.58339949e+02,
        3.75236537e+02, 4.36626178e+02, 4.05653611e+02, 3.91570018e+02,
        4.24384582e+02, 4.96497138e+02, 6.75207749e+02, 4.04373555e+02,
        4.33250931e+02, 6.07910261e+02, 4.00387889e+02, 3.94573871e+02,
        6.15015664e+02, 3.65753250e+02, 4.27103578e+02, 5.83653761e+02,
        4.48182126e+02, 3.85681799e+02, 3.82302725e+02, 3.88428423e+02,
        4.46124211e+02, 3.74537440e+02, 4.36425643e+02, 4.608820