Skip to content

Offline dataset not possible with a meta function #363

@philippreiser

Description

@philippreiser

Hi!

When using bayesflow.OfflineDataset with a simulator that includes a meta_fn, an error occurs because the dataset initialization assumes all values in the dictionary have a .shape attribute.

Code:

import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["JAX_PLATFORMS"] = "cpu"
import bayesflow as bf
import numpy as np

ad = bf.Adapter()
ad.to_array()

def prior(N):
    return {"w": np.random.normal(1, 0.2, N)}

def meta(batch_size):
    return {"N": np.random.randint(5, 15)}

simulator = bf.make_simulator([prior], meta_fn=meta)


data = simulator.sample(8)
dataset = bf.OfflineDataset(data, batch_size=2, adapter=ad)
dataset[0]

Error Message:

Traceback (most recent call last):
  File "/data/homes/reiser/projects/bayesflow/offline_dataset.py", line 20, in <module>
    dataset = bf.OfflineDataset(data, batch_size=2, adapter=ad)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/homes/reiser/projects/bayesflow/bayesflow/datasets/offline_dataset.py", line 19, in __init__
    self.num_samples = next(iter(data.values())).shape[0]
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'shape'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions