Skip to content

Commit

Permalink
Align names and defaults with Adam paper, add reference.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 277942532
Change-Id: Ie660aae26b25a74ff55f06bf7bfb12790dcde80a
  • Loading branch information
tomhennigan authored and sonnet-copybara committed Nov 1, 2019
1 parent c787bec commit e3ad61e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 34 deletions.
9 changes: 9 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,12 @@ @article{fortunato2017bayesian
year={2017},
url={https://arxiv.org/abs/1704.02798}
}

@misc{kingma2014adam,
title={Adam: A Method for Stochastic Optimization},
author={Diederik P. Kingma and Jimmy Ba},
year={2014},
eprint={1412.6980},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
76 changes: 42 additions & 34 deletions sonnet/src/optimizers/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,52 @@
from typing import Optional, Sequence, Text, Union


def adam_update(update, learning_rate, beta1, beta2, epsilon, step, m, v):
"""Computes the 'ADAM' update for a single parameter."""
m = beta1 * m + (1. - beta1) * update
v = beta2 * v + (1. - beta2) * tf.square(update)
debiased_m = m / (1. - tf.pow(beta1, step))
debiased_v = v / (1. - tf.pow(beta2, step))
update = learning_rate * debiased_m / (tf.sqrt(debiased_v) + epsilon)
def adam_update(g, alpha, beta_1, beta_2, epsilon, t, m, v):
"""Implements 'Algorithm 1' from :cite:`kingma2014adam`."""
m = beta_1 * m + (1. - beta_1) * g # Biased first moment estimate.
v = beta_2 * v + (1. - beta_2) * g * g # Biased second raw moment estimate.
m_hat = m / (1. - tf.pow(beta_1, t)) # Bias corrected 1st moment estimate.
v_hat = v / (1. - tf.pow(beta_2, t)) # Bias corrected 2nd moment estimate.
update = alpha * m_hat / (tf.sqrt(v_hat) + epsilon)
return update, m, v


class Adam(base.Optimizer):
"""Adaptive Moment Estimation (Adam) module.
"""Adaptive Moment Estimation (Adam) optimizer.
https://arxiv.org/abs/1412.6980
Adam is an algorithm for first-order gradient-based optimization of stochastic
objective functions, based on adaptive estimates of lower-order moments. See
:cite:`kingma2014adam` for more details.
Note: default parameter values have been taken from the paper.
Attributes:
learning_rate: Learning rate.
beta1: Beta1.
beta2: Beta2.
learning_rate: Step size (``alpha`` in the paper).
beta1: Exponential decay rate for first moment estimate.
beta2: Exponential decay rate for second moment estimate.
epsilon: Small value to avoid zero denominator.
step: Step count.
m: Accumulated m for each parameter.
v: Accumulated v for each parameter.
m: Biased first moment estimate (a list with one value per parameter).
v: Biased second raw moment estimate (a list with one value per parameter).
"""

def __init__(
self,
# TODO(petebu): Consider a default learning rate.
learning_rate: Union[types.FloatLike, tf.Variable],
learning_rate: Union[types.FloatLike, tf.Variable] = 0.001,
beta1: Union[types.FloatLike, tf.Variable] = 0.9,
beta2: Union[types.FloatLike, tf.Variable] = 0.999,
epsilon: Union[types.FloatLike, tf.Variable] = 1e-8,
name: Optional[Text] = None):
"""Constructs an `Adam` module.
Args:
learning_rate: Learning rate.
beta1: Beta1.
beta2: Beta2.
learning_rate: Step size (``alpha`` in the paper).
beta1: Exponential decay rate for first moment estimate.
beta2: Exponential decay rate for second moment estimate.
epsilon: Small value to avoid zero denominator.
name: Name of the module.
"""
super(Adam, self).__init__(name)
super(Adam, self).__init__(name=name)
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
Expand All @@ -83,6 +86,7 @@ def __init__(

@once.once
def _initialize(self, parameters: Sequence[tf.Variable]):
"""First and second order moments are initialized to zero."""
zero_var = lambda p: utils.variable_like(p, trainable=False)
with tf.name_scope("m"):
self.m.extend(zero_var(p) for p in parameters)
Expand All @@ -95,18 +99,20 @@ def apply(self, updates: Sequence[types.ParameterUpdate],
Applies the Adam update rule for each update, parameter pair:
m_t <- beta1 * m_{t-1} + (1 - beta1) * update
v_t <- beta2 * v_{t-1} + (1 - beta2) * update * update
\hat{m}_t <- m_t / (1 - beta1^t)
\hat{v}_t <- v_t / (1 - beta2^t)
scaled_update <- \hat{m}_t / (sqrt(\hat{v}_t) + epsilon)
.. math::
parameter <- parameter - learning_rate * scaled_update
\begin{array}{ll}
m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot update \\
v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot update^2 \\
\hat{m}_t = m_t / (1 - \beta_1^t) \\
\hat{v}_t = v_t / (1 - \beta_2^t) \\
delta = \alpha \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon) \\
param_t = param_{t-1} - delta \\
\end{array}
Args:
updates: A list of updates to apply to parameters. Updates are often
gradients as returned by `tf.GradientTape.gradient`.
gradients as returned by :tf:`GradientTape.gradient`.
parameters: A list of parameters.
Raises:
Expand All @@ -123,8 +129,8 @@ def apply(self, updates: Sequence[types.ParameterUpdate],

optimizer_utils.check_same_dtype(update, param)
learning_rate = tf.cast(self.learning_rate, update.dtype)
beta1 = tf.cast(self.beta1, update.dtype)
beta2 = tf.cast(self.beta2, update.dtype)
beta_1 = tf.cast(self.beta1, update.dtype)
beta_2 = tf.cast(self.beta2, update.dtype)
epsilon = tf.cast(self.epsilon, update.dtype)
step = tf.cast(self.step, update.dtype)

Expand All @@ -135,16 +141,18 @@ def apply(self, updates: Sequence[types.ParameterUpdate],
v = v_var.sparse_read(indices)

# Compute and apply a sparse update to our parameter and state.
update, m, v = adam_update(update, learning_rate, beta1, beta2, epsilon,
step, m, v)
update, m, v = adam_update(
g=update, alpha=learning_rate, beta_1=beta_1, beta_2=beta_2,
epsilon=epsilon, t=step, m=m, v=v)
param.scatter_sub(tf.IndexedSlices(update, indices))
m_var.scatter_update(tf.IndexedSlices(m, indices))
v_var.scatter_update(tf.IndexedSlices(v, indices))

else:
# Compute and apply a dense update to our parameter and state.
update, m, v = adam_update(update, learning_rate, beta1, beta2, epsilon,
step, m_var, v_var)
update, m, v = adam_update(
g=update, alpha=learning_rate, beta_1=beta_1, beta_2=beta_2,
epsilon=epsilon, t=step, m=m_var, v=v_var)
param.assign_sub(update)
m_var.assign(m)
v_var.assign(v)

0 comments on commit e3ad61e

Please sign in to comment.