Skip to content

Commit

Permalink
Add code for mixture of experts (#521)
Browse files Browse the repository at this point in the history
Summary:
Code for the paper: [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
Pull Request resolved: #521

Differential Revision: D14188021

Pulled By: myleott

fbshipit-source-id: ed5b1ed5ad9a582359bd5215fa2ea26dc76c673e
  • Loading branch information
myleott authored and facebook-github-bot committed Feb 22, 2019
1 parent b65c579 commit 4294c4f
Show file tree
Hide file tree
Showing 10 changed files with 435 additions and 18 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@ developers to train custom models for translation, summarization, language
modeling and other text generation tasks. It provides reference implementations
of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956)
- [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- **LightConv and DynamicConv models**
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](https://openreview.net/pdf?id=SkVhlh09tX)
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- **Long Short-Term Memory (LSTM) networks**
- [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)
- [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960)
- **Transformer (self-attention) networks**
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](https://arxiv.org/abs/1808.09381)
- **_New_** [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)

Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
Expand Down Expand Up @@ -74,6 +75,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional models are available

We also have more detailed READMEs to reproduce results from specific papers:
- [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
Expand Down
8 changes: 4 additions & 4 deletions examples/translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,22 @@ $ bash prepare-wmt14en2de.sh
$ cd ../..
# Binarize the dataset:
$ TEXT=examples/translation/wmt14_en_de
$ TEXT=examples/translation/wmt17_en_de
$ fairseq-preprocess --source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_de --thresholdtgt 0 --thresholdsrc 0
--destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0
# Train the model:
# If it runs out of memory, try to set --max-tokens 1500 instead
$ mkdir -p checkpoints/fconv_wmt_en_de
$ fairseq-train data-bin/wmt14_en_de \
$ fairseq-train data-bin/wmt17_en_de \
--lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler fixed --force-anneal 50 \
--arch fconv_wmt_en_de --save-dir checkpoints/fconv_wmt_en_de
# Generate:
$ fairseq-generate data-bin/wmt14_en_de \
$ fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/fconv_wmt_en_de/checkpoint_best.pt --beam 5 --remove-bpe
```
Expand Down
5 changes: 4 additions & 1 deletion examples/translation/prepare-wmt14en2de.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ if [ "$1" == "--icml17" ]; then
URLS[2]="http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
FILES[2]="training-parallel-nc-v9.tgz"
CORPORA[2]="training/news-commentary-v9.de-en"
OUTDIR=wmt14_en_de
else
OUTDIR=wmt17_en_de
fi

if [ ! -d "$SCRIPTS" ]; then
Expand All @@ -51,7 +54,7 @@ fi
src=en
tgt=de
lang=en-de
prep=wmt14_en_de
prep=$OUTDIR
tmp=$prep/tmp
orig=orig
dev=dev/newstest2013
Expand Down
87 changes: 87 additions & 0 deletions examples/translation_moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)

This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).

## Training a new model on WMT'17 En-De

First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh).
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.

Then we can train a mixture of experts model using the `translation_moe` task.
Use the `--method` option to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`).

To train a hard mixture of experts model with a learned prior (`hMoElp`) on 1 GPU:
```
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/wmt17_en_de \
--max-update 100000 \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--arch transformer_vaswani_wmt_en_de --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0007 --min-lr 1e-09 \
--dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \
--max-tokens 3584 \
--update-freq 8
```

**Note**: the above command assumes 1 GPU, but accumulates gradients from 8 fwd/bwd passes to simulate training on 8 GPUs.
You can accelerate training on up to 8 GPUs by adjusting the `CUDA_VISIBLE_DEVICES` and `--update-freq` options accordingly.

Once a model is trained, we can generate translations from different experts using the `--gen-expert` option.
For example, to generate from expert 0:
```
$ fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt
--beam 1 --remove-bpe \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0 \
```

You can also use `scripts/score_moe.py` to compute pairwise BLEU and average oracle BLEU.
We'll first download a tokenized version of the multi-reference WMT'14 En-De dataset:
```
$ wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok
```

Next apply BPE on the fly and run generation for each expert:
```
$ BPEROOT=examples/translation/subword-nmt/
$ BPE_CODE=examples/translation/wmt17_en_de/code
$ for EXPERT in $(seq 0 2); do \
cat wmt14-en-de.extra_refs.tok | grep ^S | cut -f 2 | \
python $BPEROOT/apply_bpe.py -c $BPE_CODE | \
fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--buffer 500 --max-tokens 6000 ; \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT \
done > wmt14-en-de.extra_refs.tok.gen.3experts
```

Finally compute pairwise BLUE and average oracle BLEU:
```
$ python scripts/score_moe.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
pairwise BLEU: 48.26
avg oracle BLEU: 49.50
#refs covered: 2.11
```

This reproduces row 3 from Table 7 in the paper.

## Citation

```bibtex
@article{shen2019mixture,
title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade},
author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato},
journal = {arXiv preprint arXiv:1902.07816},
year = 2019,
}
```
14 changes: 9 additions & 5 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ def forward(self, model, sample, reduce=True):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
Expand All @@ -42,6 +38,14 @@ def forward(self, model, sample, reduce=True):
}
return loss, sample_size, logging_output

def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
return loss, loss

@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
Expand Down
4 changes: 4 additions & 0 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
Expand All @@ -35,6 +37,8 @@
'LearnedPositionalEmbedding',
'LightweightConv1dTBC',
'LinearizedConvolution',
'LogSumExpMoE',
'MeanPoolGatingNetwork',
'MultiheadAttention',
'ScalarBias',
'SinusoidalPositionalEmbedding',
Expand Down
28 changes: 28 additions & 0 deletions fairseq/modules/logsumexp_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch


class LogSumExpMoE(torch.autograd.Function):
"""Standard LogSumExp forward pass, but use *posterior* for the backward.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""

@staticmethod
def forward(ctx, logp, posterior, dim=-1):
ctx.save_for_backward(posterior)
ctx.dim = dim
return torch.logsumexp(logp, dim=dim)

@staticmethod
def backward(ctx, grad_output):
posterior, = ctx.saved_tensors
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
return grad_logp, None, None
53 changes: 53 additions & 0 deletions fairseq/modules/mean_pool_gating_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch
import torch.nn.functional as F


class MeanPoolGatingNetwork(torch.nn.Module):
"""A simple mean-pooling gating network for selecting experts.
This module applies mean pooling over an encoder's output and returns
reponsibilities for each expert. The encoder format is expected to match
:class:`fairseq.models.transformer.TransformerEncoder`.
"""

def __init__(self, embed_dim, num_experts, dropout=None):
super().__init__()
self.embed_dim = embed_dim
self.num_experts = num_experts

self.fc1 = torch.nn.Linear(embed_dim, embed_dim)
self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None
self.fc2 = torch.nn.Linear(embed_dim, num_experts)

def forward(self, encoder_out):
if not (
isinstance(encoder_out, dict)
and 'encoder_out' in encoder_out
and 'encoder_padding_mask' in encoder_out
and encoder_out['encoder_out'].size(2) == self.embed_dim
):
raise ValueError('Unexpected format for encoder_out')

# mean pooling over time
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0
ntokens = torch.sum(1 - encoder_padding_mask, dim=1, keepdim=True)
x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
else:
x = torch.mean(encoder_out, dim=1)

x = torch.tanh(self.fc1(x))
if self.dropout is not None:
x = self.dropout(x)
x = self.fc2(x)
return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
11 changes: 9 additions & 2 deletions fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ def __init__(
self.search = search.BeamSearch(tgt_dict)

@torch.no_grad()
def generate(self, models, sample=None, net_input=None, prefix_tokens=None, **kwargs):
def generate(
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
"""Generate a batch of translations.
Args:
Expand Down Expand Up @@ -143,7 +150,7 @@ def generate(self, models, sample=None, net_input=None, prefix_tokens=None, **kw
scores_buf = scores.clone()
tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
tokens[:, 0] = bos_token or self.eos
attn, attn_buf = None, None
nonpad_idxs = None

Expand Down

0 comments on commit 4294c4f

Please sign in to comment.