Skip to content

Commit

Permalink
Implements LAMBOptimizer (tensorflow#491)
Browse files Browse the repository at this point in the history
* Implements LAMBOptimizer
  • Loading branch information
junjiek authored and seanpmorgan committed Oct 23, 2019
1 parent 5edc422 commit a28d42f
Show file tree
Hide file tree
Showing 4 changed files with 692 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tensorflow_addons/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py_library(
srcs = [
"__init__.py",
"conditional_gradient.py",
"lamb.py",
"lazy_adam.py",
"lookahead.py",
"moving_average.py",
Expand All @@ -19,6 +20,19 @@ py_library(
],
)

py_test(
name = "lamb_test",
size = "small",
srcs = [
"lamb_test.py",
],
main = "lamb_test.py",
srcs_version = "PY2AND3",
deps = [
":optimizers",
],
)

py_test(
name = "conditional_gradient_test",
size = "small",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_addons/optimizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| conditional_gradient | Pengyu Kan, Vishnu Lokhande | pkan2@wisc.edu, lokhande@cs.wisc.edu |
| lamb | Jing Li, Junjie Ke | jingli@google.com, junjiek@google.com |
| lazy_adam | Saishruthi Swaminathan | saishruthi.tn@gmail.com |
| lookahead | Zhao Hanguang | cyberzhg@gmail.com |
| moving_average | Dheeraj R. Reddy | dheeraj98reddy@gmail.com |
Expand All @@ -15,6 +16,7 @@
| Submodule | Optimizer | Reference |
|:--------- |:---------- |:---------|
| conditional_gradient | ConditionalGradient | https://arxiv.org/pdf/1803.06453.pdf |
| lamb | LAMB | https://arxiv.org/abs/1904.00962 |
| lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 |
| lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 |
| moving_average | MovingAverage | |
Expand Down
248 changes: 248 additions & 0 deletions tensorflow_addons/optimizers/lamb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""LAMB (Layer-wise Adaptive Moments) optimizer as TF2 tf.keras.optimizers.
See paper [Large Batch Optimization for Deep Learning: Training BERT in
76 minutes](https://arxiv.org/abs/1904.00962).
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re

import tensorflow as tf
from tensorflow_addons.utils import keras_utils


@keras_utils.register_keras_custom_object
class LAMB(tf.keras.optimizers.Optimizer):
"""Optimizer that implements the LAMB (Layer-wise Adaptive Moments)
optimizer as TF2 tf.keras.optimizers.
See paper [Large Batch Optimization for Deep Learning: Training BERT
in 76 minutes](https://arxiv.org/abs/1904.00962).
"""

def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
weight_decay_rate=0.0,
exclude_from_weight_decay=None,
exclude_from_layer_adaptation=None,
name='LAMB',
**kwargs):
"""
learning_rate: A `Tensor` or a floating point value.
The learning rate.
beta_1: A `float` value or a constant `float` tensor.
The exponential decay rate for the 1st moment estimates.
beta_2: A `float` value or a constant `float` tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability.
weight_decay_rate: weight decay rate.
exclude_from_weight_decay: comma separated name patterns of variables
excluded from weight decay. Variables whose name contain a substring
matching the pattern will be excluded.
exclude_from_layer_adaptation: comma separated name patterns of
variables excluded from layer adaptation. Variables whose name
contain a substring matching the pattern will be excluded.
name: Optional name for the operations created when applying
gradients. Defaults to "LAMB".
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
`lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue`
is clip gradients by value, `decay` is included for backward
compatibility to allow time inverse decay of learning rate. `lr`
is included for backward compatibility, recommended to use
`learning_rate` instead.
"""
super(LAMB, self).__init__(name, **kwargs)

# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters.
self._set_hyper('weight_decay_rate', weight_decay_rate)
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))

# This is learning rate decay for using keras learning rate schedule.
self._set_hyper('decay', self._initial_decay)
self._set_hyper('beta_1', beta_1)
self._set_hyper('beta_2', beta_2)
self.epsilon = epsilon or tf.backend_config.epsilon()
self.exclude_from_weight_decay = exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if
# the arg is None.
if exclude_from_layer_adaptation:
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
else:
self.exclude_from_layer_adaptation = exclude_from_weight_decay

def _create_slots(self, var_list):
# Create slots for the first and second moments.
# Separate for-loops to respect the ordering of slot variables from v1.
for var in var_list:
self.add_slot(var, 'm')
for var in var_list:
self.add_slot(var, 'v')

def _prepare_local(self, var_device, var_dtype, apply_state):
super(LAMB, self)._prepare_local(var_device, var_dtype, apply_state)

local_step = tf.cast(self.iterations + 1, var_dtype)
beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype))
beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype))
weight_decay_rate = tf.identity(
self._get_hyper('weight_decay_rate', var_dtype))
beta_1_power = tf.pow(beta_1_t, local_step)
beta_2_power = tf.pow(beta_2_t, local_step)
apply_state[(var_device, var_dtype)].update(
dict(
weight_decay_rate=weight_decay_rate,
epsilon=tf.convert_to_tensor(self.epsilon, var_dtype),
beta_1_t=beta_1_t,
beta_1_power=beta_1_power,
one_minus_beta_1_t=1 - beta_1_t,
beta_2_t=beta_2_t,
beta_2_power=beta_2_power,
one_minus_beta_2_t=1 - beta_2_t))

def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))

# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = m * coefficients['beta_1_t'] + m_scaled_g_values
m_t = m.assign(m_t, use_locking=self._use_locking)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, 'v')
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
v_t = v * coefficients['beta_2_t'] + v_scaled_g_values
v_t = v.assign(v_t, use_locking=self._use_locking)

m_t_hat = m_t / (1. - coefficients['beta_1_power'])
v_t_hat = v_t / (1. - coefficients['beta_2_power'])

v_sqrt = tf.sqrt(v_t_hat)
update = m_t_hat / (v_sqrt + coefficients['epsilon'])

var_name = self._get_variable_name(var.name)
if self._do_use_weight_decay(var_name):
update += coefficients['weight_decay_rate'] * var

ratio = 1.0
if self._do_layer_adaptation(var_name):
w_norm = tf.norm(var, ord=2)
g_norm = tf.norm(update, ord=2)
ratio = tf.where(
tf.greater(w_norm, 0),
tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)

var_update = var - ratio * coefficients['lr_t'] * update
return var.assign(var_update, use_locking=self._use_locking).op

def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))

# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = m.assign(
m * coefficients['beta_1_t'], use_locking=self._use_locking)
with tf.control_dependencies([m_t]):
m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)

# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, 'v')
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
v_t = v.assign(
v * coefficients['beta_2_t'], use_locking=self._use_locking)
with tf.control_dependencies([v_t]):
v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

m_t_hat = m_t / (1. - coefficients['beta_1_power'])
v_t_hat = v_t / (1. - coefficients['beta_2_power'])

v_sqrt = tf.sqrt(v_t_hat)
update = m_t_hat / (v_sqrt + coefficients['epsilon'])

var_name = self._get_variable_name(var.name)
if self._do_use_weight_decay(var_name):
update += coefficients['weight_decay_rate'] * var

ratio = 1.0
if self._do_layer_adaptation(var_name):
w_norm = tf.norm(var, ord=2)
g_norm = tf.norm(update, ord=2)
ratio = tf.where(
tf.greater(w_norm, 0),
tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)

var_update = var.assign_sub(
ratio * coefficients['lr_t'] * update,
use_locking=self._use_locking)
return tf.group(*[var_update, m_t, v_t])

def get_config(self):
config = super(LAMB, self).get_config()
config.update({
'learning_rate':
self._serialize_hyperparameter('learning_rate'),
'weight_decay_rate':
self._serialize_hyperparameter('weight_decay_rate'),
'decay':
self._serialize_hyperparameter('decay'),
'beta_1':
self._serialize_hyperparameter('beta_1'),
'beta_2':
self._serialize_hyperparameter('beta_2'),
'epsilon':
self.epsilon,
})
return config

def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True

def _do_layer_adaptation(self, param_name):
"""Whether to do layer-wise learning rate adaptation for
`param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(r, param_name) is not None:
return False
return True

def _get_variable_name(self, param_name):
"""Get the variable name from the tensor name."""
m = re.match('^(.*):\\d+$', param_name)
if m is not None:
param_name = m.group(1)
return param_name
Loading

0 comments on commit a28d42f

Please sign in to comment.