Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rmsprop implementation in optax and torch #532

Closed
vwxyzjn opened this issue Jun 1, 2023 · 9 comments · Fixed by #595
Closed

rmsprop implementation in optax and torch #532

vwxyzjn opened this issue Jun 1, 2023 · 9 comments · Fixed by #595

Comments

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jun 1, 2023

Hello, thanks for this helpful library.

I was wondering if optax's rmsprop implementation is equivalent to torch's rmsprop implementation.

Recently, I have been working on a distributed reinforcement learning library called Cleanba which replicates IMPALA. I then compared our Cleanba's impala against torchbeast's IMPALA. I considered the following four optimizer settings:

torch setting in torchbeast optax setting in our cleanba
optimizer = torch.optim.Adam(
    learner_model.parameters(),
    lr=0.00025,
    eps=1e-5,
)
# ...
nn.utils.clip_grad_norm_(learner_model.parameters(), 0.5)
optax.chain(
    optax.clip_by_global_norm(0.5),
    optax.adam(
        learning_rate=0.00025, 
        eps=1e-5
    ),
)
torch setting in torchbeast optax setting in our cleanba
optimizer = torch.optim.RMSprop(
    learner_model.parameters(),
    lr=0.0006,
    momentum=0,
    eps=0.01,
    alpha=0.99,
)
# ...
nn.utils.clip_grad_norm_(learner_model.parameters(), 40)
optax.chain(
    optax.clip_by_global_norm(40),
    optax.rmsprop(
         learning_rate=0.0006,
         eps=0.01,
         decay=0.99, # decay corresponds to torch's `alpha`, right?
    ),
)

The results are presented in the figure below. When using the same Adam optimizer settings in torchbeast and cleanba, torchbeast's IMPALA and cleanba's IMPALA has similar performance. However, when using the same RMSprop optimizer settings, the performance differs significantly, with torchbeast's IMPALA obtaining much higher median human normalized score. This leads me to wonder if optax has the same RMSprop implementation as torch... Was wondering if you would have any thoughts on the performance discrepancies. Thanks!

image

The experiments were run for three random seeds, and the individual learning curves suggest torchbeast's IMPALA seems to perform qualitatively better.

main_10CPU

@Rupt
Copy link
Contributor

Rupt commented Aug 13, 2023

@vwxyzjn both of your "cleanba" examples use optax.adam, although the second has an invalid decay= argument. Is this an error in your reproduction in this issue?

@Rupt
Copy link
Contributor

Rupt commented Aug 13, 2023

This leads me to wonder if optax has the same RMSprop implementation as torch... Was wondering if you would have any thoughts on the performance discrepancies. Thanks!

@vwxyzjn Yes, the implementations differ. They treat eps differently.

In the denominator, optax uses $\sqrt{v + \varepsilon}$ (see here and here), and torch uses $(\sqrt v + \varepsilon)$ (see here).

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Aug 13, 2023

Hi @Rupt thanks for the reply. The snippet I shared had a typo, it should be optax.rmsprop, which is actually used in the experiment

image

Also thanks for checking the differences, so to create a torch equivalent optimizer, I just need to do the following?

  def update_fn(updates, state, params=None):
    del params
    nu = update_moment_per_elem_norm(updates, state.nu, decay, 2)
    updates = jax.tree_util.tree_map(
-        lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu)
+        lambda g, n: g * jax.lax.rsqrt(n) + eps, updates, nu)
    return updates, ScaleByRmsState(nu=nu)

@Rupt
Copy link
Contributor

Rupt commented Aug 13, 2023

Very welcome @vwxyzjn.

Not quite, the eps should still be in the denominator. I think you want this:

-        lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu)
+        lambda g, n: g / (jax.lax.sqrt(n) + eps), updates, nu)

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Aug 13, 2023

Thanks so much @Rupt I will give it a try :)

@mtthss
Copy link
Collaborator

mtthss commented Aug 14, 2023

@vwxyzjn did this work?

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Aug 16, 2023

@mtthss Actually yeah, I used the optimizer @Rupt suggested and can now match the performance of monobeast with RMSprop (the purple curve now matches the blue curve). The PyTorch RMSprop setting seems quite good, giving a performance boost in almost all the games tested.

Would you like a PR to add a warning like #571?

https://github.com/vwxyzjn/cleanba/blob/d0f5edebe8539231855d657e57e46daf7c590bc7/cleanba/cleanba_impala_envpool_machado_atari_wrapper_rmsprop_pt.py#L135-L179

image

main_10CPU_sample_walltime_efficiency

main_10CPU

@Rupt
Copy link
Contributor

Rupt commented Aug 16, 2023

@vwxyzjn Glad it worked, thanks for sharing this nice reproduction.

The PyTorch RMSprop setting seems quite good, giving a performance boost in almost all the games tested.

Note that the two implementations will prefer different eps values. You might get more similar results if using eps**2 in the optax version, because for small nu we get $1/\sqrt{0+\varepsilon^2} = 1/(\sqrt 0 + \varepsilon)$

@mtthss
Copy link
Collaborator

mtthss commented Oct 10, 2023

Would you like a PR to add a warning like #571?

@vwxyzjn that would be great thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants