Skip to content

Commit

Permalink
Implementation of the paper "Jointly Learning to Align and Translate …
Browse files Browse the repository at this point in the history
…with Transformer Models" (#877)

Summary:
Pull Request resolved: fairinternal/fairseq-py#877

This PR implements guided alignment training described in  "Jointly Learning to Align and Translate with Transformer Models (https://arxiv.org/abs/1909.02074)".

In summary, it allows for training selected heads of the Transformer Model with external alignments computed by Statistical Alignment Toolkits. During inference, attention probabilities from the trained heads can be used to extract reliable alignments. In our work, we did not see any regressions in the translation performance because of guided alignment training.
Pull Request resolved: #1095

Differential Revision: D17170337

Pulled By: myleott

fbshipit-source-id: daa418bef70324d7088dbb30aa2adf9f95774859
  • Loading branch information
sarthakgarg authored and facebook-github-bot committed Sep 30, 2019
1 parent acb6fba commit 1c66792
Show file tree
Hide file tree
Showing 20 changed files with 899 additions and 61 deletions.
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -33,6 +33,7 @@ Fairseq provides reference implementations of various sequence-to-sequence model
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
Expand Down Expand Up @@ -100,6 +101,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available

We also have more detailed READMEs to reproduce results from specific papers:
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
Expand Down
89 changes: 89 additions & 0 deletions examples/joint_alignment_translation/README.md
@@ -0,0 +1,89 @@
# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)

This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074).

## Training a joint alignment-translation model on WMT'18 En-De

##### 1. Extract and preprocess the WMT'18 En-De data
```bash
./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
```

##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign.
In this example, we use FastAlign.
```bash
git clone git@github.com:clab/fast_align.git
pushd fast_align
mkdir build
cd build
cmake ..
make
popd
ALIGN=fast_align/build/fast_align
paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de
$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align
```

##### 3. Preprocess the dataset with the above generated alignments.
```bash
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref bpe.32k/train \
--validpref bpe.32k/valid \
--testpref bpe.32k/test \
--align-suffix align \
--destdir binarized/ \
--joined-dictionary \
--workers 32
```

##### 4. Train a model
```bash
fairseq-train \
binarized \
--arch transformer_wmt_en_de_big_align --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\
--lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
--max-tokens 3500 --label-smoothing 0.1 \
--save-dir ./checkpoints --log-interval 1000 --max-update 60000 \
--keep-interval-updates -1 --save-interval-updates 0 \
--load-alignments --criterion label_smoothed_cross_entropy_with_alignment \
--fp16
```

Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.

If you want to train the above model with big batches (assuming your machine has 8 GPUs):
- add `--update-freq 8` to simulate training on 8x8=64 GPUs
- increase the learning rate; 0.0007 works well for big batches

##### 5. Evaluate and generate the alignments (BPE level)
```bash
fairseq-generate \
binarized --gen-subset test --print-alignment \
--source-lang en --target-lang de \
--path checkpoints/checkpoint_best.pt --beam 5 --nbest 1
```

##### 6. Other resources.
The code for:
1. preparing alignment test sets
2. converting BPE level alignments to token level alignments
3. symmetrizing bidirectional alignments
4. evaluating alignments using AER metric
can be found [here](https://github.com/lilt/alignment-scripts)

## Citation

```bibtex
@inproceedings{garg2019jointly,
title = {Jointly Learning to Align and Translate with Transformer Models},
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
address = {Hong Kong},
month = {November},
url = {https://arxiv.org/abs/1909.02074},
year = {2019},
}
```
@@ -0,0 +1,118 @@
#!/bin/bash

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git

SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl

URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
"http://statmt.org/wmt14/test-full.tgz"
)
CORPORA=(
"training/europarl-v7.de-en"
"commoncrawl.de-en"
"training-parallel-nc-v13/news-commentary-v13.de-en"
"rapid2016.de-en"
)

if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi

src=en
tgt=de
lang=en-de
prep=wmt18_en_de
tmp=$prep/tmp
orig=orig
dev=dev/newstest2012
codes=32000
bpe=bpe.32k

mkdir -p $orig $tmp $prep $bpe

cd $orig

for ((i=0;i<${#URLS[@]};++i)); do
url=${URLS[i]}
file=$(basename $url)
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit 1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
cd ..

echo "pre-processing train data..."
for l in $src $tgt; do
rm -rf $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l
done
done

echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l
echo ""
done

# apply length filtering before BPE
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100

# use newstest2012 for valid
echo "pre-processing valid data..."
for l in $src $tgt; do
rm -rf $tmp/valid.$l
cat $orig/$dev.$l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l
done

mkdir output
mv $tmp/{train,valid,test}.{$src,$tgt} output

#BPE
git clone git@github.com:glample/fastBPE.git
pushd fastBPE
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
popd
fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes
for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done
16 changes: 16 additions & 0 deletions fairseq/binarizer.py
Expand Up @@ -52,6 +52,22 @@ def replaced_consumer(word, idx):
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}

@staticmethod
def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1):
nseq = 0

with open(filename, 'r') as f:
f.seek(offset)
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = alignment_parser(line)
nseq += 1
consumer(ids)
line = f.readline()
return {'nseq': nseq}

@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:
Expand Down
90 changes: 90 additions & 0 deletions fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
@@ -0,0 +1,90 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

from fairseq import utils

from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from . import register_criterion


@register_criterion('label_smoothed_cross_entropy_with_alignment')
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):

def __init__(self, args, task):
super().__init__(args, task)
self.alignment_lambda = args.alignment_lambda

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
super(LabelSmoothedCrossEntropyCriterionWithAlignment,
LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser)
parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D',
help='weight for the alignment loss')

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}

alignment_loss = None

# Compute alignment loss only for training set and non dummy batches.
if 'alignments' in sample and sample['alignments'] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)

if alignment_loss is not None:
logging_output['alignment_loss'] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss

return loss, sample_size, logging_output

def compute_alignment_loss(self, sample, net_output):
attn_prob = net_output[1]['attn']
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)

align = sample['alignments']
align_weights = sample['align_weights'].float()

if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum()
else:
return None

return loss

@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}

1 comment on commit 1c66792

@yotam319
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not compile, error on line 242 with **extra

Please sign in to comment.