Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiling local sampler is slow #30

Closed
kazewong opened this issue Aug 17, 2022 · 2 comments · Fixed by #33
Closed

Compiling local sampler is slow #30

kazewong opened this issue Aug 17, 2022 · 2 comments · Fixed by #33

Comments

@kazewong
Copy link
Owner

Compiling the sampler for complicated likelihood function seems pretty slow (Testing on a gravitational wave example now.)

I think this is related to defining the main loop inside the sampler.

Experimenting with abstracting that out for performance

@kazewong
Copy link
Owner Author

@dfm mentioned using scan to replace fori_loop should help reducing the compilation

@kazewong
Copy link
Owner Author

kazewong commented Aug 25, 2022

Did some more testing on MALA.py compilation time. I think the main compilation time overhead comes from Jax compiling derivative of the log likelihood multiple times.

As an example, here is the mala kernel for one proposal.


   ...

    key1, key2 = jax.random.split(rng_key)
    proposal = position + dt * d_logpdf(position)
    proposal +=  jnp.sqrt(2 * dt) * jax.random.normal(key1, shape=position.shape)
    ratio = logpdf(proposal) - logpdf(position)
    ratio -= ((position - proposal - dt * d_logpdf(proposal)) ** 2 / (4 * dt)).sum()
    ratio += ((proposal - position - dt * d_logpdf(position)) ** 2 / (4 * dt)).sum()
    proposal_log_prob = logpdf(proposal)

    log_uniform = jnp.log(jax.random.uniform(key2))
    do_accept = log_uniform < ratio

    position = jnp.where(do_accept, proposal, position)
    log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
    return position, log_prob, do_accept

I tried this example with a gravitational wave likelihood, where one compilation of d_logpdf takes around 70s.
If I have jit the entire kernel, then the compilation time is around 300 seconds.
This seems to indicate Jax does not use the cached version of d_logpdf when jitting the kernel. Even assuming 'proposal' and 'position' somehow trigger recompilation, the compilation time should still be smaller.

Another possibility is Jax unfold the entire mala_kernel computation graph without considering logpdf and d_logpdf are used multiple times hence share the same graph.

@kazewong kazewong linked a pull request Aug 29, 2022 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant