Skip to content

How to correctly load a trained ModelComparisonApproximator model? #506

@dizyd

Description

@dizyd

Dear BayesFlow Team,

as described in this post in the forum, I have trouble saving and loading models when running model comparisons.

To Reproduce
Minimal steps to reproduce the issue, I mostly used the code from this example: https://bayesflow.org/main/_examples/One_Sample_TTest.html#training

import numpy as np
import numpy.random as rng

import os
if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "jax"

import keras
import bayesflow as bf




def prior_null():
    return dict(mu=0.0)

def prior_alternative():
    mu = np.random.normal(loc=0, scale=1)
    return dict(mu=mu)

def likelihood(mu):
    x = np.random.normal(loc=mu, scale=1, size=100)
    return dict(x=x)

simulator_null = bf.make_simulator([prior_null, likelihood])
simulator_alternative = bf.make_simulator([prior_alternative, likelihood])
simulator = bf.simulators.ModelComparisonSimulator(
    simulators=[simulator_null, simulator_alternative], 
    use_mixed_batches=True
)


adapter = (
    bf.Adapter()
    .as_set("x")
    .rename("x", "summary_variables")
    .drop("mu")
    .convert_dtype("float64", "float32")
    )


summary_network = bf.networks.DeepSet(summary_dim=8, dropout=None)
classifier_network = bf.networks.MLP(widths=[32] * 4, activation="silu", dropout=None)


approximator = bf.approximators.ModelComparisonApproximator(
    num_models=2,
    classifier_network=classifier_network,
    summary_network=summary_network,
    adapter=adapter,
)

num_batches_per_epoch = 64
batch_size = 512
epochs = 32
learning_rate = keras.optimizers.schedules.CosineDecay(1e-4, decay_steps=epochs * num_batches_per_epoch)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
approximator.compile(optimizer=optimizer)


history = approximator.fit(
    epochs=epochs,
    num_batches=num_batches_per_epoch,
    batch_size=batch_size,
    simulator=simulator,
    adapter=adapter,
)

df          = simulator.sample(100)
pred_models = approximator.predict(conditions=df)
f=bf.diagnostics.plots.mc_confusion_matrix(
    pred_models=pred_models,
    true_models=df["model_indices"],
    normalize="true"
)

This results in this confusion matrix:

Image

After saving the model with approximator.save("test.keras") and reloading the model with approximator = keras.saving.load_model("test.keras") and running the evaluation code again,

df          = simulator.sample(100)
pred_models = approximator.predict(conditions=df)
f=bf.diagnostics.plots.mc_confusion_matrix(
    pred_models=pred_models,
    true_models=df["model_indices"],
    normalize="true"
)

I get a worse confusion matrix:

Image

Expected Behaviour

I would expect to get the same (or very similar) confusion matrix out again (i.e., similar model categorization performance).

System:

I am running BayesFlow 2.0.3 installed from the dev branch in Visual Studio Code on Windows 11 with Python 3.11.9 and in Jupyter notebook. These are the packages in my environment:

absl-py         - 2.2.2        - active 
arviz           - 0.21.0       - active 
asttokens       - 3.0.0        - active 
bayesflow       - 2.0.3        - active 
colorama        - 0.4.6        - active 
contourpy       - 1.3.2        - active
cycler          - 0.12.1       - active
debugpy         - 1.8.14       - active
decorator       - 5.2.1        - active
executing       - 2.2.0        - active
fonttools       - 4.58.0       - active
h5netcdf        - 1.6.1        - active
h5py            - 3.13.0       - active
ipykernel       - 6.29.5       - active
ipython_pygments_lexers - 1.1.1        - active
ipython         - 9.2.0        - active
jedi            - 0.19.2       - active
jupyter_client  - 8.6.3        - active
jupyter_core    - 5.7.2        - active
kiwisolver      - 1.4.8        - active
llvmlite        - 0.44.0       - active
matplotlib-inline - 0.1.7        - active
matplotlib      - 3.10.3       - active
mdurl           - 0.1.2        - active
mizani          - 0.13.5       - active
numba           - 0.61.2       - active
numpy           - 1.26.4       - active
opt_einsum      - 3.4.0        - active
optree          - 0.15.0       - active
packaging       - 25.0         - active
pandas          - 2.2.3        - active
parso           - 0.8.4        - active
patsy           - 1.0.1        - active
pillow          - 11.2.1       - active
pip             - 25.1.1       - active
platformdirs    - 4.3.8        - active
plotnine        - 0.14.5       - active
prompt_toolkit  - 3.0.51       - active
psutil          - 7.0.0        - active
pure_eval       - 0.2.3        - active
Pygments        - 2.19.1       - active
pyparsing       - 3.2.3        - active
python-dateutil - 2.9.0.post0  - active
pytz            - 2025.2       - active
pywin32         - 310          - active
pyzmq           - 26.4.0       - active
rich            - 14.0.0       - active
scipy           - 1.15.3       - active
seaborn         - 0.13.2       - active
setuptools      - 65.5.0       - active
six             - 1.17.0       - active
stack-data      - 0.6.3        - active 
statsmodels     - 0.14.4       - active 
tornado         - 6.5.1        - active 
tqdm            - 4.67.1       - active 
traitlets       - 5.14.3       - active 
typing_extensions - 4.13.2       - active 
tzdata          - 2025.2       - active 
wcwidth         - 0.2.13       - active 
xarray-einstats - 0.9.0        - active 
xarray          - 2025.4.0     - active 
yolk3k          - 0.9          - active

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions