Skip to content

How to save a trained ContinuousApproximator model? #302

@philippreiser

Description

@philippreiser

Hi!

Description

I am trying to save a trained ContinuousApproximator model on the dev branch using approximator.save("approximator.keras"), but I am unsure if this is the correct approach. Here's a minimal example based on the Linear Regression example notebook:

import os

if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "torch"

import bayesflow as bf
import keras
import numpy as np

def prior():
    # beta: regression coefficients (intercept, slope)
    beta = np.random.normal([2, 0], [3, 1])
    return dict(beta=beta)

def likelihood(beta):
    # x: predictor variable
    x = np.random.normal(0, 1, size=10)
    # y: response variable
    y = np.random.normal(beta[0] + beta[1] * x, size=10)
    return dict(y=y, x=x)

simulator = bf.simulators.make_simulator([prior, likelihood])
adapter = (
    bf.Adapter()
    .as_set(["x", "y"])
    .standardize()
    .concatenate(["beta"], into="inference_variables")
    .concatenate(["x", "y"], into="summary_variables")
)
inference_network = bf.networks.FlowMatching()
summary_network = bf.networks.DeepSet(depth=2)
approximator = bf.ContinuousApproximator(
   inference_network=inference_network,
   summary_network=summary_network,
   adapter=adapter,
)
epochs = 1
num_batches = 1
batch_size = 1
optimizer = keras.optimizers.Adam(learning_rate=5e-4, clipnorm=1.0)
approximator.compile(optimizer=optimizer)
history = approximator.fit(
    epochs=epochs,
    num_batches=num_batches,
    batch_size=batch_size,
    simulator=simulator,
)
approximator.save("approximator.keras")

Error Message:

NotImplementedError                       Traceback (most recent call last)
Cell In[12], line 1
----> 1 approximator.save("approximator.keras")

File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/approximators/continuous_approximator.py:127, in ContinuousApproximator.get_config(self)
    124 def get_config(self):
    125     base_config = super().get_config()
    126     config = {
--> 127         "adapter": serialize(self.adapter),
    128         "inference_network": serialize(self.inference_network),
    129         "summary_network": serialize(self.summary_network),
    130     }
    132     return base_config | config

File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/adapters/adapter.py:55, in Adapter.get_config(self)
     54 def get_config(self) -> dict:
...
File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/adapters/transforms/elementwise_transform.py:20, in ElementwiseTransform.get_config(self)
     19 def get_config(self) -> dict:
---> 20     raise NotImplementedError

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions