Skip to content


Repository files navigation

Latent Alignment and Variational Attention

This is a Pytorch implementation of the paper Latent Alignment and Variational Attention from a fork of OpenNMT.


The code was tested with python 3.6 and pytorch 0.4. To install the dependencies, run

pip install -r requirements.txt

Running the code

All commands are in the script

Preprocessing the data

To preprocess the data, run

source && preprocess_bpe

The raw data in data/iwslt14-de-en was obtained from the fairseq repo with BPE_TOKENS=14000.

Training the model

To train a model, run one of the following commands:

  • Soft attention
source && CUDA_VISIBLE_DEVICES=0 train_soft_b6
  • Categorical attention with exact evidence
source && CUDA_VISIBLE_DEVICES=0 train_exact_b6
  • Variational categorical attention with exact ELBO
source && CUDA_VISIBLE_DEVICES=0 train_cat_enum_b6
  • Variational categorical attention with REINFORCE
source && CUDA_VISIBLE_DEVICES=0 train_cat_sample_b6
  • Variational categorical attention with Gumbel-Softmax
source && CUDA_VISIBLE_DEVICES=0 train_cat_gumbel_b6
  • Variational categorical attention using Wake-Sleep algorithm (Ba et al 2015)
source && CUDA_VISIBLE_DEVICES=0 train_cat_wsram_b6

Checkpoints will be saved to the project's root directory.

Evaluating on test

The exact perplexity of the generative model can be obtained by running the following command with $model replaced with a saved checkpoint.

source && CUDA_VISIBLE_DEVICES=0 eval_cat $model

The model can also be used to generate translations of the test data:

source && CUDA_VISIBLE_DEVICES=0 gen_cat $model
sed -e "s/@@ //g" $model.out | perl tools/multi-bleu.perl data/iwslt14-de-en/test.en

Trained Models

Models with the lowest validation PPL were selected for evaluation on test. Numbers are slightly different from those reported in the paper since this is a re-implementation.

Model Test PPL Test BLEU
Soft Attention 7.17 32.77
Exact Marginalization 6.34 33.29
Variational Attention + Enumeration 6.08 33.69
Variational Attention + Sampling 6.17 33.30