-
Notifications
You must be signed in to change notification settings - Fork 77
Closed as not planned
Labels
efficiencySome code needs to be optimizedSome code needs to be optimizeduser interfaceChanges to the user interface and improvements in usabilityChanges to the user interface and improvements in usability
Description
Varying simulator output sizes are a common occurrence when the number of samples varies between calls to simulator.sample():
def context(batch_size):
n = np.random.randint(10, 101)
return dict(n=n)
def prior():
mu = np.random.normal()
sigma = np.random.gamma(shape=2)
return dict(mu=mu, sigma=sigma)
def likelihood(n, mu, sigma):
y = np.random.normal(mu, sigma, size=n)
return dict(y=y)
simulator = bf.make_simulator([prior, likelihood], meta_fn=context)However, these can trigger excessive compile times in JAX, where each value for n triggers a recompilation. For a wide range of n, this can mean that the compilation dominates the training time.
The current best-practice fix for users is to use padded tensors:
def likelihood(n, mu, sigma):
y = np.random.normal(mu, sigma, size=100) # uses fixed maximum size
y[n:] = 0 # set unused entries to zero, or some other placeholder value
return dict(y=y)When we detect that compile times dominate, we should output a warning to the user, with a suggested fix. We could also improve support for padded simulator output in general. Further, we could look into if there are better ways to mask out unused values rather than just setting them to placeholder values like above.
Metadata
Metadata
Assignees
Labels
efficiencySome code needs to be optimizedSome code needs to be optimizeduser interfaceChanges to the user interface and improvements in usabilityChanges to the user interface and improvements in usability