Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a module to apply updates every k steps (and accumulate them othe… #2350

Merged
merged 1 commit into from Mar 10, 2020

Conversation

perolat
Copy link
Contributor

@perolat perolat commented Mar 4, 2020

…rwise)

@mtthss
Copy link
Contributor

mtthss commented Mar 4, 2020

Looks great!

reset = state.count % k
emit = reset == (k - 1)
grad_acc = tree_multimap(
lambda g, ga: (reset == 0) * ga + g, updates, state.grad_acc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name reset doesn't actually reflect what's in the variable and I think the guard is wrong (afaik reset == 0 should actually be reset != 0).

I'd suggest:

c = state.count % k
acc = c != 0
grad_acc = tree_multimap(lambda g, ga: acc * ga + g, updates, state.grad_acc)
emit = c == (k - 1)
updates = tree_multimap(lambda ga: emit * ga, grad_acc)

It would also be great if there was a test for this to avoid regressions..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Tom for spotting that problem. I'd be happy to do some tests but where is optix tested?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@perolat
Copy link
Contributor Author

perolat commented Mar 5, 2020

Tests added!

The tolerance is quite high absolute tolerance is 1e-6 and the relative tolerance is 100.
I tried lower tolerance high absolute tolerance is 1e-10 and the relative tolerance is 1e-5 but the test would fail on TPU.

I am not sure this solution is good though. What is recommended to handle the lower numerical precision on TPU?

@mattjj
Copy link
Member

mattjj commented Mar 10, 2020

We can just skip the test on the TPU; look for the jtu.skip_on_devices helper.

@mattjj
Copy link
Member

mattjj commented Mar 10, 2020

I'll merge, then mark this test as skipped on TPU.

@mattjj mattjj merged commit 5c3b478 into google:master Mar 10, 2020
mattjj added a commit that referenced this pull request Mar 10, 2020
@mattjj
Copy link
Member

mattjj commented Mar 10, 2020

I confirmed internal tests pass after merging this (and after cc53aa9).

srvasude pushed a commit to srvasude/jax that referenced this pull request May 5, 2020
srvasude pushed a commit to srvasude/jax that referenced this pull request May 5, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants