In [None]:
!pip install torch matplotlib seaborn datasets transformers tqdm  tokenizers

In [None]:
!brew install boost

## Reproduce "Neural Machine Translation by Jointly Learning to Align and Translate"

In this notebook, I attempt to reproduce the results from the paper [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473) by implementing RNNSearch.

In [99]:
import torch
from torch import nn
from torch.nn import functional as F
from transformers import AutoTokenizer
from tqdm import tqdm

from tokenizers.models import WordPiece
from transformers import BertTokenizerFast

import datasets

## Dataset

The paper demonstrates the approach on an English to French translation task, using the data provided as part of the [Workshop on Statistical Machine Translation in 2014](https://aclanthology.org/W14-3302.pdf). I've found a version of that on HuggingFace. Not sure exactly how closely it mirrors the paper, but I'm not too concerned.

In [100]:
dataset = datasets.load_dataset("presencesw/wmt14_fr_en")
print("Dataset structure:", dataset)

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/30 [00:00<?, ?it/s]

Dataset structure: DatasetDict({
    train: Dataset({
        features: ['en', 'fr'],
        num_rows: 40836876
    })
    validation: Dataset({
        features: ['en', 'fr'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['en', 'fr'],
        num_rows: 3003
    })
})


In the paper, they "concat news-test-2012 and news-test-2013" for the validation set, but I'm using the validation set kindly provided by presencesw.

## Tokeniser

They use the tokenisation script from open-source package Moses.

A Python wrapper exists called `pip install mosestokenizer` which I've installed.

In [101]:
example = dataset['train'][0]
example

{'en': 'In his briefing on economic development, Al Horner will give you details of programs we fund to foster partnerships between the private sector and First Nations and Inuit communities, in areas like resource development projects, for example.',
 'fr': "Dans sa présentation sur le développement économique, M. Al Horner vous donnera des détails sur les programmes que nous finançons pour favoriser l'établissement de partenariats entre le secteur privé et les collectivités des Premières nations et inuites dans des domaines comme celui de l'exploitation des ressources naturelles."}

In [102]:
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

In [103]:
len(tokenizer)

250027

In [104]:
print([tokenizer.decode(t) for t in tokenizer(example["en"])["input_ids"]])

['In', 'his', 'brief', 'ing', 'on', 'economic', 'development', ',', 'Al', 'Horn', 'er', 'will', 'give', 'you', 'details', 'of', 'programs', 'we', 'fund', 'to', 'fost', 'er', 'partnership', 's', 'between', 'the', 'private', 'sector', 'and', 'First', 'Nations', 'and', 'In', 'uit', 'communities', ',', 'in', 'areas', 'like', 'resource', 'development', 'projects', ',', 'for', 'example', '.', '</s>', 'en_XX']


In [105]:
print([tokenizer.decode(t) for t in tokenizer(example["fr"])["input_ids"]])

['Dans', 'sa', 'présentation', 'sur', 'le', 'développement', 'économique', ',', 'M', '.', 'Al', 'Horn', 'er', 'vous', 'donner', 'a', 'des', 'détails', 'sur', 'les', 'programme', 's', 'que', 'nous', 'fina', 'nç', 'ons', 'pour', 'favoriser', 'l', "'", 'établissement', 'de', 'partenariat', 's', 'entre', 'le', 'secteur', 'privé', 'et', 'les', 'collectivités', 'des', 'Premi', 'ères', 'na', 'tions', 'et', 'in', 'uit', 'es', 'dans', 'des', 'domaine', 's', 'comme', 'celui', 'de', 'l', "'", 'exploitation', 'des', 'ressources', 'naturelle', 's', '.', '</s>', 'en_XX']


From the paper:

> After a usual tokenization, we use a shortlist of 30,000 most frequent words in each language to train our models.
> Any word not included in the shortlist is mapped to a special token ([UNK]).
> We do not apply any other special preprocessing, such as lowercasing or stemming, to the data.

To achieve that, I'll create a counter of words. Then we'll cull anything that falls out of the most frequent words.

## Models

We train two types of models.

The first one is an RNN Encoder–Decoder (RNNencdec, Cho et al., 2014a), and the other is the proposed model, to which we refer as RNNsearch.

We train each model twice: first with the sentences of length up to 30 words (RNNencdec-30, RNNsearch-30) and then with the sentences of length up to 50 word (RNNencdec-50, RNNsearch-50).

The encoder and decoder of the RNNencdec have 1000 hidden units each.

The encoder of the RNNsearch consists of forward and backward recurrent neural networks (RNN) each having 1000 hidden units. Its decoder has 1000 hidden units.


We use a minibatch stochastic gradient descent (SGD) algorithm together with Adadelta (Zeiler, 2012) to train each model.

Each SGD update direction is computed using a minibatch of 80 sentences. We trained each model for approximately 5 days.

Once a model is trained, we use a beam search to find a translation that approximately maximizes the conditional probability (see, e.g., Graves, 2012; Boulanger-Lewandowski et al., 2013). Sutskever et al. (2014) used this approach to generate translations from their neural machine translation model. For more details on the architectures of the models and training procedure used in the experiments, see Appendices A and B.

### Sizes

For all the models used in this paper, the size of a hidden layer $n$ is 1000, the word embedding dimensionality $m$ is 620 and the size of the maxout hidden layer in the deep output $l$ is 500. The number of hidden units in the alignment model $n'$ is 1000.

In [106]:
embed_size = 620
hidden_size = 1000
maxout_size = 500
vocab_size = len(tokenizer)

In [108]:
maxout_layer = MaxoutLayer(hidden_size + embed_size, maxout_size)

In [109]:
maxout_layer(torch.rand(1, 1620)).shape

torch.Size([1, 500])

### Encoder

In the paper, they use a Bi-Directional RNN, which has a forward RNN and a backward RNN.

Effectively, we have an RNN that operates on the normal sequence, and another on the reversed.

Each token is then concanted togetether for so that it has context from behind and forwards.

That gives us a bidirectional context vector for each word.

The Pytorch RNN function already has bidrecitonal capability. It return the encoding in both directions, and we just simple concat the values together.

In [119]:
token_ids = torch.tensor([ [0,1,2,3] ]).long() # Batch x Sequence

In [120]:
encoder_embedding = nn.Embedding(vocab_size, embed_size)
encoder = nn.GRU(embed_size, hidden_size, batch_first=True, bidirectional=True)

In [121]:
embedding = encoder_embedding(token_ids)  # Batch x Sequence x Embedding Dimension
embedding.shape 

torch.Size([1, 4, 620])

In [124]:
encoder_out, hidden = encoder(embedding)
encoder_out.shape

torch.Size([1, 4, 2000])

In [125]:
hidden.shape

torch.Size([2, 1, 1000])

From https://pytorch.org/docs/stable/generated/torch.nn.GRU.html

For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.

In the paper, the decoder initial hidden state uses the last hidden state from the encoder. Which is the first state in the backwards pass.

In [134]:
hidden = hidden[1][0:]
hidden.shape

torch.Size([1, 1000])

Here I wrap all into a PyTorch module.

In [135]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.GRU(
            embed_dim,
            hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        
    def forward(self, src):
        embedded = self.embedding(src)

        outputs, hidden = self.rnn(embedded)
        hidden = hidden[1][0:]

        return outputs, hidden

In [139]:
enc = Encoder(vocab_size, embed_size, hidden_size)
encoder_outputs, encoder_hidden = enc(token_ids)
encoder_outputs.shape, encoder_hidden.shape

(torch.Size([1, 4, 2000]), torch.Size([1, 1000]))

## Decoder

The initial hidden state $s_0$ is computed by $s_0 = \tanh \left( W_s \overleftarrow{h}_1 \right)$ where $W_s \in \mathbb{R}^{n \times n}$.

In [159]:
init_state = nn.Linear(hidden_size, hidden_size)
dec_hidden = torch.tanh(init_state(encoder_hidden)).unsqueeze(0)
dec_hidden.shape

torch.Size([1, 1, 1000])

### Attention

For the encoder outputs, we do a linearation project into the context size.

In [172]:
attention_context = nn.Linear(hidden_size * 2, hidden_size)

# As per the paper, we only use one direction of final hidden state. So doesn't need to be doubled.
attention_hidden = nn.Linear(hidden_size, hidden_size)

In [181]:
context_proj = attention_context(outputs)
context_proj.shape

torch.Size([1, 4, 1000])

In [182]:
hidden_proj = attention_hidden(dec_hidden)
hidden_proj.shape

torch.Size([1, 1, 1000])

In [199]:
attention_vector = torch.tanh(hidden_proj + context_proj)
attention_vector.shape

torch.Size([1, 4, 1000])

In [200]:
attention_alignment = nn.Linear(hidden_size, 1)

In [201]:
attention_scores = attention_alignment(attention_vector)
attention_scores.shape

torch.Size([1, 4, 1])

In [202]:
attention_scores = attention_alignment(attention_vector).squeeze(2)
attention_scores.shape

torch.Size([1, 4])

In [189]:
attention_scores

tensor([[0.1133, 0.1081, 0.0164, 0.0054]], grad_fn=<SqueezeBackward1>)

Then a projection of the decoder hidden.

Finally a Softmax is performed to convert into a probability.

In [191]:
attention_weights = F.softmax(attention_scores, dim=1)
attention_weights

tensor([[0.2632, 0.2618, 0.2388, 0.2362]], grad_fn=<SoftmaxBackward0>)

In [198]:
attention_weights.shape

torch.Size([1, 4])

And now we do a matrix multiplication operation, which gives us the final weighed sum.

In [195]:
context = torch.bmm(attention_weights.unsqueeze(1), outputs)
context.shape

torch.Size([1, 1, 2000])

In [208]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention_context = nn.Linear(hidden_size * 2, hidden_size)
        self.attention_hidden = nn.Linear(hidden_size, hidden_size)
        self.attention_alignment = nn.Linear(hidden_size, 1)
    
    def forward(self, decoder_state, encoder_outputs, attention_mask=None):
        hidden_projection = self.attention_hidden(decoder_state)
        context_projection = self.attention_context(encoder_outputs)
        
        attention_vector = torch.tanh(hidden_projection + context_projection)
        attention_scores = self.attention_alignment(attention_vector).squeeze(2)
        
        # Apply attention mask
        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(
                ~attention_mask.bool(), 
                float('-inf')
            )
        
        attention_weights = F.softmax(attention_scores, dim=1)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        
        return context, attention_weights

In [212]:
attn = Attention(hidden_size)
context, attention_weights = attn(dec_hidden, encoder_out)
context.shape

torch.Size([1, 1, 2000])

### Maxout Layer

In both cases, we use a multilayer network with a single maxout (Goodfellow et al., 2013) hidden layer to compute the conditional probability of each target word (Pascanu et al., 2014).

### Maxout

The final layer of the decoder is a Maxout layer, which projects a linear layer into two buckets and takes the max. A form of regularisation.

In [107]:
class MaxoutLayer(nn.Module):
    def __init__(self, in_features, out_features, num_pieces=2):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_pieces = num_pieces
        self.linear = nn.Linear(in_features, out_features * num_pieces)
    
    def forward(self, x):
        shape = [x.shape[0], self.out_features, self.num_pieces]
        # Project into num_features * out_features
        x = self.linear(x) # B, F
        # Project into (num_features, out_features, num_pieces)
        x = x.view(*shape)  # B, F, P
        # Take the max, which should take the match out of either bucket.
        x, _ = torch.max(x, -1) # B, F
        return x # B, F

### Decoder class

In [213]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attention = Attention(hidden_size)
        
        # Input size is embedding + context vector
        self.rnn = nn.GRU(
            embed_dim + hidden_dim * 2,
            hidden_dim * 2,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout
        )
        
        self.output = MaxoutLayer(hidden_dim * 2, vocab_size)
        
    def forward(self, input, hidden, encoder_outputs, attention_mask):
        # input shape: [batch_size, 1]
        embedded = self.dropout(self.embedding(input))
        
        # Get context vector from attention
        context, attention_weights = self.attention(
            hidden.unsqueeze(1), encoder_outputs, attention_mask
        )
        
        # Combine embedding and context vector
        rnn_input = torch.cat([embedded, context], dim=2)

        # Pass through RNN
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        
        # Generate output distribution
        prediction = self.output(output.squeeze(1))
        
        return prediction, hidden.squeeze(0), attention_weights

In [214]:
def create_sample_batch(batch_size=32, max_src_len=20, max_tgt_len=20, 
                       src_vocab_size=1000, tgt_vocab_size=1000):
    src = torch.randint(0, src_vocab_size, (batch_size, max_src_len))
    tgt = torch.randint(0, tgt_vocab_size, (batch_size, max_tgt_len))
    src_lengths = torch.randint(5, max_src_len + 1, (batch_size,))
    return src, tgt, src_lengths

Let's create some sample batches firstly.

In [215]:
batch_size=32
max_src_len=20
max_tgt_len=20
src_vocab_size=1000
tgt_vocab_size=1000

source = torch.randint(0, src_vocab_size, (batch_size, max_src_len))
target = torch.randint(0, tgt_vocab_size, (batch_size, max_tgt_len))
source_lengths = torch.randint(5, max_src_len + 1, (batch_size,))

In [216]:
source.shape

torch.Size([32, 20])

In [217]:
target.shape

torch.Size([32, 20])

In [218]:
source_lengths.shape

torch.Size([32])