Skip to content

Commit

Permalink
Adadelta optimizer
Browse files Browse the repository at this point in the history
Summary: Adding Adadelta optimizer to fairseq as wrapper around torch.optim.Adadelta

Reviewed By: myleott

Differential Revision: D14418635

fbshipit-source-id: 6bf5ec008e905a4a2cbf7415e9492f5eea3ff07f
  • Loading branch information
Dmytro Okhonko authored and facebook-github-bot committed Mar 12, 2019
1 parent 9e1c880 commit d17fa85
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
40 changes: 40 additions & 0 deletions fairseq/optim/adadelta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch.optim

from . import FairseqOptimizer, register_optimizer


@register_optimizer('adadelta')
class Adadelta(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)

@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO',
help='coefficient used for computing a running average of squared gradients')
parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS',
help='term added to the denominator to improve numerical stability')

@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'rho': self.args.adadelta_rho,
'eps': self.args.adadelta_eps,
'weight_decay': self.args.weight_decay,
}
23 changes: 22 additions & 1 deletion tests/test_binaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,28 @@ def test_fconv_lm(self):
eval_lm_main(data_dir)


class TestCommonOptions(unittest.TestCase):

def test_optimizers(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_optimizers') as data_dir:
# Use just a bit of data and tiny model to keep this test runtime reasonable
create_dummy_data(data_dir, num_examples=10, maxlen=5)
preprocess_translation_data(data_dir)
optimizers = ['adafactor', 'adam', 'nag', 'adagrad', 'sgd', 'adadelta']
last_checkpoint = os.path.join(data_dir, 'checkpoint_last.pt')
for optimizer in optimizers:
if os.path.exists(last_checkpoint):
os.remove(last_checkpoint)
train_translation_model(data_dir, 'lstm', [
'--encoder-layers', '1',
'--encoder-hidden-size', '32',
'--decoder-layers', '1',
'--optimizer', optimizer,
])
generate_main(data_dir)


def create_dummy_data(data_dir, num_examples=1000, maxlen=20):

def _create_dummy_data(filename):
Expand Down Expand Up @@ -267,7 +289,6 @@ def train_translation_model(data_dir, arch, extra_flags=None):
data_dir,
'--save-dir', data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '0.05',
'--max-tokens', '500',
'--max-epoch', '1',
Expand Down

0 comments on commit d17fa85

Please sign in to comment.