New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rmsprop implementation in optax and torch #532
Comments
@vwxyzjn both of your "cleanba" examples use |
@vwxyzjn Yes, the implementations differ. They treat In the denominator, optax uses |
Hi @Rupt thanks for the reply. The snippet I shared had a typo, it should be Also thanks for checking the differences, so to create a def update_fn(updates, state, params=None):
del params
nu = update_moment_per_elem_norm(updates, state.nu, decay, 2)
updates = jax.tree_util.tree_map(
- lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu)
+ lambda g, n: g * jax.lax.rsqrt(n) + eps, updates, nu)
return updates, ScaleByRmsState(nu=nu) |
Very welcome @vwxyzjn. Not quite, the eps should still be in the denominator. I think you want this: - lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu)
+ lambda g, n: g / (jax.lax.sqrt(n) + eps), updates, nu) |
Thanks so much @Rupt I will give it a try :) |
@vwxyzjn did this work? |
@mtthss Actually yeah, I used the optimizer @Rupt suggested and can now match the performance of Would you like a PR to add a warning like #571? |
@vwxyzjn Glad it worked, thanks for sharing this nice reproduction.
Note that the two implementations will prefer different |
Hello, thanks for this helpful library.
I was wondering if
optax
's rmsprop implementation is equivalent to torch's rmsprop implementation.Recently, I have been working on a distributed reinforcement learning library called Cleanba which replicates IMPALA. I then compared our Cleanba's impala against torchbeast's IMPALA. I considered the following four optimizer settings:
The results are presented in the figure below. When using the same Adam optimizer settings in torchbeast and cleanba, torchbeast's IMPALA and cleanba's IMPALA has similar performance. However, when using the same RMSprop optimizer settings, the performance differs significantly, with torchbeast's IMPALA obtaining much higher median human normalized score. This leads me to wonder if optax has the same RMSprop implementation as torch... Was wondering if you would have any thoughts on the performance discrepancies. Thanks!
The experiments were run for three random seeds, and the individual learning curves suggest torchbeast's IMPALA seems to perform qualitatively better.
The text was updated successfully, but these errors were encountered: