Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optax and tensorflow's Adam optimizer's setting. #571

Closed
vwxyzjn opened this issue Aug 13, 2023 · 3 comments · Fixed by #572
Closed

optax and tensorflow's Adam optimizer's setting. #571

vwxyzjn opened this issue Aug 13, 2023 · 3 comments · Fixed by #572

Comments

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Aug 13, 2023

Currently, optax.scale_by_adam should be equivalent to torch.optim.Adam. However, Tensorflow has a different implementation.

In short, if we change https://github.com/deepmind/optax/blob/cebdeff4a1922113a96c520e7a81b5bf79825b77/optax/_src/transform.py#L345-L348 to the following, then the adam optimizer would be the same as tensorflow's imlementation.

updates = jax.tree_util.tree_map(
    lambda m, v: (jnp.sqrt(1- b2**count_inc) / (1-b1**count_inc)) *  m / (jnp.sqrt(v + eps_root) + eps), mu, nu)

More context

Basically, PyTorch and optax's adam follow Algorithm 1 of the Kingma and Ba’s Adam paper (arxiv/1412.6980), but TensorFlow uses the formulation just before Section 2.1 of the paper and its epsilon referred to here is epsilon hat in the paper.

This was a relevant issue in my recent reproduction of openai's work in https://github.com/openai/lm-human-preferences. Long story short, below is an end-to-end experiment with torch's adam adam_pt and tensorlfow-style adam adam_tf. While the final performance (objective/scores) look the same, the learning curves are different in a non-trivial way. E.g., the torch adam version had a much higher clipfrac initially, causing a more initial significant update.

image

The "initial aggressive update" issue gets aggravated in larger models (e.g., gpt2-large). You can see that objective/kl had a spike with adam_tf, so this could be a reproducibility issue.

image image

Desired solution

include a

import jax
import jax.numpy as jnp
from optax import ScaleByAdamState, update_moment, update_moment_per_elem_norm
from optax._src.alias import _scale_by_learning_rate
from optax._src import base, utils, combine, numerics


def scale_by_adam_tf_style(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype = None,
) -> base.GradientTransformation:
  """Rescale updates according to the Adam algorithm.
  References:
    [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
  WARNING: This is a TensorFlow-style Adam optimizer that uses the
    formulation just before Section 2.1 of the Kingma and Ba paper
    rather than the formulation in Algorithm 1, the "epsilon" referred 
    to here is "epsilon hat" in the paper.
  Args:
    b1: Decay rate for the exponentially weighted average of grads.
    b2: Decay rate for the exponentially weighted average of squared grads.
    eps: Term added to the denominator to improve numerical stability. (epsilon hat)
    eps_root: Term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    mu_dtype: Optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype` is inferred from `params` and `updates`.
  Returns:
    A `GradientTransformation` object.
  """

  mu_dtype = utils.canonicalize_dtype(mu_dtype)

  def init_fn(params):
    mu = jax.tree_util.tree_map(  # First moment
        lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
    nu = jax.tree_util.tree_map(jnp.zeros_like, params)  # Second moment
    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

  def update_fn(updates, state, params=None):
    del params
    mu = update_moment(updates, state.mu, b1, 1)
    nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
    count_inc = numerics.safe_int32_increment(state.count)

    ### `optax` default adam implementation
    # mu_hat = bias_correction(mu, b1, count_inc)
    # nu_hat = bias_correction(nu, b2, count_inc)
    # updates = jax.tree_util.tree_map(
    #     lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
    ### Tensorflow adam implementation
    updates = jax.tree_util.tree_map(
        lambda m, v: (jnp.sqrt(1- b2**count_inc) / (1-b1**count_inc)) *  m / (jnp.sqrt(v + eps_root) + eps), mu, nu) # 
    mu = utils.cast_tree(mu, mu_dtype)
    return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

  return base.GradientTransformation(init_fn, update_fn)


def adam_tf_style(
    learning_rate,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype = None,
):
  return combine.chain(
      scale_by_adam_tf_style(
          b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype),
      _scale_by_learning_rate(learning_rate),
  )

obviously this is bad naming, but I figure you'd have much better ideas :)

@mtthss
Copy link
Collaborator

mtthss commented Aug 14, 2023

Thanks for raising this!

Rather than including an alternative adam version in optax, I would instead suggest we add a warning in the doc-string of the adam and scale_by_adam gradient transformations. Once they are aware, it's easy for people who need to reproduce old results to fork and modify the transformation.

E.g. we could add:

WARNING: PyTorch and optax's adam follow Algorithm 1 of the Kingma and Ba’s Adam paper, if reproducing old results note that TensorFlow used instead the formulation just before Section 2.1 of the paper.

Can you put together a short PR with this fix?

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Aug 14, 2023

@mtthss with pleasure. Thanks for the suggestion and see #572.

P.S. big fan of your work and it made my day seeing your response :D

@liutianlin0121
Copy link

liutianlin0121 commented Sep 5, 2023

Hi @vwxyzjn @mtthss ,

Thanks Costa @vwxyzjn for the observation regarding optax/pytorch vs. tensorflow Adam and for pointing out this issue to me.

While the final performance (objective/scores) look the same, the learning curves are different in a non-trivial way. E.g., the torch adam version had a much higher clipfrac initially, causing a more initial significant update.

This is a keen observation, and we can indeed confirm it analytically. Following the notation of Kingma and Ba’s Adam paper, let's compare the update equations implemented by optax/pytorch and tensorflow.

  • Gradient update rules for optax/pytorch adam (Algorithm 1 of Kingma and Ba’s paper):
$$\begin{equation} \begin{aligned} \theta_t & =\theta_{t-1}-\alpha \cdot \hat{m}_t /\left(\sqrt{\hat{v}_t}+\varepsilon\right) \\\ & =\theta_{t-1}- \alpha \underbrace{\left[m_t /\left(1-\beta_1^t\right)\right]}_{=\hat{m}_t} /\left[\sqrt{\underbrace{v_t /\left(1-\beta_2^t\right)}_{=\hat{v}_t} }+\varepsilon\right]\\\ & =\theta_{t-1}- \alpha\left[m_t /\left(1-\beta_1^t\right)\right]\frac{\sqrt{1-\beta_2^t}}{\sqrt{v_t}+\color{blue}{\varepsilon \sqrt{1-\beta_2^t}}} \end{aligned} \end{equation}$$
  • Gradient update rules for tensorflow adam (the formulation just before Section 2.1 of Kingma and Ba’s paper):
$$\begin{equation} \begin{aligned}\quad \theta_t & =\theta_{t-1}-\alpha_t m_t /\left(\sqrt{v_t}+\hat{\varepsilon}\right) \\& =\theta_{t-1}-\underbrace{\left[\alpha \sqrt{1-\beta_2^t} /\left(1-\beta_1^t\right)\right]}_{=\alpha_t} m_t /\left(\sqrt{v_t}+\hat{\varepsilon}\right) \\& =\theta_{t-1}- \alpha\left[m_t /\left(1-\beta_1^t\right)\right] \frac{\sqrt{1-\beta_2^t}}{\sqrt{v_t}+\color{blue}{\hat{\varepsilon}}} \end{aligned}\end{equation}$$

The equations above highlight that the distinction between optax and tensorflow implementation is their normalization terms, $\color{blue}{\varepsilon \sqrt{1-\beta_2^t}}$ and $\color{blue}{\hat{\varepsilon}}$. The two versions are equivalent if we set $\hat{\varepsilon} =\varepsilon \sqrt{1-\beta_2^t}$ . However, in the APIs, we can only set $\varepsilon$ (optax and pytorch) and $\hat{\varepsilon}$ (tensorflow) via the eps argument, causing differences in their update equations.

So, what if we set $\varepsilon$ and $\hat{\varepsilon}$ to the same value, say, 1e-5? Then for tensorflow adam, the normalization term $\hat{\varepsilon} = \text{1e-5}$ is just a constant. But for optax/pytorch adam, the normalization term ${\varepsilon \sqrt{1-\beta_2^t}}$ changes over time. Importantly, the normalization term ${\varepsilon \sqrt{1-\beta_2^t}}$ is initially much smaller than 1e-5 when the timestep $t$ is small, and gradually approaches to 1e-5 as timesteps increase. The plot below compares these two normalization terms over timesteps:

image

The above figure shows that, if we set the same eps in optax/pytorch adam and tensorflow adam, then optax/pytorch adam uses a much smaller normalization term than tensorflow-adam in the early phase of training. In other words, optax/pytorch adam goes for more aggressive gradient updates early in the training. This aligns nicely with Costa's empirical results shared earlier!

Tianlin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants