From 1c6679294848f303a361cba7b306b760e299bd9c Mon Sep 17 00:00:00 2001 From: Sarthak Garg Date: Mon, 30 Sep 2019 06:56:15 -0700 Subject: [PATCH] Implementation of the paper "Jointly Learning to Align and Translate with Transformer Models" (#877) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/877 This PR implements guided alignment training described in "Jointly Learning to Align and Translate with Transformer Models (https://arxiv.org/abs/1909.02074)". In summary, it allows for training selected heads of the Transformer Model with external alignments computed by Statistical Alignment Toolkits. During inference, attention probabilities from the trained heads can be used to extract reliable alignments. In our work, we did not see any regressions in the translation performance because of guided alignment training. Pull Request resolved: https://github.com/pytorch/fairseq/pull/1095 Differential Revision: D17170337 Pulled By: myleott fbshipit-source-id: daa418bef70324d7088dbb30aa2adf9f95774859 --- README.md | 2 + .../joint_alignment_translation/README.md | 89 ++++++++++ ...t18en2de_no_norm_no_escape_no_agressive.sh | 118 +++++++++++++ fairseq/binarizer.py | 16 ++ ...l_smoothed_cross_entropy_with_alignment.py | 90 ++++++++++ fairseq/data/language_pair_dataset.py | 65 +++++++- fairseq/models/fairseq_model.py | 3 + fairseq/models/transformer.py | 155 ++++++++++++++++-- fairseq/modules/multihead_attention.py | 50 ++++-- fairseq/modules/transformer_layer.py | 16 +- fairseq/options.py | 2 + fairseq/sequence_generator.py | 142 +++++++++++++--- fairseq/sequence_scorer.py | 9 +- fairseq/tasks/fairseq_task.py | 8 +- fairseq/tasks/translation.py | 12 +- fairseq/utils.py | 45 +++++ generate.py | 5 +- interactive.py | 5 +- preprocess.py | 87 +++++++++- tests/test_binaries.py | 41 ++++- 20 files changed, 899 insertions(+), 61 deletions(-) create mode 100644 examples/joint_alignment_translation/README.md create mode 100755 examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh create mode 100644 fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py diff --git a/README.md b/README.md index c39ff22c97..e05af20c6c 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ Fairseq provides reference implementations of various sequence-to-sequence model - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) + - [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - **Non-autoregressive Transformers** - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -100,6 +101,7 @@ as well as example training and evaluation commands. - [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available We also have more detailed READMEs to reproduce results from specific papers: +- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) diff --git a/examples/joint_alignment_translation/README.md b/examples/joint_alignment_translation/README.md new file mode 100644 index 0000000000..cd9c0ea65f --- /dev/null +++ b/examples/joint_alignment_translation/README.md @@ -0,0 +1,89 @@ +# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019) + +This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074). + +## Training a joint alignment-translation model on WMT'18 En-De + +##### 1. Extract and preprocess the WMT'18 En-De data +```bash +./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh +``` + +##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign. +In this example, we use FastAlign. +```bash +git clone git@github.com:clab/fast_align.git +pushd fast_align +mkdir build +cd build +cmake .. +make +popd +ALIGN=fast_align/build/fast_align +paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de +$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align +``` + +##### 3. Preprocess the dataset with the above generated alignments. +```bash +fairseq-preprocess \ + --source-lang en --target-lang de \ + --trainpref bpe.32k/train \ + --validpref bpe.32k/valid \ + --testpref bpe.32k/test \ + --align-suffix align \ + --destdir binarized/ \ + --joined-dictionary \ + --workers 32 +``` + +##### 4. Train a model +```bash +fairseq-train \ + binarized \ + --arch transformer_wmt_en_de_big_align --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\ + --lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 3500 --label-smoothing 0.1 \ + --save-dir ./checkpoints --log-interval 1000 --max-update 60000 \ + --keep-interval-updates -1 --save-interval-updates 0 \ + --load-alignments --criterion label_smoothed_cross_entropy_with_alignment \ + --fp16 +``` + +Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer. + +If you want to train the above model with big batches (assuming your machine has 8 GPUs): +- add `--update-freq 8` to simulate training on 8x8=64 GPUs +- increase the learning rate; 0.0007 works well for big batches + +##### 5. Evaluate and generate the alignments (BPE level) +```bash +fairseq-generate \ + binarized --gen-subset test --print-alignment \ + --source-lang en --target-lang de \ + --path checkpoints/checkpoint_best.pt --beam 5 --nbest 1 +``` + +##### 6. Other resources. +The code for: +1. preparing alignment test sets +2. converting BPE level alignments to token level alignments +3. symmetrizing bidirectional alignments +4. evaluating alignments using AER metric +can be found [here](https://github.com/lilt/alignment-scripts) + +## Citation + +```bibtex +@inproceedings{garg2019jointly, + title = {Jointly Learning to Align and Translate with Transformer Models}, + author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias}, + booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)}, + address = {Hong Kong}, + month = {November}, + url = {https://arxiv.org/abs/1909.02074}, + year = {2019}, +} +``` diff --git a/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh b/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh new file mode 100755 index 0000000000..e78ed66a15 --- /dev/null +++ b/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh @@ -0,0 +1,118 @@ +#!/bin/bash + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +echo 'Cloning Moses github repository (for tokenization scripts)...' +git clone https://github.com/moses-smt/mosesdecoder.git + +SCRIPTS=mosesdecoder/scripts +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +CLEAN=$SCRIPTS/training/clean-corpus-n.perl +REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl + +URLS=( + "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" + "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" + "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" + "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz" + "http://data.statmt.org/wmt17/translation-task/dev.tgz" + "http://statmt.org/wmt14/test-full.tgz" +) +CORPORA=( + "training/europarl-v7.de-en" + "commoncrawl.de-en" + "training-parallel-nc-v13/news-commentary-v13.de-en" + "rapid2016.de-en" +) + +if [ ! -d "$SCRIPTS" ]; then + echo "Please set SCRIPTS variable correctly to point to Moses scripts." + exit +fi + +src=en +tgt=de +lang=en-de +prep=wmt18_en_de +tmp=$prep/tmp +orig=orig +dev=dev/newstest2012 +codes=32000 +bpe=bpe.32k + +mkdir -p $orig $tmp $prep $bpe + +cd $orig + +for ((i=0;i<${#URLS[@]};++i)); do + url=${URLS[i]} + file=$(basename $url) + if [ -f $file ]; then + echo "$file already exists, skipping download" + else + wget "$url" + if [ -f $file ]; then + echo "$url successfully downloaded." + else + echo "$url not successfully downloaded." + exit 1 + fi + if [ ${file: -4} == ".tgz" ]; then + tar zxvf $file + elif [ ${file: -4} == ".tar" ]; then + tar xvf $file + fi + fi +done +cd .. + +echo "pre-processing train data..." +for l in $src $tgt; do + rm -rf $tmp/train.tags.$lang.tok.$l + for f in "${CORPORA[@]}"; do + cat $orig/$f.$l | \ + perl $REM_NON_PRINT_CHAR | \ + perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l + done +done + +echo "pre-processing test data..." +for l in $src $tgt; do + if [ "$l" == "$src" ]; then + t="src" + else + t="ref" + fi + grep '\s*//g' | \ + sed -e 's/\s*<\/seg>\s*//g' | \ + sed -e "s/\’/\'/g" | \ + perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l + echo "" +done + +# apply length filtering before BPE +perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100 + +# use newstest2012 for valid +echo "pre-processing valid data..." +for l in $src $tgt; do + rm -rf $tmp/valid.$l + cat $orig/$dev.$l | \ + perl $REM_NON_PRINT_CHAR | \ + perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l +done + +mkdir output +mv $tmp/{train,valid,test}.{$src,$tgt} output + +#BPE +git clone git@github.com:glample/fastBPE.git +pushd fastBPE +g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast +popd +fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes +for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py index 44dcb256c4..744c5e3fc8 100644 --- a/fairseq/binarizer.py +++ b/fairseq/binarizer.py @@ -52,6 +52,22 @@ def replaced_consumer(word, idx): line = f.readline() return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced} + @staticmethod + def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1): + nseq = 0 + + with open(filename, 'r') as f: + f.seek(offset) + line = safe_readline(f) + while line: + if end > 0 and f.tell() > end: + break + ids = alignment_parser(line) + nseq += 1 + consumer(ids) + line = f.readline() + return {'nseq': nseq} + @staticmethod def find_offsets(filename, num_chunks): with open(filename, 'r', encoding='utf-8') as f: diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py new file mode 100644 index 0000000000..2cb5621498 --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import utils + +from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion +from . import register_criterion + + +@register_criterion('label_smoothed_cross_entropy_with_alignment') +class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + self.alignment_lambda = args.alignment_lambda + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + super(LabelSmoothedCrossEntropyCriterionWithAlignment, + LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser) + parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D', + help='weight for the alignment loss') + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample['net_input']) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['target'].size(0), + 'sample_size': sample_size, + } + + alignment_loss = None + + # Compute alignment loss only for training set and non dummy batches. + if 'alignments' in sample and sample['alignments'] is not None: + alignment_loss = self.compute_alignment_loss(sample, net_output) + + if alignment_loss is not None: + logging_output['alignment_loss'] = utils.item(alignment_loss.data) + loss += self.alignment_lambda * alignment_loss + + return loss, sample_size, logging_output + + def compute_alignment_loss(self, sample, net_output): + attn_prob = net_output[1]['attn'] + bsz, tgt_sz, src_sz = attn_prob.shape + attn = attn_prob.view(bsz * tgt_sz, src_sz) + + align = sample['alignments'] + align_weights = sample['align_weights'].float() + + if len(align) > 0: + # Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to + # the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing. + loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum() + else: + return None + + return loss + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) + sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + return { + 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0., + 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0., + 'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0., + 'ntokens': ntokens, + 'nsentences': nsentences, + 'sample_size': sample_size, + } diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 5fc1371aae..09c7193ab4 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -22,6 +22,28 @@ def merge(key, left_pad, move_eos_to_beginning=False): pad_idx, eos_idx, left_pad, move_eos_to_beginning, ) + def check_alignment(alignment, src_len, tgt_len): + if alignment is None or len(alignment) == 0: + return False + if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1: + print("| alignment size mismatch found, skipping alignment!") + return False + return True + + def compute_alignment_weights(alignments): + """ + Given a tensor of shape [:, 2] containing the source-target indices + corresponding to the alignments, a weight vector containing the + inverse frequency of each target index is computed. + For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then + a tensor containing [1., 0.5, 0.5, 1] should be returned (since target + index 3 is repeated twice) + """ + align_tgt = alignments[:, 1] + _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True) + align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] + return 1. / align_weights.float() + id = torch.LongTensor([s['id'] for s in samples]) src_tokens = merge('source', left_pad=left_pad_source) # sort by descending source length @@ -35,6 +57,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): if samples[0].get('target', None) is not None: target = merge('target', left_pad=left_pad_target) target = target.index_select(0, sort_order) + tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order) ntokens = sum(len(s['target']) for s in samples) if input_feeding: @@ -61,6 +84,32 @@ def merge(key, left_pad, move_eos_to_beginning=False): } if prev_output_tokens is not None: batch['net_input']['prev_output_tokens'] = prev_output_tokens + + if samples[0].get('alignment', None) is not None: + bsz, tgt_sz = batch['target'].shape + src_sz = batch['net_input']['src_tokens'].shape[1] + + offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) + offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz) + if left_pad_source: + offsets[:, 0] += (src_sz - src_lengths) + if left_pad_target: + offsets[:, 1] += (tgt_sz - tgt_lengths) + + alignments = [ + alignment + offset + for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) + for alignment in [samples[align_idx]['alignment'].view(-1, 2)] + if check_alignment(alignment, src_len, tgt_len) + ] + + if len(alignments) > 0: + alignments = torch.cat(alignments, dim=0) + align_weights = compute_alignment_weights(alignments) + + batch['alignments'] = alignments + batch['align_weights'] = align_weights + return batch @@ -91,6 +140,8 @@ class LanguagePairDataset(FairseqDataset): of source if it's present (default: False). append_eos_to_target (bool, optional): if set, appends eos to end of target if it's absent (default: False). + align_dataset (torch.utils.data.Dataset, optional): dataset + containing alignments. """ def __init__( @@ -98,7 +149,9 @@ def __init__( tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, max_source_positions=1024, max_target_positions=1024, - shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, + shuffle=True, input_feeding=True, + remove_eos_from_source=False, append_eos_to_target=False, + align_dataset=None, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() @@ -118,6 +171,9 @@ def __init__( self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target + self.align_dataset = align_dataset + if self.align_dataset is not None: + assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None @@ -136,11 +192,14 @@ def __getitem__(self, index): if self.src[index][-1] == eos: src_item = self.src[index][:-1] - return { + example = { 'id': index, 'source': src_item, 'target': tgt_item, } + if self.align_dataset is not None: + example['alignment'] = self.align_dataset[index] + return example def __len__(self): return len(self.src) @@ -212,3 +271,5 @@ def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) + if self.align_dataset is not None: + self.align_dataset.prefetch(indices) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index fc53a7c9d7..674de01310 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -222,6 +222,9 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs) return decoder_out + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ Similar to *forward* but only return features. diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 910c2eda09..f5f23f1b95 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -68,6 +68,7 @@ def hub_models(cls): def __init__(self, encoder, decoder): super().__init__(encoder, decoder) + self.supports_align_args = True @staticmethod def add_args(parser): @@ -195,6 +196,69 @@ def build_decoder(cls, args, tgt_dict, embed_tokens): ) +@register_model('transformer_align') +class TransformerAlignModel(TransformerModel): + """ + See "Jointly Learning to Align and Translate with Transformer + Models" (Garg et al., EMNLP 2019). + """ + + def __init__(self, encoder, decoder, args): + super().__init__(encoder, decoder) + self.alignment_heads = args.alignment_heads + self.alignment_layer = args.alignment_layer + self.full_context_alignment = args.full_context_alignment + + @staticmethod + def add_args(parser): + # fmt: off + super(TransformerAlignModel, TransformerAlignModel).add_args(parser) + parser.add_argument('--alignment-heads', type=int, metavar='D', + help='Number of cross attention heads per layer to supervised with alignments') + parser.add_argument('--alignment-layer', type=int, metavar='D', + help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.') + parser.add_argument('--full-context-alignment', type=bool, metavar='D', + help='Whether or not alignment is supervised conditioned on the full target context.') + # fmt: on + + @classmethod + def build_model(cls, args, task): + # set any default arguments + transformer_align(args) + + transformer_model = TransformerModel.build_model(args, task) + return TransformerAlignModel(transformer_model.encoder, transformer_model.decoder, args) + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_out = self.encoder(src_tokens, src_lengths) + return self.forward_decoder(prev_output_tokens, encoder_out) + + def forward_decoder( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + features_only=False, + **extra_args, + ): + attn_args = {'alignment_layer': self.alignment_layer, 'alignment_heads': self.alignment_heads} + decoder_out = self.decoder( + prev_output_tokens, + encoder_out, + **attn_args, + **extra_args, + ) + + if self.full_context_alignment: + attn_args['full_context_alignment'] = self.full_context_alignment + _, alignment_out = self.decoder( + prev_output_tokens, encoder_out, features_only=True, **attn_args, **extra_args, + ) + decoder_out[1]['attn'] = alignment_out['attn'] + + return decoder_out + + class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer @@ -423,7 +487,14 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): else: self.layer_norm = None - def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + def forward( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + features_only=False, + **extra_args, + ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape @@ -432,25 +503,53 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ - x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state) - x = self.output_layer(x) + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state, **extra_args, + ) + if not features_only: + x = self.output_layer(x) return x, extra - def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + def extract_features( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + full_context_alignment=False, + alignment_layer=None, + alignment_heads=None, + **unused, + ): """ Similar to *forward* but only return features. + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ + if alignment_layer is None: + alignment_layer = len(self.layers) - 1 + # embed positions positions = self.embed_positions( prev_output_tokens, @@ -474,15 +573,14 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta # B x T x C -> T x B x C x = x.transpose(0, 1) - attn = None - - inner_states = [x] self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) if not self_attn_padding_mask.any() and not self.cross_self_attention: self_attn_padding_mask = None # decoder layers + attn = None + inner_states = [x] for idx, layer in enumerate(self.layers): encoder_state = None if encoder_out is not None: @@ -491,15 +589,32 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta else: encoder_state = encoder_out['encoder_out'] - x, attn = layer( + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn = layer( x, encoder_state, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state, - self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, + self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, + need_attn=(idx == alignment_layer), + need_head_weights=(idx == alignment_layer), ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float() + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) if self.layer_norm: x = self.layer_norm(x) @@ -531,7 +646,12 @@ def max_positions(self): def buffered_future_mask(self, tensor): dim = tensor.size(0) - if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim: + if ( + not hasattr(self, '_future_mask') + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) return self._future_mask[:dim, :dim] @@ -668,3 +788,18 @@ def transformer_wmt_en_de_big_t2t(args): args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.activation_dropout = getattr(args, 'activation_dropout', 0.1) transformer_vaswani_wmt_en_de_big(args) + + +@register_model_architecture('transformer_align', 'transformer_align') +def transformer_align(args): + args.alignment_heads = getattr(args, 'alignment_heads', 1) + args.alignment_layer = getattr(args, 'alignment_layer', 4) + args.full_context_alignment = getattr(args, 'full_context_alignment', False) + base_architecture(args) + + +@register_model_architecture('transformer_align', 'transformer_wmt_en_de_big_align') +def transformer_wmt_en_de_big_align(args): + args.alignment_heads = getattr(args, 'alignment_heads', 1) + args.alignment_layer = getattr(args, 'alignment_layer', 4) + transformer_wmt_en_de_big(args) diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 9aaea82484..96849790f6 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -90,15 +90,37 @@ def reset_parameters(self): if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v) - def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, - need_weights=True, static_kv=False, attn_mask=None, before_softmax=False): + def forward( + self, + query, key, value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + ): """Input shape: Time x Batch x Channel - Timesteps can be masked by supplying a T x T mask in the - `attn_mask` argument. Padding elements can be excluded from - the key by passing a binary ByteTensor (`key_padding_mask`) with shape: - batch x src_len, where padding elements are indicated by 1s. + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. """ + if need_head_weights: + need_weights = True + tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] @@ -249,12 +271,11 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No if before_softmax: return attn_weights, v - attn_weights = utils.softmax( - attn_weights, dim=-1, onnx_trace=self.onnx_trace, - ).type_as(attn_weights) - attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) + attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) - attn = torch.bmm(attn_weights, v) + attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) @@ -265,9 +286,10 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No attn = self.out_proj(attn) if need_weights: - # average attention weights over heads - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.sum(dim=1) / self.num_heads + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) else: attn_weights = None diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 63c6cdf552..a3579fb990 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -195,16 +195,25 @@ def forward( prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, + need_attn=False, + need_head_weights=False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_padding_mask (ByteTensor): binary ByteTensor of shape - `(batch, src_len)` where padding elements are indicated by ``1``. + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ + if need_head_weights: + need_attn = True + residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: @@ -259,7 +268,8 @@ def forward( key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, - need_weights=(not self.training and self.need_attn), + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x diff --git a/fairseq/options.py b/fairseq/options.py index bb1e27aeb7..06a52b62ba 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -224,6 +224,8 @@ def add_preprocess_args(parser): help="comma separated, valid file prefixes") group.add_argument("--testpref", metavar="FP", default=None, help="comma separated, test file prefixes") + group.add_argument("--align-suffix", metavar="FP", default=None, + help="alignment file suffix") group.add_argument("--destdir", metavar="DIR", default="data-bin", help="destination dir") group.add_argument("--thresholdtgt", metavar="N", default=0, type=int, diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 3b100b9615..dd3fb86f7b 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -7,7 +7,8 @@ import torch -from fairseq import search +from fairseq import search, utils +from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder @@ -81,7 +82,6 @@ def __init__( self.temperature = temperature self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size - assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' assert temperature > 0, '--temperature must be greater than 0' @@ -98,14 +98,7 @@ def __init__( self.search = search.BeamSearch(tgt_dict) @torch.no_grad() - def generate( - self, - models, - sample, - prefix_tokens=None, - bos_token=None, - **kwargs - ): + def generate(self, models, sample, **kwargs): """Generate a batch of translations. Args: @@ -113,8 +106,21 @@ def generate( sample (dict): batch prefix_tokens (torch.LongTensor, optional): force decoder to begin with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) """ model = EnsembleModel(models) + return self._generate(model, sample, **kwargs) + + @torch.no_grad() + def _generate( + self, + model, + sample, + prefix_tokens=None, + bos_token=None, + **kwargs + ): if not self.retain_dropout: model.eval() @@ -155,7 +161,6 @@ def generate( tokens_buf = tokens.clone() tokens[:, 0] = self.eos if bos_token is None else bos_token attn, attn_buf = None, None - nonpad_idxs = None # The blacklist indicates candidates that should be ignored. # For example, suppose we're sampling and have already finalized 2/5 @@ -251,17 +256,15 @@ def get_hypo(): if attn_clone is not None: # remove padding tokens from attn scores - hypo_attn = attn_clone[i][nonpad_idxs[sent]] - _, alignment = hypo_attn.max(dim=0) + hypo_attn = attn_clone[i] else: hypo_attn = None - alignment = None return { 'tokens': tokens_clone[i], 'score': score, 'attention': hypo_attn, # src_len x tgt_len - 'alignment': alignment, + 'alignment': None, 'positional_scores': pos_scores[i], } @@ -345,7 +348,6 @@ def replicate_first_beam(tensor, mask): if attn is None: attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) attn_buf = attn.clone() - nonpad_idxs = src_tokens.ne(self.pad) attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(lprobs) @@ -512,7 +514,6 @@ def calculate_banned_tokens(bbsz_idx): # sort by score descending for sent in range(len(finalized)): finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) - return finalized @@ -577,9 +578,11 @@ def _decode_one( temperature=1., ): if self.incremental_states is not None: - decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model])) + decoder_out = list(model.forward_decoder( + tokens, encoder_out=encoder_out, incremental_state=self.incremental_states[model], + )) else: - decoder_out = list(model.decoder(tokens, encoder_out)) + decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out)) decoder_out[0] = decoder_out[0][:, -1:, :] if temperature != 1.: decoder_out[0].div_(temperature) @@ -605,3 +608,104 @@ def reorder_incremental_state(self, new_order): return for model in self.models: model.decoder.reorder_incremental_state(self.incremental_states[model], new_order) + + +class SequenceGeneratorWithAlignment(SequenceGenerator): + + def __init__(self, tgt_dict, left_pad_target=False, **kwargs): + """Generates translations of a given source sentence. + + Produces alignments following "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + left_pad_target (bool, optional): Whether or not the + hypothesis should be left padded or not when they are + teacher forced for generating alignments. + """ + super().__init__(tgt_dict, **kwargs) + self.left_pad_target = left_pad_target + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + model = EnsembleModelWithAlignment(models) + finalized = super()._generate(model, sample, **kwargs) + + src_tokens = sample['net_input']['src_tokens'] + bsz = src_tokens.shape[0] + beam_size = self.beam_size + src_tokens, src_lengths, prev_output_tokens, tgt_tokens = \ + self._prepare_batch_for_alignment(sample, finalized) + if any(getattr(m, 'full_context_alignment', False) for m in model.models): + attn = model.forward_align(src_tokens, src_lengths, prev_output_tokens) + else: + attn = [ + finalized[i // beam_size][i % beam_size]['attention'].transpose(1, 0) + for i in range(bsz * beam_size) + ] + + # Process the attn matrix to extract hard alignments. + for i in range(bsz * beam_size): + alignment = utils.extract_hard_alignment(attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos) + finalized[i // beam_size][i % beam_size]['alignment'] = alignment + return finalized + + def _prepare_batch_for_alignment(self, sample, hypothesis): + src_tokens = sample['net_input']['src_tokens'] + bsz = src_tokens.shape[0] + src_tokens = src_tokens[:, None, :].expand(-1, self.beam_size, -1).contiguous().view(bsz * self.beam_size, -1) + src_lengths = sample['net_input']['src_lengths'] + src_lengths = src_lengths[:, None].expand(-1, self.beam_size).contiguous().view(bsz * self.beam_size) + prev_output_tokens = data_utils.collate_tokens( + [beam['tokens'] for example in hypothesis for beam in example], + self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=True, + ) + tgt_tokens = data_utils.collate_tokens( + [beam['tokens'] for example in hypothesis for beam in example], + self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=False, + ) + return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +class EnsembleModelWithAlignment(EnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + def forward_align(self, src_tokens, src_lengths, prev_output_tokens): + avg_attn = None + for model in self.models: + decoder_out = model(src_tokens, src_lengths, prev_output_tokens) + attn = decoder_out[1]['attn'] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(self.models) > 1: + avg_attn.div_(len(self.models)) + return avg_attn + + def _decode_one( + self, tokens, model, encoder_out, incremental_states, log_probs, + temperature=1., + ): + if self.incremental_states is not None: + decoder_out = list(model.forward_decoder( + tokens, + encoder_out=encoder_out, + incremental_state=self.incremental_states[model], + )) + else: + decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out)) + decoder_out[0] = decoder_out[0][:, -1:, :] + if temperature != 1.: + decoder_out[0].div_(temperature) + attn = decoder_out[1] + if type(attn) is dict: + attn = attn.get('attn', None) + if attn is not None: + attn = attn[:, -1, :] + probs = model.get_normalized_probs(decoder_out, log_probs=log_probs) + probs = probs[:, -1, :] + return probs, attn diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py index d125422340..75ff4cf051 100644 --- a/fairseq/sequence_scorer.py +++ b/fairseq/sequence_scorer.py @@ -14,6 +14,7 @@ class SequenceScorer(object): def __init__(self, tgt_dict, softmax_batch=None): self.pad = tgt_dict.pad() + self.eos = tgt_dict.eos() self.softmax_batch = softmax_batch or sys.maxsize assert self.softmax_batch > 0 @@ -44,6 +45,7 @@ def gather_target_probs(probs, target): ) return probs + orig_target = sample['target'] # compute scores for each model in the ensemble @@ -53,6 +55,8 @@ def gather_target_probs(probs, target): model.eval() decoder_out = model.forward(**net_input) attn = decoder_out[1] + if type(attn) is dict: + attn = attn.get('attn', None) batched = batch_for_softmax(decoder_out, orig_target) probs, idx = None, 0 @@ -100,8 +104,9 @@ def gather_target_probs(probs, target): avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] score_i = avg_probs_i.sum() / tgt_len if avg_attn is not None: - avg_attn_i = avg_attn[i, start_idxs[i]:] - _, alignment = avg_attn_i.max(dim=0) + avg_attn_i = avg_attn[i] + alignment = utils.extract_hard_alignment(avg_attn_i, sample['net_input']['src_tokens'][i], + sample['target'][i], self.pad, self.eos) else: avg_attn_i = alignment = None hypos.append([{ diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index ba5695785d..538532b20e 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -198,8 +198,12 @@ def build_generator(self, args): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer(self.target_dictionary) else: - from fairseq.sequence_generator import SequenceGenerator - return SequenceGenerator( + from fairseq.sequence_generator import SequenceGenerator, SequenceGeneratorWithAlignment + if getattr(args, 'print_alignment', False): + seq_gen_cls = SequenceGeneratorWithAlignment + else: + seq_gen_cls = SequenceGenerator + return seq_gen_cls( self.target_dictionary, beam_size=getattr(args, 'beam', 5), max_len_a=getattr(args, 'max_len_a', 0), diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index f3d60403ba..353e640bf6 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -24,7 +24,7 @@ def load_langpair_dataset( tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, - max_target_positions, prepend_bos=False, + max_target_positions, prepend_bos=False, load_alignments=False, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) @@ -74,6 +74,12 @@ def split_exists(split, src, tgt, lang, data_path): src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + align_dataset = None + if load_alignments: + align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) + if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): + align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) + return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, @@ -81,6 +87,7 @@ def split_exists(split, src, tgt, lang, data_path): left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, + align_dataset=align_dataset, ) @@ -120,6 +127,8 @@ def add_args(parser): help='load the dataset lazily') parser.add_argument('--raw-text', action='store_true', help='load raw text dataset') + parser.add_argument('--load-alignments', action='store_true', + help='load the binarized alignments') parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', help='pad the source on the left') parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', @@ -193,6 +202,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, + load_alignments=self.args.load_alignments, ) def build_dataset_for_inference(self, src_tokens, src_lengths): diff --git a/fairseq/utils.py b/fairseq/utils.py index 80ecb6d083..9dd41fbfea 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -16,6 +16,7 @@ import torch import torch.nn.functional as F +from itertools import accumulate from fairseq.modules import gelu, gelu_accurate @@ -367,3 +368,47 @@ def set_torch_seed(seed): assert isinstance(seed, int) torch.manual_seed(seed) torch.cuda.manual_seed(seed) + + +def parse_alignment(line): + """ + Parses a single line from the alingment file. + + Args: + line (str): String containing the alignment of the format: + - - .. + -. All indices are 0 indexed. + + Returns: + torch.IntTensor: packed alignments of shape (2 * m). + """ + alignments = line.strip().split() + parsed_alignment = torch.IntTensor(2 * len(alignments)) + for idx, alignment in enumerate(alignments): + src_idx, tgt_idx = alignment.split('-') + parsed_alignment[2 * idx] = int(src_idx) + parsed_alignment[2 * idx + 1] = int(tgt_idx) + return parsed_alignment + + +def get_token_to_word_mapping(tokens, exclude_list): + n = len(tokens) + word_start = [int(token not in exclude_list) for token in tokens] + word_idx = list(accumulate(word_start)) + token_to_word = {i: word_idx[i] for i in range(n)} + return token_to_word + + +def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero().squeeze(dim=-1) + src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero().squeeze(dim=-1) + src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) + tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) + alignment = [] + if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): + attn_valid = attn[tgt_valid] + attn_valid[:, src_invalid] = float('-inf') + _, src_indices = attn_valid.max(dim=1) + for tgt_idx, src_idx in zip(tgt_valid, src_indices): + alignment.append((src_token_to_word[src_idx.item()] - 1, tgt_token_to_word[tgt_idx.item()] - 1)) + return alignment diff --git a/generate.py b/generate.py index 6de1a69abd..aba611d4b0 100644 --- a/generate.py +++ b/generate.py @@ -137,7 +137,7 @@ def main(args): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, - alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, + alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, @@ -156,7 +156,7 @@ def main(args): if args.print_alignment: print('A-{}\t{}'.format( sample_id, - ' '.join(map(lambda x: str(utils.item(x)), alignment)) + ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) )) if args.print_step: @@ -180,6 +180,7 @@ def main(args): num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) + return scorer diff --git a/interactive.py b/interactive.py index d9d547a974..36e2bd0ca9 100644 --- a/interactive.py +++ b/interactive.py @@ -162,7 +162,7 @@ def decode_fn(x): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, - alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, + alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, @@ -174,9 +174,10 @@ def decode_fn(x): ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) )) if args.print_alignment: + alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment]) print('A-{}\t{}'.format( id, - ' '.join(map(lambda x: str(utils.item(x)), alignment)) + alignment_str )) # update running id counter diff --git a/preprocess.py b/preprocess.py index a157feeb68..538ff2b006 100644 --- a/preprocess.py +++ b/preprocess.py @@ -157,6 +157,60 @@ def merge_result(worker_result): ) ) + def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers): + nseq = [0] + + def merge_result(worker_result): + nseq[0] += worker_result['nseq'] + + input_file = input_prefix + offsets = Binarizer.find_offsets(input_file, num_workers) + pool = None + if num_workers > 1: + pool = Pool(processes=num_workers - 1) + for worker_id in range(1, num_workers): + prefix = "{}{}".format(output_prefix, worker_id) + pool.apply_async( + binarize_alignments, + ( + args, + input_file, + utils.parse_alignment, + prefix, + offsets[worker_id], + offsets[worker_id + 1] + ), + callback=merge_result + ) + pool.close() + + ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"), + impl=args.dataset_impl) + + merge_result( + Binarizer.binarize_alignments( + input_file, utils.parse_alignment, lambda t: ds.add_item(t), + offset=0, end=offsets[1] + ) + ) + if num_workers > 1: + pool.join() + for worker_id in range(1, num_workers): + prefix = "{}{}".format(output_prefix, worker_id) + temp_file_path = dataset_dest_prefix(args, prefix, None) + ds.merge_file_(temp_file_path) + os.remove(indexed_dataset.data_file_path(temp_file_path)) + os.remove(indexed_dataset.index_file_path(temp_file_path)) + + ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) + + print( + "| [alignments] {}: parsed {} alignments".format( + input_file, + nseq[0] + ) + ) + def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.dataset_impl == "raw": # Copy original text file to destination folder @@ -180,9 +234,19 @@ def make_all(lang, vocab): outprefix = "test{}".format(k) if k > 0 else "test" make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) + def make_all_alignments(): + if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): + make_binary_alignment_dataset(args.trainpref + "." + args.align_suffix, "train.align", num_workers=args.workers) + if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): + make_binary_alignment_dataset(args.validpref + "." + args.align_suffix, "valid.align", num_workers=args.workers) + if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): + make_binary_alignment_dataset(args.testpref + "." + args.align_suffix, "test.align", num_workers=args.workers) + make_all(args.source_lang, src_dict) if target: make_all(args.target_lang, tgt_dict) + if args.align_suffix: + make_all_alignments() print("| Wrote preprocessed data to {}".format(args.destdir)) @@ -242,11 +306,28 @@ def consumer(tensor): return res +def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end): + ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"), + impl=args.dataset_impl, vocab_size=None) + + def consumer(tensor): + ds.add_item(tensor) + + res = Binarizer.binarize_alignments(filename, parse_alignment, consumer, offset=offset, + end=end) + ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) + return res + + def dataset_dest_prefix(args, output_prefix, lang): base = "{}/{}".format(args.destdir, output_prefix) - lang_part = ( - ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) if lang is not None else "" - ) + if lang is not None: + lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) + elif args.only_source: + lang_part = "" + else: + lang_part = ".{}-{}".format(args.source_lang, args.target_lang) + return "{}{}".format(base, lang_part) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index f77806bd6a..113901ab06 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -266,6 +266,27 @@ def test_mixture_of_experts(self): '--gen-expert', '0' ]) + def test_alignment(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_alignment') as data_dir: + create_dummy_data(data_dir, alignment=True) + preprocess_translation_data(data_dir, ['--align-suffix', 'align']) + train_translation_model( + data_dir, + 'transformer_align', + [ + '--encoder-layers', '2', + '--decoder-layers', '2', + '--encoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--load-alignments', + '--alignment-layer', '1', + '--criterion', 'label_smoothed_cross_entropy_with_alignment' + ], + run_validation=True, + ) + generate_main(data_dir) + class TestStories(unittest.TestCase): @@ -484,7 +505,7 @@ def test_optimizers(self): generate_main(data_dir) -def create_dummy_data(data_dir, num_examples=1000, maxlen=20): +def create_dummy_data(data_dir, num_examples=1000, maxlen=20, alignment=False): def _create_dummy_data(filename): data = torch.rand(num_examples * maxlen) @@ -497,6 +518,20 @@ def _create_dummy_data(filename): print(ex_str, file=h) offset += ex_len + def _create_dummy_alignment_data(filename_src, filename_tgt, filename): + with open(os.path.join(data_dir, filename_src), 'r') as src_f, \ + open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \ + open(os.path.join(data_dir, filename), 'w') as h: + for src, tgt in zip(src_f, tgt_f): + src_len = len(src.split()) + tgt_len = len(tgt.split()) + avg_len = (src_len + tgt_len) // 2 + num_alignments = random.randint(avg_len // 2, 2 * avg_len) + src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() + tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() + ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)]) + print(ex_str, file=h) + _create_dummy_data('train.in') _create_dummy_data('train.out') _create_dummy_data('valid.in') @@ -504,6 +539,10 @@ def _create_dummy_data(filename): _create_dummy_data('test.in') _create_dummy_data('test.out') + if alignment: + _create_dummy_alignment_data('train.in', 'train.out', 'train.align') + _create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align') + _create_dummy_alignment_data('test.in', 'test.out', 'test.align') def preprocess_translation_data(data_dir, extra_flags=None): preprocess_parser = options.get_preprocessing_parser()