-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why local_energy is not jit'ed? #21
Comments
It's because pmap automatically calls jit. If you are calling any of these
functions outside of pmap, you should jit them manually, but it is not
needed in the main training loop thanks to pmap.
…On Mon, Mar 1, 2021 at 1:11 PM connection-on-fiber-bundles < ***@***.***> wrote:
Hi there. Got a quick question on the JAX implementation:
Why local_energy is not jit'ed? To be more specific, I mean the local
energy defined in
https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/train.py#L111-L112
Actually, jit does not even show up in the train.py file at all.
Note that in tests, the local_energy is indeed jit'ed before comparison.
See
https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/tests/hamiltonian_test.py#L147-L149
This could be a JAX 101 question (Sorry in advance!) though. It would be
very helpful if you can share the rationale for not doing jit here and/or
some related performance tips. Thanks!
BTW, I am quite interested in how the laplacian could be calculated in
JAX, and measured the performance of the local_kinetic_energy
<https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/hamiltonian.py#L23>.
I did find that doing jit can significantly improve the performance (which
makes me wonder when you didn't do jit for local_energy). A related
implementation-detail question:
1. why doing a fori_loop to add up the second-order derivatives
instead of first calculating Hessian, for instance, by
jit(jacfwd(jacrev(fun))) then sum the diagonal and the square of the
gradient? Is the concern more on the speed or memory consumption?
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#21>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AABDACA6M2X4VU6VLFOQL73TBOHBHANCNFSM4YMQNA4A>
.
|
The kinetic energy is largely a design choice motivated by memory consumption. The tests actually have a hessian based version for comparison. |
Got it! Thanks for the awesome explanations! You guys are awesome! |
Hi there. Got a quick question on the JAX implementation:
Why local_energy is not jit'ed? To be more specific, I mean the local energy defined in https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/train.py#L111-L112
Actually,
jit
does not even show up in the train.py file at all.Note that in tests, the local_energy is indeed jit'ed before comparison. See https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/tests/hamiltonian_test.py#L147-L149
This could be a JAX 101 question (Sorry in advance!) though. It would be very helpful if you can share the rationale for not doing jit here and/or some related performance tips. Thanks!
BTW, I am quite interested in how the laplacian could be calculated in JAX, and measured the performance of the local_kinetic_energy. I did find that doing jit can significantly improve the performance (which makes me wonder when you didn't do jit for local_energy). A related implementation-detail question:
fori_loop
to add up the second-order derivatives instead of first calculating Hessian, for instance, byjit(jacfwd(jacrev(fun)))
then sum the diagonal and the square of the gradient? Is the concern more on the speed or memory consumption?The text was updated successfully, but these errors were encountered: