-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add code for mixture of experts (#521)
Summary: Code for the paper: [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816). Pull Request resolved: #521 Differential Revision: D14188021 Pulled By: myleott fbshipit-source-id: ed5b1ed5ad9a582359bd5215fa2ea26dc76c673e
- Loading branch information
1 parent
b65c579
commit 4294c4f
Showing
10 changed files
with
435 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019) | ||
|
||
This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816). | ||
|
||
## Training a new model on WMT'17 En-De | ||
|
||
First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh). | ||
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`. | ||
|
||
Then we can train a mixture of experts model using the `translation_moe` task. | ||
Use the `--method` option to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`). | ||
|
||
To train a hard mixture of experts model with a learned prior (`hMoElp`) on 1 GPU: | ||
``` | ||
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/wmt17_en_de \ | ||
--max-update 100000 \ | ||
--task translation_moe \ | ||
--method hMoElp --mean-pool-gating-network \ | ||
--num-experts 3 \ | ||
--arch transformer_vaswani_wmt_en_de --share-all-embeddings \ | ||
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ | ||
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ | ||
--lr 0.0007 --min-lr 1e-09 \ | ||
--dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \ | ||
--max-tokens 3584 \ | ||
--update-freq 8 | ||
``` | ||
|
||
**Note**: the above command assumes 1 GPU, but accumulates gradients from 8 fwd/bwd passes to simulate training on 8 GPUs. | ||
You can accelerate training on up to 8 GPUs by adjusting the `CUDA_VISIBLE_DEVICES` and `--update-freq` options accordingly. | ||
|
||
Once a model is trained, we can generate translations from different experts using the `--gen-expert` option. | ||
For example, to generate from expert 0: | ||
``` | ||
$ fairseq-generate data-bin/wmt17_en_de \ | ||
--path checkpoints/checkpoint_best.pt | ||
--beam 1 --remove-bpe \ | ||
--task translation_moe \ | ||
--method hMoElp --mean-pool-gating-network \ | ||
--num-experts 3 \ | ||
--gen-expert 0 \ | ||
``` | ||
|
||
You can also use `scripts/score_moe.py` to compute pairwise BLEU and average oracle BLEU. | ||
We'll first download a tokenized version of the multi-reference WMT'14 En-De dataset: | ||
``` | ||
$ wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok | ||
``` | ||
|
||
Next apply BPE on the fly and run generation for each expert: | ||
``` | ||
$ BPEROOT=examples/translation/subword-nmt/ | ||
$ BPE_CODE=examples/translation/wmt17_en_de/code | ||
$ for EXPERT in $(seq 0 2); do \ | ||
cat wmt14-en-de.extra_refs.tok | grep ^S | cut -f 2 | \ | ||
python $BPEROOT/apply_bpe.py -c $BPE_CODE | \ | ||
fairseq-interactive data-bin/wmt17_en_de \ | ||
--path checkpoints/checkpoint_best.pt \ | ||
--beam 1 --remove-bpe \ | ||
--buffer 500 --max-tokens 6000 ; \ | ||
--task translation_moe \ | ||
--method hMoElp --mean-pool-gating-network \ | ||
--num-experts 3 \ | ||
--gen-expert $EXPERT \ | ||
done > wmt14-en-de.extra_refs.tok.gen.3experts | ||
``` | ||
|
||
Finally compute pairwise BLUE and average oracle BLEU: | ||
``` | ||
$ python scripts/score_moe.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok | ||
pairwise BLEU: 48.26 | ||
avg oracle BLEU: 49.50 | ||
#refs covered: 2.11 | ||
``` | ||
|
||
This reproduces row 3 from Table 7 in the paper. | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@article{shen2019mixture, | ||
title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade}, | ||
author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato}, | ||
journal = {arXiv preprint arXiv:1902.07816}, | ||
year = 2019, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
import torch | ||
|
||
|
||
class LogSumExpMoE(torch.autograd.Function): | ||
"""Standard LogSumExp forward pass, but use *posterior* for the backward. | ||
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" | ||
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_. | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, logp, posterior, dim=-1): | ||
ctx.save_for_backward(posterior) | ||
ctx.dim = dim | ||
return torch.logsumexp(logp, dim=dim) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
posterior, = ctx.saved_tensors | ||
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior | ||
return grad_logp, None, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
class MeanPoolGatingNetwork(torch.nn.Module): | ||
"""A simple mean-pooling gating network for selecting experts. | ||
This module applies mean pooling over an encoder's output and returns | ||
reponsibilities for each expert. The encoder format is expected to match | ||
:class:`fairseq.models.transformer.TransformerEncoder`. | ||
""" | ||
|
||
def __init__(self, embed_dim, num_experts, dropout=None): | ||
super().__init__() | ||
self.embed_dim = embed_dim | ||
self.num_experts = num_experts | ||
|
||
self.fc1 = torch.nn.Linear(embed_dim, embed_dim) | ||
self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None | ||
self.fc2 = torch.nn.Linear(embed_dim, num_experts) | ||
|
||
def forward(self, encoder_out): | ||
if not ( | ||
isinstance(encoder_out, dict) | ||
and 'encoder_out' in encoder_out | ||
and 'encoder_padding_mask' in encoder_out | ||
and encoder_out['encoder_out'].size(2) == self.embed_dim | ||
): | ||
raise ValueError('Unexpected format for encoder_out') | ||
|
||
# mean pooling over time | ||
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T | ||
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C | ||
if encoder_padding_mask is not None: | ||
encoder_out = encoder_out.clone() # required because of transpose above | ||
encoder_out[encoder_padding_mask] = 0 | ||
ntokens = torch.sum(1 - encoder_padding_mask, dim=1, keepdim=True) | ||
x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out) | ||
else: | ||
x = torch.mean(encoder_out, dim=1) | ||
|
||
x = torch.tanh(self.fc1(x)) | ||
if self.dropout is not None: | ||
x = self.dropout(x) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.