diff --git a/README.md b/README.md index c39ff22c97..e05af20c6c 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ Fairseq provides reference implementations of various sequence-to-sequence model - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) + - [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - **Non-autoregressive Transformers** - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) @@ -100,6 +101,7 @@ as well as example training and evaluation commands. - [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available We also have more detailed READMEs to reproduce results from specific papers: +- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) diff --git a/examples/joint_alignment_translation/README.md b/examples/joint_alignment_translation/README.md new file mode 100644 index 0000000000..cd9c0ea65f --- /dev/null +++ b/examples/joint_alignment_translation/README.md @@ -0,0 +1,89 @@ +# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019) + +This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074). + +## Training a joint alignment-translation model on WMT'18 En-De + +##### 1. Extract and preprocess the WMT'18 En-De data +```bash +./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh +``` + +##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign. +In this example, we use FastAlign. +```bash +git clone git@github.com:clab/fast_align.git +pushd fast_align +mkdir build +cd build +cmake .. +make +popd +ALIGN=fast_align/build/fast_align +paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de +$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align +``` + +##### 3. Preprocess the dataset with the above generated alignments. +```bash +fairseq-preprocess \ + --source-lang en --target-lang de \ + --trainpref bpe.32k/train \ + --validpref bpe.32k/valid \ + --testpref bpe.32k/test \ + --align-suffix align \ + --destdir binarized/ \ + --joined-dictionary \ + --workers 32 +``` + +##### 4. Train a model +```bash +fairseq-train \ + binarized \ + --arch transformer_wmt_en_de_big_align --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\ + --lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 3500 --label-smoothing 0.1 \ + --save-dir ./checkpoints --log-interval 1000 --max-update 60000 \ + --keep-interval-updates -1 --save-interval-updates 0 \ + --load-alignments --criterion label_smoothed_cross_entropy_with_alignment \ + --fp16 +``` + +Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer. + +If you want to train the above model with big batches (assuming your machine has 8 GPUs): +- add `--update-freq 8` to simulate training on 8x8=64 GPUs +- increase the learning rate; 0.0007 works well for big batches + +##### 5. Evaluate and generate the alignments (BPE level) +```bash +fairseq-generate \ + binarized --gen-subset test --print-alignment \ + --source-lang en --target-lang de \ + --path checkpoints/checkpoint_best.pt --beam 5 --nbest 1 +``` + +##### 6. Other resources. +The code for: +1. preparing alignment test sets +2. converting BPE level alignments to token level alignments +3. symmetrizing bidirectional alignments +4. evaluating alignments using AER metric +can be found [here](https://github.com/lilt/alignment-scripts) + +## Citation + +```bibtex +@inproceedings{garg2019jointly, + title = {Jointly Learning to Align and Translate with Transformer Models}, + author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias}, + booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)}, + address = {Hong Kong}, + month = {November}, + url = {https://arxiv.org/abs/1909.02074}, + year = {2019}, +} +``` diff --git a/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh b/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh new file mode 100755 index 0000000000..e78ed66a15 --- /dev/null +++ b/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh @@ -0,0 +1,118 @@ +#!/bin/bash + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +echo 'Cloning Moses github repository (for tokenization scripts)...' +git clone https://github.com/moses-smt/mosesdecoder.git + +SCRIPTS=mosesdecoder/scripts +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +CLEAN=$SCRIPTS/training/clean-corpus-n.perl +REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl + +URLS=( + "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" + "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" + "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" + "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz" + "http://data.statmt.org/wmt17/translation-task/dev.tgz" + "http://statmt.org/wmt14/test-full.tgz" +) +CORPORA=( + "training/europarl-v7.de-en" + "commoncrawl.de-en" + "training-parallel-nc-v13/news-commentary-v13.de-en" + "rapid2016.de-en" +) + +if [ ! -d "$SCRIPTS" ]; then + echo "Please set SCRIPTS variable correctly to point to Moses scripts." + exit +fi + +src=en +tgt=de +lang=en-de +prep=wmt18_en_de +tmp=$prep/tmp +orig=orig +dev=dev/newstest2012 +codes=32000 +bpe=bpe.32k + +mkdir -p $orig $tmp $prep $bpe + +cd $orig + +for ((i=0;i<${#URLS[@]};++i)); do + url=${URLS[i]} + file=$(basename $url) + if [ -f $file ]; then + echo "$file already exists, skipping download" + else + wget "$url" + if [ -f $file ]; then + echo "$url successfully downloaded." + else + echo "$url not successfully downloaded." + exit 1 + fi + if [ ${file: -4} == ".tgz" ]; then + tar zxvf $file + elif [ ${file: -4} == ".tar" ]; then + tar xvf $file + fi + fi +done +cd .. + +echo "pre-processing train data..." +for l in $src $tgt; do + rm -rf $tmp/train.tags.$lang.tok.$l + for f in "${CORPORA[@]}"; do + cat $orig/$f.$l | \ + perl $REM_NON_PRINT_CHAR | \ + perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l + done +done + +echo "pre-processing test data..." +for l in $src $tgt; do + if [ "$l" == "$src" ]; then + t="src" + else + t="ref" + fi + grep '\s*//g' | \ + sed -e 's/\s*<\/seg>\s*//g' | \ + sed -e "s/\’/\'/g" | \ + perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l + echo "" +done + +# apply length filtering before BPE +perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100 + +# use newstest2012 for valid +echo "pre-processing valid data..." +for l in $src $tgt; do + rm -rf $tmp/valid.$l + cat $orig/$dev.$l | \ + perl $REM_NON_PRINT_CHAR | \ + perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l +done + +mkdir output +mv $tmp/{train,valid,test}.{$src,$tgt} output + +#BPE +git clone git@github.com:glample/fastBPE.git +pushd fastBPE +g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast +popd +fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes +for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py index 44dcb256c4..744c5e3fc8 100644 --- a/fairseq/binarizer.py +++ b/fairseq/binarizer.py @@ -52,6 +52,22 @@ def replaced_consumer(word, idx): line = f.readline() return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced} + @staticmethod + def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1): + nseq = 0 + + with open(filename, 'r') as f: + f.seek(offset) + line = safe_readline(f) + while line: + if end > 0 and f.tell() > end: + break + ids = alignment_parser(line) + nseq += 1 + consumer(ids) + line = f.readline() + return {'nseq': nseq} + @staticmethod def find_offsets(filename, num_chunks): with open(filename, 'r', encoding='utf-8') as f: diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py new file mode 100644 index 0000000000..2cb5621498 --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import utils + +from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion +from . import register_criterion + + +@register_criterion('label_smoothed_cross_entropy_with_alignment') +class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + self.alignment_lambda = args.alignment_lambda + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + super(LabelSmoothedCrossEntropyCriterionWithAlignment, + LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser) + parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D', + help='weight for the alignment loss') + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample['net_input']) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['target'].size(0), + 'sample_size': sample_size, + } + + alignment_loss = None + + # Compute alignment loss only for training set and non dummy batches. + if 'alignments' in sample and sample['alignments'] is not None: + alignment_loss = self.compute_alignment_loss(sample, net_output) + + if alignment_loss is not None: + logging_output['alignment_loss'] = utils.item(alignment_loss.data) + loss += self.alignment_lambda * alignment_loss + + return loss, sample_size, logging_output + + def compute_alignment_loss(self, sample, net_output): + attn_prob = net_output[1]['attn'] + bsz, tgt_sz, src_sz = attn_prob.shape + attn = attn_prob.view(bsz * tgt_sz, src_sz) + + align = sample['alignments'] + align_weights = sample['align_weights'].float() + + if len(align) > 0: + # Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to + # the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing. + loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum() + else: + return None + + return loss + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) + sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + return { + 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0., + 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0., + 'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0., + 'ntokens': ntokens, + 'nsentences': nsentences, + 'sample_size': sample_size, + } diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index 5fc1371aae..09c7193ab4 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -22,6 +22,28 @@ def merge(key, left_pad, move_eos_to_beginning=False): pad_idx, eos_idx, left_pad, move_eos_to_beginning, ) + def check_alignment(alignment, src_len, tgt_len): + if alignment is None or len(alignment) == 0: + return False + if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1: + print("| alignment size mismatch found, skipping alignment!") + return False + return True + + def compute_alignment_weights(alignments): + """ + Given a tensor of shape [:, 2] containing the source-target indices + corresponding to the alignments, a weight vector containing the + inverse frequency of each target index is computed. + For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then + a tensor containing [1., 0.5, 0.5, 1] should be returned (since target + index 3 is repeated twice) + """ + align_tgt = alignments[:, 1] + _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True) + align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] + return 1. / align_weights.float() + id = torch.LongTensor([s['id'] for s in samples]) src_tokens = merge('source', left_pad=left_pad_source) # sort by descending source length @@ -35,6 +57,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): if samples[0].get('target', None) is not None: target = merge('target', left_pad=left_pad_target) target = target.index_select(0, sort_order) + tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order) ntokens = sum(len(s['target']) for s in samples) if input_feeding: @@ -61,6 +84,32 @@ def merge(key, left_pad, move_eos_to_beginning=False): } if prev_output_tokens is not None: batch['net_input']['prev_output_tokens'] = prev_output_tokens + + if samples[0].get('alignment', None) is not None: + bsz, tgt_sz = batch['target'].shape + src_sz = batch['net_input']['src_tokens'].shape[1] + + offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) + offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz) + if left_pad_source: + offsets[:, 0] += (src_sz - src_lengths) + if left_pad_target: + offsets[:, 1] += (tgt_sz - tgt_lengths) + + alignments = [ + alignment + offset + for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) + for alignment in [samples[align_idx]['alignment'].view(-1, 2)] + if check_alignment(alignment, src_len, tgt_len) + ] + + if len(alignments) > 0: + alignments = torch.cat(alignments, dim=0) + align_weights = compute_alignment_weights(alignments) + + batch['alignments'] = alignments + batch['align_weights'] = align_weights + return batch @@ -91,6 +140,8 @@ class LanguagePairDataset(FairseqDataset): of source if it's present (default: False). append_eos_to_target (bool, optional): if set, appends eos to end of target if it's absent (default: False). + align_dataset (torch.utils.data.Dataset, optional): dataset + containing alignments. """ def __init__( @@ -98,7 +149,9 @@ def __init__( tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, max_source_positions=1024, max_target_positions=1024, - shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, + shuffle=True, input_feeding=True, + remove_eos_from_source=False, append_eos_to_target=False, + align_dataset=None, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() @@ -118,6 +171,9 @@ def __init__( self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target + self.align_dataset = align_dataset + if self.align_dataset is not None: + assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None @@ -136,11 +192,14 @@ def __getitem__(self, index): if self.src[index][-1] == eos: src_item = self.src[index][:-1] - return { + example = { 'id': index, 'source': src_item, 'target': tgt_item, } + if self.align_dataset is not None: + example['alignment'] = self.align_dataset[index] + return example def __len__(self): return len(self.src) @@ -212,3 +271,5 @@ def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) + if self.align_dataset is not None: + self.align_dataset.prefetch(indices) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index fc53a7c9d7..674de01310 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -222,6 +222,9 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs) return decoder_out + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ Similar to *forward* but only return features. diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 910c2eda09..f5f23f1b95 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -68,6 +68,7 @@ def hub_models(cls): def __init__(self, encoder, decoder): super().__init__(encoder, decoder) + self.supports_align_args = True @staticmethod def add_args(parser): @@ -195,6 +196,69 @@ def build_decoder(cls, args, tgt_dict, embed_tokens): ) +@register_model('transformer_align') +class TransformerAlignModel(TransformerModel): + """ + See "Jointly Learning to Align and Translate with Transformer + Models" (Garg et al., EMNLP 2019). + """ + + def __init__(self, encoder, decoder, args): + super().__init__(encoder, decoder) + self.alignment_heads = args.alignment_heads + self.alignment_layer = args.alignment_layer + self.full_context_alignment = args.full_context_alignment + + @staticmethod + def add_args(parser): + # fmt: off + super(TransformerAlignModel, TransformerAlignModel).add_args(parser) + parser.add_argument('--alignment-heads', type=int, metavar='D', + help='Number of cross attention heads per layer to supervised with alignments') + parser.add_argument('--alignment-layer', type=int, metavar='D', + help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.') + parser.add_argument('--full-context-alignment', type=bool, metavar='D', + help='Whether or not alignment is supervised conditioned on the full target context.') + # fmt: on + + @classmethod + def build_model(cls, args, task): + # set any default arguments + transformer_align(args) + + transformer_model = TransformerModel.build_model(args, task) + return TransformerAlignModel(transformer_model.encoder, transformer_model.decoder, args) + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_out = self.encoder(src_tokens, src_lengths) + return self.forward_decoder(prev_output_tokens, encoder_out) + + def forward_decoder( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + features_only=False, + **extra_args, + ): + attn_args = {'alignment_layer': self.alignment_layer, 'alignment_heads': self.alignment_heads} + decoder_out = self.decoder( + prev_output_tokens, + encoder_out, + **attn_args, + **extra_args, + ) + + if self.full_context_alignment: + attn_args['full_context_alignment'] = self.full_context_alignment + _, alignment_out = self.decoder( + prev_output_tokens, encoder_out, features_only=True, **attn_args, **extra_args, + ) + decoder_out[1]['attn'] = alignment_out['attn'] + + return decoder_out + + class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer @@ -423,7 +487,14 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): else: self.layer_norm = None - def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + def forward( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + features_only=False, + **extra_args, + ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape @@ -432,25 +503,53 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ - x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state) - x = self.output_layer(x) + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state, **extra_args, + ) + if not features_only: + x = self.output_layer(x) return x, extra - def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + def extract_features( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + full_context_alignment=False, + alignment_layer=None, + alignment_heads=None, + **unused, + ): """ Similar to *forward* but only return features. + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ + if alignment_layer is None: + alignment_layer = len(self.layers) - 1 + # embed positions positions = self.embed_positions( prev_output_tokens, @@ -474,15 +573,14 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta # B x T x C -> T x B x C x = x.transpose(0, 1) - attn = None - - inner_states = [x] self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) if not self_attn_padding_mask.any() and not self.cross_self_attention: self_attn_padding_mask = None # decoder layers + attn = None + inner_states = [x] for idx, layer in enumerate(self.layers): encoder_state = None if encoder_out is not None: @@ -491,15 +589,32 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta else: encoder_state = encoder_out['encoder_out'] - x, attn = layer( + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn = layer( x, encoder_state, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state, - self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, + self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, + need_attn=(idx == alignment_layer), + need_head_weights=(idx == alignment_layer), ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float() + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) if self.layer_norm: x = self.layer_norm(x) @@ -531,7 +646,12 @@ def max_positions(self): def buffered_future_mask(self, tensor): dim = tensor.size(0) - if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim: + if ( + not hasattr(self, '_future_mask') + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) return self._future_mask[:dim, :dim] @@ -668,3 +788,18 @@ def transformer_wmt_en_de_big_t2t(args): args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.activation_dropout = getattr(args, 'activation_dropout', 0.1) transformer_vaswani_wmt_en_de_big(args) + + +@register_model_architecture('transformer_align', 'transformer_align') +def transformer_align(args): + args.alignment_heads = getattr(args, 'alignment_heads', 1) + args.alignment_layer = getattr(args, 'alignment_layer', 4) + args.full_context_alignment = getattr(args, 'full_context_alignment', False) + base_architecture(args) + + +@register_model_architecture('transformer_align', 'transformer_wmt_en_de_big_align') +def transformer_wmt_en_de_big_align(args): + args.alignment_heads = getattr(args, 'alignment_heads', 1) + args.alignment_layer = getattr(args, 'alignment_layer', 4) + transformer_wmt_en_de_big(args) diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 9aaea82484..96849790f6 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -90,15 +90,37 @@ def reset_parameters(self): if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v) - def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, - need_weights=True, static_kv=False, attn_mask=None, before_softmax=False): + def forward( + self, + query, key, value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + ): """Input shape: Time x Batch x Channel - Timesteps can be masked by supplying a T x T mask in the - `attn_mask` argument. Padding elements can be excluded from - the key by passing a binary ByteTensor (`key_padding_mask`) with shape: - batch x src_len, where padding elements are indicated by 1s. + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. """ + if need_head_weights: + need_weights = True + tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] @@ -249,12 +271,11 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No if before_softmax: return attn_weights, v - attn_weights = utils.softmax( - attn_weights, dim=-1, onnx_trace=self.onnx_trace, - ).type_as(attn_weights) - attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) + attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) - attn = torch.bmm(attn_weights, v) + attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) @@ -265,9 +286,10 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No attn = self.out_proj(attn) if need_weights: - # average attention weights over heads - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.sum(dim=1) / self.num_heads + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) else: attn_weights = None diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 63c6cdf552..a3579fb990 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -195,16 +195,25 @@ def forward( prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, + need_attn=False, + need_head_weights=False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_padding_mask (ByteTensor): binary ByteTensor of shape - `(batch, src_len)` where padding elements are indicated by ``1``. + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ + if need_head_weights: + need_attn = True + residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: @@ -259,7 +268,8 @@ def forward( key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, - need_weights=(not self.training and self.need_attn), + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x diff --git a/fairseq/options.py b/fairseq/options.py index bb1e27aeb7..06a52b62ba 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -224,6 +224,8 @@ def add_preprocess_args(parser): help="comma separated, valid file prefixes") group.add_argument("--testpref", metavar="FP", default=None, help="comma separated, test file prefixes") + group.add_argument("--align-suffix", metavar="FP", default=None, + help="alignment file suffix") group.add_argument("--destdir", metavar="DIR", default="data-bin", help="destination dir") group.add_argument("--thresholdtgt", metavar="N", default=0, type=int, diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 3b100b9615..dd3fb86f7b 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -7,7 +7,8 @@ import torch -from fairseq import search +from fairseq import search, utils +from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder @@ -81,7 +82,6 @@ def __init__( self.temperature = temperature self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size - assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' assert temperature > 0, '--temperature must be greater than 0' @@ -98,14 +98,7 @@ def __init__( self.search = search.BeamSearch(tgt_dict) @torch.no_grad() - def generate( - self, - models, - sample, - prefix_tokens=None, - bos_token=None, - **kwargs - ): + def generate(self, models, sample, **kwargs): """Generate a batch of translations. Args: @@ -113,8 +106,21 @@ def generate( sample (dict): batch prefix_tokens (torch.LongTensor, optional): force decoder to begin with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) """ model = EnsembleModel(models) + return self._generate(model, sample, **kwargs) + + @torch.no_grad() + def _generate( + self, + model, + sample, + prefix_tokens=None, + bos_token=None, + **kwargs + ): if not self.retain_dropout: model.eval() @@ -155,7 +161,6 @@ def generate( tokens_buf = tokens.clone() tokens[:, 0] = self.eos if bos_token is None else bos_token attn, attn_buf = None, None - nonpad_idxs = None # The blacklist indicates candidates that should be ignored. # For example, suppose we're sampling and have already finalized 2/5 @@ -251,17 +256,15 @@ def get_hypo(): if attn_clone is not None: # remove padding tokens from attn scores - hypo_attn = attn_clone[i][nonpad_idxs[sent]] - _, alignment = hypo_attn.max(dim=0) + hypo_attn = attn_clone[i] else: hypo_attn = None - alignment = None return { 'tokens': tokens_clone[i], 'score': score, 'attention': hypo_attn, # src_len x tgt_len - 'alignment': alignment, + 'alignment': None, 'positional_scores': pos_scores[i], } @@ -345,7 +348,6 @@ def replicate_first_beam(tensor, mask): if attn is None: attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) attn_buf = attn.clone() - nonpad_idxs = src_tokens.ne(self.pad) attn[:, :, step + 1].copy_(avg_attn_scores) scores = scores.type_as(lprobs) @@ -512,7 +514,6 @@ def calculate_banned_tokens(bbsz_idx): # sort by score descending for sent in range(len(finalized)): finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) - return finalized @@ -577,9 +578,11 @@ def _decode_one( temperature=1., ): if self.incremental_states is not None: - decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model])) + decoder_out = list(model.forward_decoder( + tokens, encoder_out=encoder_out, incremental_state=self.incremental_states[model], + )) else: - decoder_out = list(model.decoder(tokens, encoder_out)) + decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out)) decoder_out[0] = decoder_out[0][:, -1:, :] if temperature != 1.: decoder_out[0].div_(temperature) @@ -605,3 +608,104 @@ def reorder_incremental_state(self, new_order): return for model in self.models: model.decoder.reorder_incremental_state(self.incremental_states[model], new_order) + + +class SequenceGeneratorWithAlignment(SequenceGenerator): + + def __init__(self, tgt_dict, left_pad_target=False, **kwargs): + """Generates translations of a given source sentence. + + Produces alignments following "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + left_pad_target (bool, optional): Whether or not the + hypothesis should be left padded or not when they are + teacher forced for generating alignments. + """ + super().__init__(tgt_dict, **kwargs) + self.left_pad_target = left_pad_target + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + model = EnsembleModelWithAlignment(models) + finalized = super()._generate(model, sample, **kwargs) + + src_tokens = sample['net_input']['src_tokens'] + bsz = src_tokens.shape[0] + beam_size = self.beam_size + src_tokens, src_lengths, prev_output_tokens, tgt_tokens = \ + self._prepare_batch_for_alignment(sample, finalized) + if any(getattr(m, 'full_context_alignment', False) for m in model.models): + attn = model.forward_align(src_tokens, src_lengths, prev_output_tokens) + else: + attn = [ + finalized[i // beam_size][i % beam_size]['attention'].transpose(1, 0) + for i in range(bsz * beam_size) + ] + + # Process the attn matrix to extract hard alignments. + for i in range(bsz * beam_size): + alignment = utils.extract_hard_alignment(attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos) + finalized[i // beam_size][i % beam_size]['alignment'] = alignment + return finalized + + def _prepare_batch_for_alignment(self, sample, hypothesis): + src_tokens = sample['net_input']['src_tokens'] + bsz = src_tokens.shape[0] + src_tokens = src_tokens[:, None, :].expand(-1, self.beam_size, -1).contiguous().view(bsz * self.beam_size, -1) + src_lengths = sample['net_input']['src_lengths'] + src_lengths = src_lengths[:, None].expand(-1, self.beam_size).contiguous().view(bsz * self.beam_size) + prev_output_tokens = data_utils.collate_tokens( + [beam['tokens'] for example in hypothesis for beam in example], + self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=True, + ) + tgt_tokens = data_utils.collate_tokens( + [beam['tokens'] for example in hypothesis for beam in example], + self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=False, + ) + return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +class EnsembleModelWithAlignment(EnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + def forward_align(self, src_tokens, src_lengths, prev_output_tokens): + avg_attn = None + for model in self.models: + decoder_out = model(src_tokens, src_lengths, prev_output_tokens) + attn = decoder_out[1]['attn'] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(self.models) > 1: + avg_attn.div_(len(self.models)) + return avg_attn + + def _decode_one( + self, tokens, model, encoder_out, incremental_states, log_probs, + temperature=1., + ): + if self.incremental_states is not None: + decoder_out = list(model.forward_decoder( + tokens, + encoder_out=encoder_out, + incremental_state=self.incremental_states[model], + )) + else: + decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out)) + decoder_out[0] = decoder_out[0][:, -1:, :] + if temperature != 1.: + decoder_out[0].div_(temperature) + attn = decoder_out[1] + if type(attn) is dict: + attn = attn.get('attn', None) + if attn is not None: + attn = attn[:, -1, :] + probs = model.get_normalized_probs(decoder_out, log_probs=log_probs) + probs = probs[:, -1, :] + return probs, attn diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py index d125422340..75ff4cf051 100644 --- a/fairseq/sequence_scorer.py +++ b/fairseq/sequence_scorer.py @@ -14,6 +14,7 @@ class SequenceScorer(object): def __init__(self, tgt_dict, softmax_batch=None): self.pad = tgt_dict.pad() + self.eos = tgt_dict.eos() self.softmax_batch = softmax_batch or sys.maxsize assert self.softmax_batch > 0 @@ -44,6 +45,7 @@ def gather_target_probs(probs, target): ) return probs + orig_target = sample['target'] # compute scores for each model in the ensemble @@ -53,6 +55,8 @@ def gather_target_probs(probs, target): model.eval() decoder_out = model.forward(**net_input) attn = decoder_out[1] + if type(attn) is dict: + attn = attn.get('attn', None) batched = batch_for_softmax(decoder_out, orig_target) probs, idx = None, 0 @@ -100,8 +104,9 @@ def gather_target_probs(probs, target): avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] score_i = avg_probs_i.sum() / tgt_len if avg_attn is not None: - avg_attn_i = avg_attn[i, start_idxs[i]:] - _, alignment = avg_attn_i.max(dim=0) + avg_attn_i = avg_attn[i] + alignment = utils.extract_hard_alignment(avg_attn_i, sample['net_input']['src_tokens'][i], + sample['target'][i], self.pad, self.eos) else: avg_attn_i = alignment = None hypos.append([{ diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index ba5695785d..538532b20e 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -198,8 +198,12 @@ def build_generator(self, args): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer(self.target_dictionary) else: - from fairseq.sequence_generator import SequenceGenerator - return SequenceGenerator( + from fairseq.sequence_generator import SequenceGenerator, SequenceGeneratorWithAlignment + if getattr(args, 'print_alignment', False): + seq_gen_cls = SequenceGeneratorWithAlignment + else: + seq_gen_cls = SequenceGenerator + return seq_gen_cls( self.target_dictionary, beam_size=getattr(args, 'beam', 5), max_len_a=getattr(args, 'max_len_a', 0), diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index f3d60403ba..353e640bf6 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -24,7 +24,7 @@ def load_langpair_dataset( tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, - max_target_positions, prepend_bos=False, + max_target_positions, prepend_bos=False, load_alignments=False, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) @@ -74,6 +74,12 @@ def split_exists(split, src, tgt, lang, data_path): src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + align_dataset = None + if load_alignments: + align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt)) + if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): + align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl) + return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, @@ -81,6 +87,7 @@ def split_exists(split, src, tgt, lang, data_path): left_pad_target=left_pad_target, max_source_positions=max_source_positions, max_target_positions=max_target_positions, + align_dataset=align_dataset, ) @@ -120,6 +127,8 @@ def add_args(parser): help='load the dataset lazily') parser.add_argument('--raw-text', action='store_true', help='load raw text dataset') + parser.add_argument('--load-alignments', action='store_true', + help='load the binarized alignments') parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', help='pad the source on the left') parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', @@ -193,6 +202,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, + load_alignments=self.args.load_alignments, ) def build_dataset_for_inference(self, src_tokens, src_lengths): diff --git a/fairseq/utils.py b/fairseq/utils.py index 80ecb6d083..9dd41fbfea 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -16,6 +16,7 @@ import torch import torch.nn.functional as F +from itertools import accumulate from fairseq.modules import gelu, gelu_accurate @@ -367,3 +368,47 @@ def set_torch_seed(seed): assert isinstance(seed, int) torch.manual_seed(seed) torch.cuda.manual_seed(seed) + + +def parse_alignment(line): + """ + Parses a single line from the alingment file. + + Args: + line (str): String containing the alignment of the format: + - - .. + -. All indices are 0 indexed. + + Returns: + torch.IntTensor: packed alignments of shape (2 * m). + """ + alignments = line.strip().split() + parsed_alignment = torch.IntTensor(2 * len(alignments)) + for idx, alignment in enumerate(alignments): + src_idx, tgt_idx = alignment.split('-') + parsed_alignment[2 * idx] = int(src_idx) + parsed_alignment[2 * idx + 1] = int(tgt_idx) + return parsed_alignment + + +def get_token_to_word_mapping(tokens, exclude_list): + n = len(tokens) + word_start = [int(token not in exclude_list) for token in tokens] + word_idx = list(accumulate(word_start)) + token_to_word = {i: word_idx[i] for i in range(n)} + return token_to_word + + +def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero().squeeze(dim=-1) + src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero().squeeze(dim=-1) + src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) + tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) + alignment = [] + if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): + attn_valid = attn[tgt_valid] + attn_valid[:, src_invalid] = float('-inf') + _, src_indices = attn_valid.max(dim=1) + for tgt_idx, src_idx in zip(tgt_valid, src_indices): + alignment.append((src_token_to_word[src_idx.item()] - 1, tgt_token_to_word[tgt_idx.item()] - 1)) + return alignment diff --git a/generate.py b/generate.py index 6de1a69abd..aba611d4b0 100644 --- a/generate.py +++ b/generate.py @@ -137,7 +137,7 @@ def main(args): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, - alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, + alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, @@ -156,7 +156,7 @@ def main(args): if args.print_alignment: print('A-{}\t{}'.format( sample_id, - ' '.join(map(lambda x: str(utils.item(x)), alignment)) + ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) )) if args.print_step: @@ -180,6 +180,7 @@ def main(args): num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) + return scorer diff --git a/interactive.py b/interactive.py index d9d547a974..36e2bd0ca9 100644 --- a/interactive.py +++ b/interactive.py @@ -162,7 +162,7 @@ def decode_fn(x): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, - alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, + alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, @@ -174,9 +174,10 @@ def decode_fn(x): ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) )) if args.print_alignment: + alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment]) print('A-{}\t{}'.format( id, - ' '.join(map(lambda x: str(utils.item(x)), alignment)) + alignment_str )) # update running id counter diff --git a/preprocess.py b/preprocess.py index a157feeb68..538ff2b006 100644 --- a/preprocess.py +++ b/preprocess.py @@ -157,6 +157,60 @@ def merge_result(worker_result): ) ) + def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers): + nseq = [0] + + def merge_result(worker_result): + nseq[0] += worker_result['nseq'] + + input_file = input_prefix + offsets = Binarizer.find_offsets(input_file, num_workers) + pool = None + if num_workers > 1: + pool = Pool(processes=num_workers - 1) + for worker_id in range(1, num_workers): + prefix = "{}{}".format(output_prefix, worker_id) + pool.apply_async( + binarize_alignments, + ( + args, + input_file, + utils.parse_alignment, + prefix, + offsets[worker_id], + offsets[worker_id + 1] + ), + callback=merge_result + ) + pool.close() + + ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"), + impl=args.dataset_impl) + + merge_result( + Binarizer.binarize_alignments( + input_file, utils.parse_alignment, lambda t: ds.add_item(t), + offset=0, end=offsets[1] + ) + ) + if num_workers > 1: + pool.join() + for worker_id in range(1, num_workers): + prefix = "{}{}".format(output_prefix, worker_id) + temp_file_path = dataset_dest_prefix(args, prefix, None) + ds.merge_file_(temp_file_path) + os.remove(indexed_dataset.data_file_path(temp_file_path)) + os.remove(indexed_dataset.index_file_path(temp_file_path)) + + ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) + + print( + "| [alignments] {}: parsed {} alignments".format( + input_file, + nseq[0] + ) + ) + def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.dataset_impl == "raw": # Copy original text file to destination folder @@ -180,9 +234,19 @@ def make_all(lang, vocab): outprefix = "test{}".format(k) if k > 0 else "test" make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) + def make_all_alignments(): + if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): + make_binary_alignment_dataset(args.trainpref + "." + args.align_suffix, "train.align", num_workers=args.workers) + if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): + make_binary_alignment_dataset(args.validpref + "." + args.align_suffix, "valid.align", num_workers=args.workers) + if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): + make_binary_alignment_dataset(args.testpref + "." + args.align_suffix, "test.align", num_workers=args.workers) + make_all(args.source_lang, src_dict) if target: make_all(args.target_lang, tgt_dict) + if args.align_suffix: + make_all_alignments() print("| Wrote preprocessed data to {}".format(args.destdir)) @@ -242,11 +306,28 @@ def consumer(tensor): return res +def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end): + ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"), + impl=args.dataset_impl, vocab_size=None) + + def consumer(tensor): + ds.add_item(tensor) + + res = Binarizer.binarize_alignments(filename, parse_alignment, consumer, offset=offset, + end=end) + ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) + return res + + def dataset_dest_prefix(args, output_prefix, lang): base = "{}/{}".format(args.destdir, output_prefix) - lang_part = ( - ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) if lang is not None else "" - ) + if lang is not None: + lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) + elif args.only_source: + lang_part = "" + else: + lang_part = ".{}-{}".format(args.source_lang, args.target_lang) + return "{}{}".format(base, lang_part) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index f77806bd6a..113901ab06 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -266,6 +266,27 @@ def test_mixture_of_experts(self): '--gen-expert', '0' ]) + def test_alignment(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_alignment') as data_dir: + create_dummy_data(data_dir, alignment=True) + preprocess_translation_data(data_dir, ['--align-suffix', 'align']) + train_translation_model( + data_dir, + 'transformer_align', + [ + '--encoder-layers', '2', + '--decoder-layers', '2', + '--encoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--load-alignments', + '--alignment-layer', '1', + '--criterion', 'label_smoothed_cross_entropy_with_alignment' + ], + run_validation=True, + ) + generate_main(data_dir) + class TestStories(unittest.TestCase): @@ -484,7 +505,7 @@ def test_optimizers(self): generate_main(data_dir) -def create_dummy_data(data_dir, num_examples=1000, maxlen=20): +def create_dummy_data(data_dir, num_examples=1000, maxlen=20, alignment=False): def _create_dummy_data(filename): data = torch.rand(num_examples * maxlen) @@ -497,6 +518,20 @@ def _create_dummy_data(filename): print(ex_str, file=h) offset += ex_len + def _create_dummy_alignment_data(filename_src, filename_tgt, filename): + with open(os.path.join(data_dir, filename_src), 'r') as src_f, \ + open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \ + open(os.path.join(data_dir, filename), 'w') as h: + for src, tgt in zip(src_f, tgt_f): + src_len = len(src.split()) + tgt_len = len(tgt.split()) + avg_len = (src_len + tgt_len) // 2 + num_alignments = random.randint(avg_len // 2, 2 * avg_len) + src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() + tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() + ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)]) + print(ex_str, file=h) + _create_dummy_data('train.in') _create_dummy_data('train.out') _create_dummy_data('valid.in') @@ -504,6 +539,10 @@ def _create_dummy_data(filename): _create_dummy_data('test.in') _create_dummy_data('test.out') + if alignment: + _create_dummy_alignment_data('train.in', 'train.out', 'train.align') + _create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align') + _create_dummy_alignment_data('test.in', 'test.out', 'test.align') def preprocess_translation_data(data_dir, extra_flags=None): preprocess_parser = options.get_preprocessing_parser()