Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge internal changes #428

Closed
wants to merge 12 commits into from
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ of various sequence-to-sequence models, including:

Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU
- fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained and top-k)
- large mini-batch training even on a single GPU via delayed updates
- fast half-precision floating point (FP16) training
- extensible: easily register new models, criterions, and tasks
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers

We also provide [pre-trained models](#pre-trained-models) for several benchmark
translation and language modeling datasets.
Expand All @@ -34,7 +37,7 @@ translation and language modeling datasets.
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6

Currently fairseq requires PyTorch version >= 0.4.0.
Currently fairseq requires PyTorch version >= 1.0.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.

If you use Docker make sure to increase the shared memory size either with
Expand Down
2 changes: 1 addition & 1 deletion distributed_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main(args):
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
except FileNotFoundError: # Slurm is not installed
pass
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
Expand Down
18 changes: 18 additions & 0 deletions docs/criterions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,26 @@
Criterions
==========

Criterions compute the loss function given the model and batch, roughly::

loss = criterion(model, batch)

.. automodule:: fairseq.criterions
:members:

.. autoclass:: fairseq.criterions.FairseqCriterion
:members:
:undoc-members:

.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
:members:
:undoc-members:
16 changes: 16 additions & 0 deletions docs/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ mini-batches.
.. autoclass:: fairseq.data.MonolingualDataset
:members:

**Helper Datasets**

These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
provide additional functionality:

.. autoclass:: fairseq.data.BacktranslationDataset
:members:
.. autoclass:: fairseq.data.ConcatDataset
:members:
.. autoclass:: fairseq.data.RoundRobinZipDatasets
:members:
.. autoclass:: fairseq.data.TransformEosDataset
:members:


Dictionary
----------
Expand All @@ -32,6 +46,8 @@ Dictionary
Iterators
---------

.. autoclass:: fairseq.data.BufferedIterator
:members:
.. autoclass:: fairseq.data.CountingIterator
:members:
.. autoclass:: fairseq.data.EpochBatchIterator
Expand Down
17 changes: 8 additions & 9 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5:
> MODEL_DIR=wmt14.en-fr.fconv-py
> python interactive.py \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5
--beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ?
H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 8 8 8 7 11 12

This generation script produces four types of outputs: a line prefixed
with *S* shows the supplied source sentence after applying the
vocabulary; *O* is a copy of the original source sentence; *H* is the
hypothesis along with an average log-likelihood; and *A* is the
attention maxima for each word in the hypothesis, including the
H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288

This generation script produces three types of outputs: a line prefixed
with *O* is a copy of the original source sentence; *H* is the
hypothesis along with an average log-likelihood; and *P* is the
positional score per token position, including the
end-of-sentence marker which is omitted from the text.

See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
Expand Down
24 changes: 23 additions & 1 deletion docs/lr_scheduler.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,29 @@
Learning Rate Schedulers
========================

TODO
Learning Rate Schedulers update the learning rate over the course of training.
Learning rates can be updated after each update via :func:`step_update` or at
epoch boundaries via :func:`step`.

.. automodule:: fairseq.optim.lr_scheduler
:members:

.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
:members:
:undoc-members:

.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
:members:
:undoc-members:
4 changes: 2 additions & 2 deletions docs/modules.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Modules
=======

Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be
helpful when implementing a new :class:`FairseqModel`.
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.

.. automodule:: fairseq.modules
:members:
Expand Down
22 changes: 22 additions & 0 deletions docs/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,27 @@
Optimizers
==========

Optimizers update the Model parameters based on the gradients.

.. automodule:: fairseq.optim
:members:

.. autoclass:: fairseq.optim.FairseqOptimizer
:members:
:undoc-members:

.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.nag.FairseqNAG
:members:
:undoc-members:
.. autoclass:: fairseq.optim.sgd.SGD
:members:
:undoc-members:
10 changes: 8 additions & 2 deletions docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ fairseq implements the following high-level training flow::
for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr):
loss = criterion(model, batch)
optimizer.backward(loss)
task.train_step(batch, model, criterion, optimizer)
average_and_clip_gradients()
optimizer.step()
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)

where the default implementation for ``train.train_step`` is roughly::

def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)

**Registering new plug-ins**

New plug-ins are *registered* through a set of ``@register`` function
Expand Down
5 changes: 2 additions & 3 deletions docs/tutorial_classifying_names.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory.
-------------------------------

Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classify.py` with the following contents::
a new file named :file:`eval_classifier.py` with the following contents::

from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer

# Parse command-line arguments for generation
parser = options.get_generation_parser()
parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser)

# Setup task
args.task = 'simple_classification'
task = tasks.setup_task(args)

# Load model
Expand Down
7 changes: 5 additions & 2 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def main(parsed_args):

# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides))
models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
)

for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
Expand Down Expand Up @@ -83,9 +85,10 @@ def main(parsed_args):
max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
ignore_invalid_inputs=True,
num_shards=args.num_shards,
shard_id=args.shard_id,
ignore_invalid_inputs=True,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)

gen_timer = StopwatchMeter()
Expand Down
3 changes: 1 addition & 2 deletions fairseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
Expand All @@ -33,7 +33,6 @@
'GroupedIterator',
'IndexedCachedDataset',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'MonolingualDataset',
Expand Down
48 changes: 24 additions & 24 deletions fairseq/data/backtranslation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,28 @@ def update_sample(sample, generated_source):


class BacktranslationDataset(FairseqDataset):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.

Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""

def __init__(
self,
tgt_dataset,
Expand All @@ -66,27 +88,6 @@ def __init__(
cuda=True,
**kwargs
):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.

Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be
used. After backtranslation, the source sentences in this
dataset will be returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be
passed into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch (default:
``tgt_dataset.collater``)
cuda: use GPU for generation
"""
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a
Expand Down Expand Up @@ -166,11 +167,10 @@ def size(self, index):
"""
tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size)

@property
def supports_prefetch(self):
return self.tgt_dataset.supports_prefetch()
return getattr(self.tgt_dataset, 'supports_prefetch', False)

def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices)

10 changes: 5 additions & 5 deletions fairseq/data/concat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ def __init__(self, datasets, sample_ratios=1):
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]

def __len__(self):
return self.cummulative_sizes[-1]
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx)
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx]

Expand All @@ -54,7 +54,7 @@ def supports_prefetch(self):

def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets):
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
Loading