-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented applying dropout at inference time (#2308)
Summary: Pull Request resolved: #2308 Implemented Monte Carlo dropout. Added README to reproduce the results from our paper that applies this idea for unsupervised quality estimation of NMT (joint work of Facebook AI and the University of Sheffield): Marina Fomicheva, Shuo Sun, Lisa Yankovskaya, Frédéric Blain, Francisco Guzmán, Mark Fishel, Nikolaos Aletras, Vishrav Chaudhary, Lucia Specia. Unsupervised Quality Estimation for Neural Machine Translation. Accepted to TACL Retaining dropout at test time is not possible in the current code base. The statement ``` if not self.retain_dropout: model.eval() ``` in `SequenceGenerator` does not have any effect, since model `training` attribute is already set to False by the method `make_generate_fast_`, which is applied before initializing `SequenceGenerator` in `generate.py`. `make_generate_fast_` throws an exception when trying to set `training` to True after its application. Also, if I am not mistaken `self.training=True` can have other effects, so setting it to True only for the purpose of retaining dropout at test time might be confusing. I propose an alternative implementation where `retain_dropout` is an attribute of FairseqModel class. # Before submitting - [N] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [Y] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [Y] Did you make sure to update the docs? - [Y] Did you write any new necessary tests? ## What does this PR do? New feature. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: #2151 Reviewed By: ngoyal2707 Differential Revision: D22048889 Pulled By: myleott fbshipit-source-id: 0d0d4784a7314fc7a45b76341fd3b8232b3e2cf0
- Loading branch information
1 parent
625e501
commit 3b7cf75
Showing
42 changed files
with
698 additions
and
216 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020) | ||
|
||
This page includes instructions for reproducing results from the paper [Unsupervised Quality Estimation for Neural | ||
Machine Translation (Fomicheva et al., 2020)](https://arxiv.org/abs/2005.10608) | ||
|
||
## Requirements: | ||
|
||
* mosesdecoder: https://github.com/moses-smt/mosesdecoder | ||
* subword-nmt: https://github.com/rsennrich/subword-nmt | ||
* flores: https://github.com/facebookresearch/flores | ||
|
||
## Download Models and Test Data | ||
|
||
Download translation models and test data from [MLQE dataset repository](https://github.com/facebookresearch/mlqe). | ||
|
||
## Set up: | ||
|
||
Given a testset consisting of source sentences and reference translations: | ||
|
||
* `SRC_LANG`: source language | ||
* `TGT_LANG`: target language | ||
* `INPUT`: input prefix, such that the file `$INPUT.$SRC_LANG` contains source sentences and `$INPUT.$TGT_LANG` | ||
contains the reference sentences | ||
* `OUTPUT_DIR`: output path to store results | ||
* `MOSES_DECODER`: path to mosesdecoder installation | ||
* `BPE_ROOT`: path to subword-nmt installation | ||
* `BPE`: path to BPE model | ||
* `MODEL_DIR`: directory containing the NMT model `.pt` file as well as the source and target vocabularies. | ||
* `TMP`: directory for intermediate temporary files | ||
* `GPU`: if translating with GPU, id of the GPU to use for inference | ||
* `DROPOUT_N`: number of stochastic forward passes | ||
|
||
`$DROPOUT_N` is set to 30 in the experiments reported in the paper. However, we observed that increasing it beyond 10 | ||
does not bring substantial improvements. | ||
|
||
## Translate the data using standard decoding | ||
|
||
Preprocess the input data: | ||
``` | ||
for LANG in $SRC_LANG $TGT_LANG; do | ||
perl $MOSES_DECODER/scripts/tokenizer/tokenizer.perl -threads 80 -a -l $LANG < $INPUT.$LANG > $TMP/preprocessed.tok.$LANG | ||
python $BPE_ROOT/apply_bpe.py -c ${BPE} < $TMP/preprocessed.tok.$LANG > $TMP/preprocessed.tok.bpe.$LANG | ||
done | ||
``` | ||
|
||
Binarize the data for faster translation: | ||
|
||
``` | ||
fairseq-preprocess --srcdict $MODEL_DIR/dict.$SRC_LANG.txt --tgtdict $MODEL_DIR/dict.$TGT_LANG.txt | ||
--source-lang ${SRC_LANG} --target-lang ${TGT_LANG} --testpref $TMP/preprocessed.tok.bpe --destdir $TMP/bin --workers 4 | ||
``` | ||
|
||
Translate | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=$GPU fairseq-generate $TMP/bin --path ${MODEL_DIR}/${SRC_LANG}-${TGT_LANG}.pt --beam 5 | ||
--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 > $TMP/fairseq.out | ||
grep ^H $TMP/fairseq.out | cut -f3- > $TMP/mt.out | ||
``` | ||
|
||
Post-process | ||
|
||
``` | ||
sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/mt.out | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl | ||
-l $TGT_LANG > $OUTPUT_DIR/mt.out | ||
``` | ||
|
||
## Produce uncertainty estimates | ||
|
||
### Scoring | ||
|
||
Make temporary files to store the translations repeated N times. | ||
|
||
``` | ||
python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/preprocessed.tok.bpe.$SRC_LANG -n $DROPOUT_N | ||
-o $TMP/repeated.$SRC_LANG | ||
python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/mt.out -n $DROPOUT_N -o $TMP/repeated.$TGT_LANG | ||
fairseq-preprocess --srcdict ${MODEL_DIR}/dict.${SRC_LANG}.txt $TGT_DIC --source-lang ${SRC_LANG} | ||
--target-lang ${TGT_LANG} --testpref ${TMP}/repeated --destdir ${TMP}/bin-repeated | ||
``` | ||
|
||
Produce model scores for the generated translations using `--retain-dropout` option to apply dropout at inference time: | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5 | ||
--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout | ||
--retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer | ||
TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out | ||
grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores | ||
``` | ||
|
||
Use `--retain-dropout-modules` to specify the modules. By default, dropout is applied in the same places | ||
as for training. | ||
|
||
Compute the mean of the resulting output distribution: | ||
|
||
``` | ||
python $SCRIPTS/scripts/uncertainty/aggregate_scores.py -i $TMP/dropout.scores -o $OUTPUT_DIR/dropout.scores.mean | ||
-n $DROPOUT_N | ||
``` | ||
|
||
### Generation | ||
|
||
Produce multiple translation hypotheses for the same source using `--retain-dropout` option: | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt | ||
--beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --retain-dropout | ||
--unkpen 5 --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder | ||
TransformerEncoderLayer TransformerDecoderLayer --seed 46 > $TMP/dropout.generation.out | ||
grep ^H $TMP/dropout.generation.out | cut -f3- > $TMP/dropout.hypotheses_ | ||
sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/dropout.hypotheses_ | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl | ||
-l $TGT_LANG > $TMP/dropout.hypotheses | ||
``` | ||
|
||
Compute similarity between multiple hypotheses corresponding to the same source sentence using Meteor | ||
evaluation metric: | ||
``` | ||
python meteor.py -i $TMP/dropout.hypotheses -m <path_to_meteor_installation> -n $DROPOUT_N -o | ||
$OUTPUT_DIR/dropout.gen.sim.meteor | ||
``` |
40 changes: 40 additions & 0 deletions
40
examples/unsupervised_quality_estimation/aggregate_scores.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import numpy as np | ||
import sys | ||
|
||
|
||
aggregate_funcs = { | ||
'std': np.std, | ||
'var': np.var, | ||
'median': np.median, | ||
'mean': np.mean, | ||
'min': np.min, | ||
'max': np.max, | ||
} | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-i', '--input_file', required=True, type=str) | ||
parser.add_argument('-n', '--repeat_times', required=True, type=int) | ||
parser.add_argument('-o', '--output_file', required=False) | ||
parser.add_argument('-f', '--func', required=False, default='mean') | ||
args = parser.parse_args() | ||
|
||
stream = open(args.output_file, 'w') if args.output_file else sys.stdout | ||
|
||
segment_scores = [] | ||
for line in open(args.input_file): | ||
segment_scores.append(float(line.strip())) | ||
if len(segment_scores) == args.repeat_times: | ||
stream.write('{}\n'.format(aggregate_funcs[args.func](segment_scores))) | ||
segment_scores = [] | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import os | ||
import sys | ||
import subprocess | ||
import tempfile | ||
import math | ||
|
||
from itertools import combinations | ||
from collections import defaultdict | ||
|
||
|
||
def read_translations(path, n_repeats): | ||
segment_counter = 0 | ||
segment_translations = [] | ||
translations = defaultdict(list) | ||
for line in open(path): | ||
segment_translations.append(' '.join(line.split())) | ||
if len(segment_translations) == n_repeats: | ||
translations[segment_counter] = segment_translations | ||
segment_translations = [] | ||
segment_counter += 1 | ||
return translations | ||
|
||
|
||
def generate_input(translations, n_repeats): | ||
_, ref_path = tempfile.mkstemp() | ||
_, mt_path = tempfile.mkstemp() | ||
ref_fh = open(ref_path, 'w') | ||
mt_fh = open(mt_path, 'w') | ||
for segid in sorted(translations.keys()): | ||
assert len(translations[segid]) == n_repeats | ||
indexes = combinations(range(n_repeats), 2) | ||
for idx1, idx2 in indexes: | ||
mt_fh.write(translations[segid][idx1].strip() + '\n') | ||
ref_fh.write(translations[segid][idx2].strip() + '\n') | ||
sys.stderr.write('\nSaved translations to %s and %s' % (ref_path, mt_path)) | ||
return ref_path, mt_path | ||
|
||
|
||
def run_meteor(ref_path, mt_path, metric_path, lang='en'): | ||
_, out_path = tempfile.mkstemp() | ||
subprocess.call([ | ||
'java', '-Xmx2G', '-jar', metric_path, mt_path, ref_path, | ||
'-p', '0.5 0.2 0.6 0.75', # default parameters, only changed alpha to give equal weight to P and R | ||
'-norm', | ||
'-l', lang], stdout=open(out_path, 'w')) | ||
os.remove(ref_path) | ||
os.remove(mt_path) | ||
sys.stderr.write('\nSaved Meteor output to %s' % out_path) | ||
return out_path | ||
|
||
|
||
def read_output(meteor_output_path, n_repeats): | ||
n_combinations = math.factorial(n_repeats)/(math.factorial(2) * math.factorial(n_repeats - 2)) | ||
raw_scores = [] | ||
average_scores = [] | ||
for line in open(meteor_output_path): | ||
if not line.startswith('Segment '): | ||
continue | ||
score = float(line.strip().split('\t')[1]) | ||
raw_scores.append(score) | ||
if len(raw_scores) == n_combinations: | ||
average_scores.append(sum(raw_scores)/n_combinations) | ||
raw_scores = [] | ||
os.remove(meteor_output_path) | ||
return average_scores | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-i', '--input') | ||
parser.add_argument('-n', '--repeat_times', type=int) | ||
parser.add_argument('-m', '--meteor') | ||
parser.add_argument('-o', '--output') | ||
args = parser.parse_args() | ||
|
||
translations = read_translations(args.infile, args.repetitions) | ||
sys.stderr.write('\nGenerating input for Meteor...') | ||
ref_path, mt_path = generate_input(translations, args.repetitions) | ||
sys.stderr.write('\nRunning Meteor...') | ||
out_path = run_meteor(ref_path, mt_path, args.meteor) | ||
sys.stderr.write('\nReading output...') | ||
scores = read_output(out_path, args.repetitions) | ||
sys.stderr.write('\nWriting results...') | ||
with open(args.output, 'w') as o: | ||
for scr in scores: | ||
o.write('{}\n'.format(scr)) | ||
o.close() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import sys | ||
|
||
|
||
def _normalize_spaces(line): | ||
return ' '.join(line.split()) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-i', '--input_file', required=True, type=str) | ||
parser.add_argument('-n', '--repeat_times', required=True, type=int) | ||
parser.add_argument('-o', '--output_file', required=False, type=str) | ||
args = parser.parse_args() | ||
stream = open(args.output_file, 'w') if args.output_file else sys.stdout | ||
|
||
for line in open(args.input_file): | ||
for _ in range(args.repeat_times): | ||
stream.write(_normalize_spaces(line) + '\n') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.