Skip to content

Some inference nets do not allow inference_variables during sampling #328

@Kucharssim

Description

@Kucharssim

I am on this commit: 1dbaedd

It should be possible to easily obtain posterior samples using a dataset from a simulator, e.g.

dataset = simulator.sample(1000)
posterior_samples = amortizer.sample(num_samples=500, conditions=dataset)

Some inference nets however do not play nice if inference_variables are in the input dict during sampling. Specifically,

  • Flow Matching
TypeError: Exception encountered when calling FlowMatching.call().

rk45_step() got an unexpected keyword argument 'inference_variables'

This error is independent of which integrator we select (presumably, I just only tried euler as an alternative)

  • Consistency Model
ValueError: In a nested call() argument, you cannot mix tensors and non-tensors. Received invalid mixed argument: kwargs={'density': False, 'inference_variables': Array([[-0.19996756],
       [ 0.17675304],
       [ 0.27663922],
       [-1.2399167 ],
       [ 1.1492295 ],
       [-1.2453966 ],
       [ 0.35556716],
       [-1.3496909 ],
       [ 0.14714664],
       [ 0.46096817]], dtype=float32)}

For both nets, if we remove inference variables from the conditions, they run fine.

For affine and spline coupling networks, it runs fine.

Full code:

def prior():
    mu = np.random.normal(loc=0, scale=1)
    return dict(mu = mu)

def likelihood(mu):
    x = np.random.normal(loc=mu, scale=1)
    return dict(x=x)

simulator = bf.make_simulator([prior, likelihood])


workflow = bf.BasicWorkflow(
    simulator=simulator,
    # uncomment whichever network you want to test
    #inference_network=bf.networks.FlowMatching(),
    #inference_network=bf.networks.CouplingFlow(),
    #inference_network=bf.networks.CouplingFlow(transform="spline"),
    #inference_network=bf.networks.ConsistencyModel(total_steps=100),
    inference_variables="mu",
    inference_conditions="x"
)

h=workflow.fit_online(epochs=1)

dataset=simulator.sample(10)

# uncomment the next line to make flow matching and consistency model run
# dataset=dict(x=dataset["x"])

posterior_samples=workflow.sample(num_samples=100, conditions=dataset)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions