-
Notifications
You must be signed in to change notification settings - Fork 78
Milestone
Description
Hi!
When using bayesflow.OfflineDataset with a simulator that includes a meta_fn, an error occurs because the dataset initialization assumes all values in the dictionary have a .shape attribute.
Code:
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["JAX_PLATFORMS"] = "cpu"
import bayesflow as bf
import numpy as np
ad = bf.Adapter()
ad.to_array()
def prior(N):
return {"w": np.random.normal(1, 0.2, N)}
def meta(batch_size):
return {"N": np.random.randint(5, 15)}
simulator = bf.make_simulator([prior], meta_fn=meta)
data = simulator.sample(8)
dataset = bf.OfflineDataset(data, batch_size=2, adapter=ad)
dataset[0]Error Message:
Traceback (most recent call last):
File "/data/homes/reiser/projects/bayesflow/offline_dataset.py", line 20, in <module>
dataset = bf.OfflineDataset(data, batch_size=2, adapter=ad)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/homes/reiser/projects/bayesflow/bayesflow/datasets/offline_dataset.py", line 19, in __init__
self.num_samples = next(iter(data.values())).shape[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'shape'Metadata
Metadata
Assignees
Labels
No labels