New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a module to apply updates every k steps (and accumulate them othe… #2350
Conversation
Looks great! |
jax/experimental/optix.py
Outdated
reset = state.count % k | ||
emit = reset == (k - 1) | ||
grad_acc = tree_multimap( | ||
lambda g, ga: (reset == 0) * ga + g, updates, state.grad_acc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name reset
doesn't actually reflect what's in the variable and I think the guard is wrong (afaik reset == 0
should actually be reset != 0
).
I'd suggest:
c = state.count % k
acc = c != 0
grad_acc = tree_multimap(lambda g, ga: acc * ga + g, updates, state.grad_acc)
emit = c == (k - 1)
updates = tree_multimap(lambda ga: emit * ga, grad_acc)
It would also be great if there was a test for this to avoid regressions..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Tom for spotting that problem. I'd be happy to do some tests but where is optix tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6321a78
to
3b34ab3
Compare
Tests added! The tolerance is quite high absolute tolerance is 1e-6 and the relative tolerance is 100. I am not sure this solution is good though. What is recommended to handle the lower numerical precision on TPU? |
We can just skip the test on the TPU; look for the |
I'll merge, then mark this test as skipped on TPU. |
I confirmed internal tests pass after merging this (and after cc53aa9). |
…rwise)