-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Dear BayesFlow Team,
as described in this post in the forum, I have trouble saving and loading models when running model comparisons.
To Reproduce
Minimal steps to reproduce the issue, I mostly used the code from this example: https://bayesflow.org/main/_examples/One_Sample_TTest.html#training
import numpy as np
import numpy.random as rng
import os
if "KERAS_BACKEND" not in os.environ:
# set this to "torch", "tensorflow", or "jax"
os.environ["KERAS_BACKEND"] = "jax"
import keras
import bayesflow as bf
def prior_null():
return dict(mu=0.0)
def prior_alternative():
mu = np.random.normal(loc=0, scale=1)
return dict(mu=mu)
def likelihood(mu):
x = np.random.normal(loc=mu, scale=1, size=100)
return dict(x=x)
simulator_null = bf.make_simulator([prior_null, likelihood])
simulator_alternative = bf.make_simulator([prior_alternative, likelihood])
simulator = bf.simulators.ModelComparisonSimulator(
simulators=[simulator_null, simulator_alternative],
use_mixed_batches=True
)
adapter = (
bf.Adapter()
.as_set("x")
.rename("x", "summary_variables")
.drop("mu")
.convert_dtype("float64", "float32")
)
summary_network = bf.networks.DeepSet(summary_dim=8, dropout=None)
classifier_network = bf.networks.MLP(widths=[32] * 4, activation="silu", dropout=None)
approximator = bf.approximators.ModelComparisonApproximator(
num_models=2,
classifier_network=classifier_network,
summary_network=summary_network,
adapter=adapter,
)
num_batches_per_epoch = 64
batch_size = 512
epochs = 32
learning_rate = keras.optimizers.schedules.CosineDecay(1e-4, decay_steps=epochs * num_batches_per_epoch)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
approximator.compile(optimizer=optimizer)
history = approximator.fit(
epochs=epochs,
num_batches=num_batches_per_epoch,
batch_size=batch_size,
simulator=simulator,
adapter=adapter,
)
df = simulator.sample(100)
pred_models = approximator.predict(conditions=df)
f=bf.diagnostics.plots.mc_confusion_matrix(
pred_models=pred_models,
true_models=df["model_indices"],
normalize="true"
)
This results in this confusion matrix:
After saving the model with approximator.save("test.keras") and reloading the model with approximator = keras.saving.load_model("test.keras") and running the evaluation code again,
df = simulator.sample(100)
pred_models = approximator.predict(conditions=df)
f=bf.diagnostics.plots.mc_confusion_matrix(
pred_models=pred_models,
true_models=df["model_indices"],
normalize="true"
)
I get a worse confusion matrix:
Expected Behaviour
I would expect to get the same (or very similar) confusion matrix out again (i.e., similar model categorization performance).
System:
I am running BayesFlow 2.0.3 installed from the dev branch in Visual Studio Code on Windows 11 with Python 3.11.9 and in Jupyter notebook. These are the packages in my environment:
absl-py - 2.2.2 - active
arviz - 0.21.0 - active
asttokens - 3.0.0 - active
bayesflow - 2.0.3 - active
colorama - 0.4.6 - active
contourpy - 1.3.2 - active
cycler - 0.12.1 - active
debugpy - 1.8.14 - active
decorator - 5.2.1 - active
executing - 2.2.0 - active
fonttools - 4.58.0 - active
h5netcdf - 1.6.1 - active
h5py - 3.13.0 - active
ipykernel - 6.29.5 - active
ipython_pygments_lexers - 1.1.1 - active
ipython - 9.2.0 - active
jedi - 0.19.2 - active
jupyter_client - 8.6.3 - active
jupyter_core - 5.7.2 - active
kiwisolver - 1.4.8 - active
llvmlite - 0.44.0 - active
matplotlib-inline - 0.1.7 - active
matplotlib - 3.10.3 - active
mdurl - 0.1.2 - active
mizani - 0.13.5 - active
numba - 0.61.2 - active
numpy - 1.26.4 - active
opt_einsum - 3.4.0 - active
optree - 0.15.0 - active
packaging - 25.0 - active
pandas - 2.2.3 - active
parso - 0.8.4 - active
patsy - 1.0.1 - active
pillow - 11.2.1 - active
pip - 25.1.1 - active
platformdirs - 4.3.8 - active
plotnine - 0.14.5 - active
prompt_toolkit - 3.0.51 - active
psutil - 7.0.0 - active
pure_eval - 0.2.3 - active
Pygments - 2.19.1 - active
pyparsing - 3.2.3 - active
python-dateutil - 2.9.0.post0 - active
pytz - 2025.2 - active
pywin32 - 310 - active
pyzmq - 26.4.0 - active
rich - 14.0.0 - active
scipy - 1.15.3 - active
seaborn - 0.13.2 - active
setuptools - 65.5.0 - active
six - 1.17.0 - active
stack-data - 0.6.3 - active
statsmodels - 0.14.4 - active
tornado - 6.5.1 - active
tqdm - 4.67.1 - active
traitlets - 5.14.3 - active
typing_extensions - 4.13.2 - active
tzdata - 2025.2 - active
wcwidth - 0.2.13 - active
xarray-einstats - 0.9.0 - active
xarray - 2025.4.0 - active
yolk3k - 0.9 - active

