-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training gives nan after the first iteration #5
Comments
It seems like I was able to fix the error by uncommenting this line in train. Maybe this is something that you may want to consider updating in your repo?
I still don't know why it gave an error though. It seems like this is something you encountered before. Would you mind sharing with me what this fix meant? |
I ran into this issue before with Jax 0.4.13 and fixed it by downgrading to 0.4.11. Very interesting that disabling remat fixes it --- I never figured that out. The commented-out line is just a remnant from when I was experimenting with memory usage. You might want to try different Jax versions instead of disabling remat, since (at least in theory) remat can give significant speed and memory savings. |
I see. Thanks for getting back to me! Will update that :) Very cool paper by the way! |
Hello, when I try to run training with the command
It seems like the train loss becomes nan immediately after the first iteration. Is this something that you have encountered before?
The text was updated successfully, but these errors were encountered: