Skip to content

Add warning for varying simulator output sizes #370

@LarsKue

Description

@LarsKue

Varying simulator output sizes are a common occurrence when the number of samples varies between calls to simulator.sample():

def context(batch_size):
    n = np.random.randint(10, 101)
    return dict(n=n)

def prior():
    mu = np.random.normal()
    sigma = np.random.gamma(shape=2)
    return dict(mu=mu, sigma=sigma)

def likelihood(n, mu, sigma):
    y = np.random.normal(mu, sigma, size=n)
    return dict(y=y)

simulator = bf.make_simulator([prior, likelihood], meta_fn=context)

However, these can trigger excessive compile times in JAX, where each value for n triggers a recompilation. For a wide range of n, this can mean that the compilation dominates the training time.

The current best-practice fix for users is to use padded tensors:

def likelihood(n, mu, sigma):
    y = np.random.normal(mu, sigma, size=100)  # uses fixed maximum size
    y[n:] = 0  # set unused entries to zero, or some other placeholder value
    return dict(y=y)

When we detect that compile times dominate, we should output a warning to the user, with a suggested fix. We could also improve support for padded simulator output in general. Further, we could look into if there are better ways to mask out unused values rather than just setting them to placeholder values like above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    efficiencySome code needs to be optimizeduser interfaceChanges to the user interface and improvements in usability

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions