-
Notifications
You must be signed in to change notification settings - Fork 78
Closed
Description
Hi!
Description
I am trying to save a trained ContinuousApproximator model on the dev branch using approximator.save("approximator.keras"), but I am unsure if this is the correct approach. Here's a minimal example based on the Linear Regression example notebook:
import os
if "KERAS_BACKEND" not in os.environ:
# set this to "torch", "tensorflow", or "jax"
os.environ["KERAS_BACKEND"] = "torch"
import bayesflow as bf
import keras
import numpy as np
def prior():
# beta: regression coefficients (intercept, slope)
beta = np.random.normal([2, 0], [3, 1])
return dict(beta=beta)
def likelihood(beta):
# x: predictor variable
x = np.random.normal(0, 1, size=10)
# y: response variable
y = np.random.normal(beta[0] + beta[1] * x, size=10)
return dict(y=y, x=x)
simulator = bf.simulators.make_simulator([prior, likelihood])
adapter = (
bf.Adapter()
.as_set(["x", "y"])
.standardize()
.concatenate(["beta"], into="inference_variables")
.concatenate(["x", "y"], into="summary_variables")
)
inference_network = bf.networks.FlowMatching()
summary_network = bf.networks.DeepSet(depth=2)
approximator = bf.ContinuousApproximator(
inference_network=inference_network,
summary_network=summary_network,
adapter=adapter,
)
epochs = 1
num_batches = 1
batch_size = 1
optimizer = keras.optimizers.Adam(learning_rate=5e-4, clipnorm=1.0)
approximator.compile(optimizer=optimizer)
history = approximator.fit(
epochs=epochs,
num_batches=num_batches,
batch_size=batch_size,
simulator=simulator,
)
approximator.save("approximator.keras")Error Message:
NotImplementedError Traceback (most recent call last)
Cell In[12], line 1
----> 1 approximator.save("approximator.keras")
File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/approximators/continuous_approximator.py:127, in ContinuousApproximator.get_config(self)
124 def get_config(self):
125 base_config = super().get_config()
126 config = {
--> 127 "adapter": serialize(self.adapter),
128 "inference_network": serialize(self.inference_network),
129 "summary_network": serialize(self.summary_network),
130 }
132 return base_config | config
File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/adapters/adapter.py:55, in Adapter.get_config(self)
54 def get_config(self) -> dict:
...
File ~/.conda/envs/sabi_env/lib/python3.11/site-packages/bayesflow/adapters/transforms/elementwise_transform.py:20, in ElementwiseTransform.get_config(self)
19 def get_config(self) -> dict:
---> 20 raise NotImplementedErrorMetadata
Metadata
Assignees
Labels
No labels