Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KFAC Open sourcing #24

Closed
n-gao opened this issue Mar 19, 2021 · 8 comments
Closed

KFAC Open sourcing #24

n-gao opened this issue Mar 19, 2021 · 8 comments

Comments

@n-gao
Copy link
Contributor

n-gao commented Mar 19, 2021

Hi,

I was wondering whether you are planning on releasing the KFAC optimizer used in both papers as well?
I know that the TensorFlow version is available on GitHub. Is the JAX version also going open-source?

Thank you!

@jsspencer
Copy link
Collaborator

Yes, we hope to release a research-level preview of KFAC soon!

@n-gao
Copy link
Contributor Author

n-gao commented Mar 19, 2021

Great to hear, thanks! Is there any ETA for this?

@kngwyu
Copy link

kngwyu commented Apr 14, 2021

@jsspencer
Copy link
Collaborator

KFAC is now integrated (and the default optimiser) in the JAX branch.

@connection-on-fiber-bundles

Hey @jsspencer , thanks a lot for open-sourcing KFAC implementation. Great work!

However, when I run training for Mg with 8 V-100 GPUs (batch size 512), I got an error as follows

terminate called after throwing an instance of 'std::runtime_error'
terminate called recursively
terminate called recursively
terminate called recursively
Fatal Python error: Aborted

Thread 0xterminate called recursively
00007f3da9bf2b80 (most recent call first):
  File "terminate called recursively
/usr/lo  what():  ccuSolver execution failedal/
lib/pyterminate called recursively
thon3.7/dist-packages/jax/interpreters/pxla.py", line 1204 in execute_replicated
  File "/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py", line 648 in xla_pmap_impl
  File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 631 in process_call
  File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1305 in process
  File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1266 in call_bind
  File "/usr/local/lib/python3.7/dist-packages/jax/core.py", line 1302 in bind
  File "/usr/local/lib/python3.7/dist-packages/jax/api.py", line 1574 in f_pmapped
  File "/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py", line 139 in reraise_with_filtered_traceback
  File "/home/tiger/.local/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 567 in step
  File "/opt/tiger/ferminet_jax/ferminet/train.py", line 497 in train
  File "./bin/ferminet", line 35 in main
  File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 251 in _run_main
  File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 303 in run
  File "./bin/ferminet", line 39 in <module>
Aborted (core dumped)

Any clue?

BTW, I was using jax 0.2.9 and jaxlib 0.1.59, not sure if related.

@connection-on-fiber-bundles

BTW, I can successfully train the net using KFAC on smaller atoms like O and F, but not for Na nor Mg.

@jsspencer
Copy link
Collaborator

Hard to know. My suspicion is that the batch size is so small that the estimates required for the curvature in KFAC are noisy. KFAC requires solving the linear equations Ax=b, which is done via a Cholesky decomposition and assumes A is symmetric and positive-definite. The latter requirement might not be met for noisy estimates.

@connection-on-fiber-bundles

@jsspencer Got it, will give it a try, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants