New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optax
and tensorflow
's Adam optimizer's setting.
#571
Comments
Thanks for raising this! Rather than including an alternative adam version in optax, I would instead suggest we add a warning in the doc-string of the adam and scale_by_adam gradient transformations. Once they are aware, it's easy for people who need to reproduce old results to fork and modify the transformation. E.g. we could add:
Can you put together a short PR with this fix? |
Thanks Costa @vwxyzjn for the observation regarding optax/pytorch vs. tensorflow Adam and for pointing out this issue to me.
This is a keen observation, and we can indeed confirm it analytically. Following the notation of Kingma and Ba’s Adam paper, let's compare the update equations implemented by optax/pytorch and tensorflow.
The equations above highlight that the distinction between optax and tensorflow implementation is their normalization terms, So, what if we set The above figure shows that, if we set the same Tianlin |
Currently,
optax.scale_by_adam
should be equivalent totorch.optim.Adam
. However, Tensorflow has a different implementation.In short, if we change https://github.com/deepmind/optax/blob/cebdeff4a1922113a96c520e7a81b5bf79825b77/optax/_src/transform.py#L345-L348 to the following, then the adam optimizer would be the same as tensorflow's imlementation.
More context
Basically, PyTorch and optax's adam follow Algorithm 1 of the Kingma and Ba’s Adam paper (arxiv/1412.6980), but TensorFlow uses the formulation just before Section 2.1 of the paper and its epsilon referred to here is epsilon hat in the paper.
This was a relevant issue in my recent reproduction of openai's work in https://github.com/openai/lm-human-preferences. Long story short, below is an end-to-end experiment with torch's adam
adam_pt
and tensorlfow-style adamadam_tf
. While the final performance (objective/scores
) look the same, the learning curves are different in a non-trivial way. E.g., the torch adam version had a much higherclipfrac
initially, causing a more initial significant update.The "initial aggressive update" issue gets aggravated in larger models (e.g., gpt2-large). You can see that
objective/kl
had a spike withadam_tf
, so this could be a reproducibility issue.Desired solution
include a
obviously this is bad naming, but I figure you'd have much better ideas :)
The text was updated successfully, but these errors were encountered: