-
Notifications
You must be signed in to change notification settings - Fork 76
Closed
Labels
serializationIssues related to saving and loading (a.k.a. serialization).Issues related to saving and loading (a.k.a. serialization).
Description
Describe the bug
The serialization of the broadcast
transformation in the adapter does not work. After saving and loading a model, the broadcasting fails. With some debugging, I found that the problem is that self.expand
is loaded as a list
instead of a tuple
. This error might also show up somewhere else, but I am not sure.
To Reproduce
Minimal steps to reproduce the behavior:
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
import numpy as np
import keras
import bayesflow as bf
#%%
def prior():
return dict(alpha=np.random.uniform(0, 1), beta=np.random.uniform(0, 1))
def simulator(alpha, beta):
# Simulate some data based on the alpha parameter
data = np.random.normal(loc=alpha, scale=beta, size=100)
inference_conditions = np.random.normal(loc=0, scale=1, size=(100, 1, 1))
return dict(data=data, inference_conditions=inference_conditions)
adapter = (
bf.adapters.Adapter()
.to_array()
.broadcast("data", to="inference_conditions", expand=(2,3))
.squeeze(keys="inference_conditions", axis=(2,3))
.concatenate(['alpha', 'beta'], into='inference_variables')
)
bf_sim = bf.make_simulator([prior, simulator])
#%%
print(adapter.forward(bf_sim.sample(5))['inference_variables'].shape, adapter.forward(bf_sim.sample(5))['inference_conditions'].shape)
#%%
print(adapter.forward(bf_sim.sample(5))['data'].shape)
#%%
workflow = bf.BasicWorkflow(
simulator=bf_sim,
adapter=adapter,
inference_network=bf.networks.CouplingFlow()
)
#%%
history = workflow.fit_online(
epochs=1,
batch_size=32
)
workflow.approximator.save(filepath='test.keras')
#%%
workflow.plot_default_diagnostics(test_data=bf_sim.sample(100), num_samples=100)
#%%
workflow.approximator = keras.saving.load_model(filepath='test.keras')
#%%
workflow.plot_default_diagnostics(test_data=bf_sim.sample(100), num_samples=100)
Expected behavior
Saving and loading of the adapter should not be problem. If a transform was working before, it should also work after the loading.
Traceback
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (100,100) and requested shape (100,100,1,100)
Environment
- OS: MacOS
- Python Version: 3.11
- Backend: fails for all of them
- BayesFlow Version: 2.0.4 (dev branch from today)
vpratz
Metadata
Metadata
Assignees
Labels
serializationIssues related to saving and loading (a.k.a. serialization).Issues related to saving and loading (a.k.a. serialization).