Skip to content

Commit

Permalink
Merge internal changes (#654)
Browse files Browse the repository at this point in the history
Summary:
- Add --add-bos-token option to LM task
- Cleanup utils.py and options.py
Pull Request resolved: #654

Differential Revision: D15041794

Pulled By: myleott

fbshipit-source-id: 3ad00007769d5f48308052cfd40de39c5ffa1a6e
  • Loading branch information
myleott authored and facebook-github-bot committed Apr 30, 2019
1 parent 89a6961 commit d45db80
Show file tree
Hide file tree
Showing 34 changed files with 368 additions and 278 deletions.
6 changes: 6 additions & 0 deletions docs/optim.rst
Expand Up @@ -15,9 +15,15 @@ Optimizers update the Model parameters based on the gradients.
:members:
:undoc-members:

.. autoclass:: fairseq.optim.adadelta.Adadelta
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
Expand Down
3 changes: 2 additions & 1 deletion docs/overview.rst
Expand Up @@ -28,11 +28,12 @@ fairseq implements the following high-level training flow::
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)

where the default implementation for ``train.train_step`` is roughly::
where the default implementation for ``task.train_step`` is roughly::

def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
return loss

**Registering new plug-ins**

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorial_classifying_names.rst
Expand Up @@ -354,7 +354,7 @@ The model files should appear in the :file:`checkpoints/` directory.
Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents::

from fairseq import data, options, tasks, utils
from fairseq import checkpoint_utils, data, options, tasks

# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
Expand All @@ -365,7 +365,7 @@ a new file named :file:`eval_classifier.py` with the following contents::

# Load model
print('| loading model from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference([args.path], task)
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
model = models[0]

while True:
Expand Down
21 changes: 15 additions & 6 deletions eval_lm.py
Expand Up @@ -13,11 +13,10 @@
import numpy as np
import torch

from fairseq import options, progress_bar, tasks, utils
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module


class WordStat(object):
Expand Down Expand Up @@ -49,7 +48,7 @@ def __str__(self):
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'

import_user_module(parsed_args)
utils.import_user_module(parsed_args)

print(parsed_args)

Expand All @@ -59,12 +58,17 @@ def main(parsed_args):

# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(':'),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
)

for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
if arg not in {
'self_target', 'future_target', 'past_target', 'tokens_per_sample',
'output_size_dictionary', 'add_bos_token',
}:
setattr(args, arg, getattr(parsed_args, arg))

# reduce tokens per sample by the required context window size
Expand Down Expand Up @@ -151,6 +155,11 @@ def main(parsed_args):
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()

if args.add_bos_token:
assert hypo['tokens'][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]

skipped_toks = 0
if bpe_toks is not None:
for i in range(tgt_len - 1):
Expand Down
2 changes: 1 addition & 1 deletion examples/language_model/README.md
Expand Up @@ -39,7 +39,7 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
# Evaluate:
Expand Down
177 changes: 177 additions & 0 deletions fairseq/checkpoint_utils.py
@@ -0,0 +1,177 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from collections import OrderedDict
import logging
import os
import re
import traceback

import torch
from torch.serialization import default_restore_location

from fairseq import tasks


def load_checkpoint_to_cpu(path):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
state = _upgrade_state_dict(state)
return state


def load_model_ensemble(filenames, arg_overrides=None, task=None):
"""Loads an ensemble of models.
Args:
filenames (List[str]): checkpoint files to load
arg_overrides (Dict[str,Any], optional): override model args that
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
ensemble = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)

args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)

if task is None:
task = tasks.setup_task(args)

# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)

return ensemble, args


def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)

entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]


def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())


def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict


def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None,
):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)


def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
}
del state['optimizer_history'][-1]['best_loss']
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state
9 changes: 6 additions & 3 deletions fairseq/data/dictionary.py
Expand Up @@ -18,13 +18,12 @@

class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'):
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>', bos='<s>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
self.add_symbol('<Lua heritage>')
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
Expand Down Expand Up @@ -143,6 +142,10 @@ def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
self.symbols = list(new_symbols)
self.indices = new_indices

def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.bos_index

def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
Expand Down
12 changes: 11 additions & 1 deletion fairseq/data/monolingual_dataset.py
Expand Up @@ -62,13 +62,14 @@ class MonolingualDataset(FairseqDataset):
"""

def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
targets=None):
targets=None, add_bos_token=False):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = src_vocab
self.tgt_vocab = tgt_vocab
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle
self.add_bos_token = add_bos_token

assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
"targets must be none or one of 'self', 'future', 'past'"
Expand All @@ -91,6 +92,7 @@ def __getitem__(self, index):
else:
source = self.dataset[index]
target = None
source, target = self._maybe_add_bos(source, target)
return {'id': index, 'source': source, 'target': target}

def __len__(self):
Expand Down Expand Up @@ -129,6 +131,13 @@ def _make_source_target(self, source, future_target, past_target):

return source, self._filter_vocab(target)

def _maybe_add_bos(self, source, target):
if self.add_bos_token:
source = torch.cat([source.new([self.vocab.bos()]), source])
if target is not None:
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
return source, target

def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab):
def _filter(target):
Expand Down Expand Up @@ -173,6 +182,7 @@ def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
target = self.vocab.dummy_sentence(tgt_len + 2)
source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target)
source, target = self._maybe_add_bos(source, target)

return self.collater([
{'id': i, 'source': source, 'target': target}
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/fconv.py
Expand Up @@ -141,7 +141,7 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_lm_architecture(args)

if hasattr(args, 'max_target_positions'):
if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'):
args.tokens_per_sample = args.max_target_positions

decoder = FConvDecoder(
Expand Down

0 comments on commit d45db80

Please sign in to comment.