Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why local_energy is not jit'ed? #21

Closed
connection-on-fiber-bundles opened this issue Mar 1, 2021 · 3 comments
Closed

Why local_energy is not jit'ed? #21

connection-on-fiber-bundles opened this issue Mar 1, 2021 · 3 comments

Comments

@connection-on-fiber-bundles

Hi there. Got a quick question on the JAX implementation:

Why local_energy is not jit'ed? To be more specific, I mean the local energy defined in https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/train.py#L111-L112

Actually, jit does not even show up in the train.py file at all.

Note that in tests, the local_energy is indeed jit'ed before comparison. See https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/tests/hamiltonian_test.py#L147-L149

This could be a JAX 101 question (Sorry in advance!) though. It would be very helpful if you can share the rationale for not doing jit here and/or some related performance tips. Thanks!

BTW, I am quite interested in how the laplacian could be calculated in JAX, and measured the performance of the local_kinetic_energy. I did find that doing jit can significantly improve the performance (which makes me wonder when you didn't do jit for local_energy). A related implementation-detail question:

  1. why doing a fori_loop to add up the second-order derivatives instead of first calculating Hessian, for instance, by jit(jacfwd(jacrev(fun))) then sum the diagonal and the square of the gradient? Is the concern more on the speed or memory consumption?
@dpfau
Copy link
Collaborator

dpfau commented Mar 1, 2021 via email

@jsspencer
Copy link
Collaborator

pmap also triggers an XLA compilation (see e.g. https://jax.readthedocs.io/en/latest/jax.html?highlight=pmap#jax.pmap and google/jax#5307). vmap does not trigger a compilation, which is why the tests use an explicit jit call.

The kinetic energy is largely a design choice motivated by memory consumption. The tests actually have a hessian based version for comparison.

@connection-on-fiber-bundles
Copy link
Author

Got it! Thanks for the awesome explanations! You guys are awesome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants