## RNNSearch - "Neural Machine Translation by Jointly Learning to Align and Translate"

This notebook is a quick implementation of `RNNSearch` from paper [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473).

In [None]:
from typing import Tuple

import torch
from torch import nn
from torch.nn import functional as F
from transformers import AutoTokenizer
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from tokenizers.models import WordPiece
from transformers import BertTokenizerFast

import datasets

## Model Configuration

For all the models used in this paper:
* the size of a hidden layer $n$ is 1000
* the word embedding dimensionality $m$ is 620
* size of the maxout hidden layer in the deep output $l$ is 500.
* The number of hidden units in the alignment model $n$ is 1000.

I've had to adjust to fit in the GPU limits.

In [None]:
#embed_size = 620
#hidden_size = 1000
#maxout_size = 500

embed_size = 128
hidden_size = 128
maxout_size = 128
vocab_size = 32000
max_length = 10
batch_size = 4

## Encoder

In [None]:
token_ids = torch.tensor([ [0,1,2,3] ]).long() # Batch x Sequence

In [None]:
encoder_embedding = nn.Embedding(vocab_size, embed_size)
encoder = nn.GRU(embed_size, hidden_size, batch_first=True, bidirectional=True)

In [None]:
embedding = encoder_embedding(token_ids)  # Batch x Sequence x Embedding Dimension
embedding.shape 

In [None]:
encoder_out, hidden = encoder(embedding)
encoder_out.shape, hidden.shape

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.GRU(
            embed_dim,
            hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        
    def forward(self, src):
        embedded = self.embedding(src)

        outputs, hidden = self.rnn(embedded)
        hidden = hidden[1][0:]

        return outputs, hidden

In [None]:
enc = Encoder(vocab_size, embed_size, hidden_size)
encoder_outputs, encoder_hidden = enc(token_ids)
encoder_outputs.shape, encoder_hidden.shape

## Attention Mechanism

In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_hidden_size, decoder_hidden_size, alignment_hidden_size):
        super(Attention, self).__init__()

        self.decoder_hidden_layer = nn.Linear(decoder_hidden_size, alignment_hidden_size)
        self.encoder_outputs_layer = nn.Linear(encoder_hidden_size, alignment_hidden_size)
        self.score_layer = nn.Linear(alignment_hidden_size, 1)

    def forward(self, decoder_hidden_state, encoder_outputs):
        projected_decoder_state = self.decoder_hidden_layer(decoder_hidden_state.squeeze(0))
        projected_decoder_state = projected_decoder_state.unsqueeze(1)  # [batch, 1, hidden]

        projected_encoder_outputs = self.encoder_outputs_layer(encoder_outputs)

        alignment_scores = torch.tanh(projected_decoder_state + projected_encoder_outputs)
        alignment_scores = self.score_layer(alignment_scores).squeeze(2)  # (batch_size, sequence_length)

        # Apply softmax to get alignment weights
        alignment_weights = F.softmax(alignment_scores, dim=1)
        
        alignment_weights_expanded = alignment_weights.unsqueeze(2)
    
        context_vector = torch.sum(encoder_outputs * alignment_weights_expanded, dim=1)

        return context_vector, alignment_weights

In [None]:
dec_hidden = torch.randn(80, hidden_size)

In [None]:
attn = Attention(hidden_size * 2, hidden_size, hidden_size)
context_vector, alignment_weights = attn(dec_hidden, encoder_out)
context_vector.shape, alignment_weights.shape

## Maxout Layer

The final layer of the decoder is a Maxout layer, which projects a linear layer into two buckets and takes the max. A form of regularisation.

In [None]:
class MaxoutLayer(nn.Module):
    def __init__(self, input_size, output_size, num_pieces=2):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size * num_pieces)
        self.num_pieces = num_pieces

    def forward(self, x):
        output = self.linear(x)
        output = output.view(-1, self.num_pieces, output.size(1) // self.num_pieces)
        output, _ = torch.max(output, dim=1)

        return output

In [None]:
maxout_layer = MaxoutLayer(
    input_size=hidden_size * 3 + embed_size,
    output_size=vocab_size,
    num_pieces=2
)

hidden_states = torch.randn(1, hidden_size * 3 + embed_size)
maxout_layer(hidden_states).shape

## Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size: int, embed_size: int, hidden_size: int, dropout: float = 0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
        self.attention = Attention(hidden_size * 2, hidden_size, hidden_size)
        
        self.gru = nn.GRU(hidden_size * 2 + embed_size, hidden_size, batch_first=True)
        
        self.maxout = MaxoutLayer(
            input_size=hidden_size + hidden_size * 2 + embed_size, 
            output_size=vocab_size,
            num_pieces=2
        )

    def forward(self, input: torch.Tensor, hidden: torch.Tensor, encoder_outputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        embedded = self.embedding(input)
        
        context, attn_weights = self.attention(hidden, encoder_outputs)
        
        rnn_input = torch.cat((embedded, context), dim=1).unsqueeze(1)  # [batch_size, 1, hidden_size * 2 + embed_size]
    
        output, hidden = self.gru(input=rnn_input, hx=hidden)
        
        maxout_input = torch.cat((hidden[0], context, embedded), dim=1)
        
        # Note that Softmax will be applied in loss calculation.
        prediction_scores = self.maxout(maxout_input)  # [batch_size, output_dim]
        
        return prediction_scores, hidden, attn_weights

In [None]:
dec = Decoder(vocab_size, embed_size, hidden_size)

In [None]:
batch_size = 80

input_tensor = torch.tensor([[0, 1, 2, 3]] * 80).long()
hidden_tensor = torch.randn(1, batch_size, hidden_size)
encoder_outputs = torch.randn(batch_size, 4, hidden_size * 2) 

In [None]:
prediction, hidden, attn_weights = dec(input=input_tensor[:,0], hidden=hidden_tensor, encoder_outputs=encoder_outputs)
prediction.shape, hidden.shape, attn_weights.shape

## Model Implementation

In [None]:
class RNNSearch(nn.Module):
    def __init__(self, vocab_size, hidden_size, output_size, sos_token):
        super(RNNSearch, self).__init__()

        self.encoder = Encoder(vocab_size, embed_size, hidden_size)
        self.decoder = Decoder(vocab_size, embed_size, hidden_size)
        
        self.decoder_init = nn.Linear(hidden_size, hidden_size)
    
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.sos_token = sos_token
        
    def init_decoder(self, encoder_hidden):
        return torch.tanh(self.decoder_init(encoder_hidden)).unsqueeze(0)

    def forward(self, input, target, target_length):
        batch_size = input.shape[0]
        
        # Encoding
        encoder_outputs, encoder_hidden = self.encoder(input)

        # Initialise hidden state.
        decoder_hidden = self.init_decoder(encoder_hidden)
        
        # Initial input is SOS token.
        decoder_input = torch.tensor([self.sos_token] * batch_size).to(input.device)

        outputs = []
        for i in range(target.shape[-1]):
            decoder_output, decoder_hidden, _ = self.decoder(
                input=decoder_input,
                hidden=decoder_hidden,
                encoder_outputs=encoder_outputs
            )
            outputs.append(decoder_output)
            
            # Teacher forcing: next input is current target
            decoder_input = target[:, i]

        return torch.stack(outputs, dim=1)

In [None]:
# Create random input data
source = torch.randint(0, vocab_size, (batch_size, 80))  # Random source sentences
target = torch.randint(0, vocab_size, (batch_size, 80))  # Random target sentences

# Initialize model
model = RNNSearch(
    vocab_size=vocab_size,
    hidden_size=hidden_size,
    output_size=vocab_size,
    sos_token=vocab_size-1
)

# Test forward pass
output = model(
    input=source,
    target=target,
    target_length=10
)

model = None

# Print shapes
print("Input shape:", source.shape)
print("Target shape:", target.shape)
print("Output shape:", output.shape)

## Data

The paper demonstrates the approach on an English to French translation task, using the data provided as part of the [Workshop on Statistical Machine Translation in 2014](https://aclanthology.org/W14-3302.pdf). I've found a version of that on HuggingFace. Not sure exactly how closely it mirrors the paper, but I'm not too concerned.

In [None]:
dataset = datasets.load_dataset("presencesw/wmt14_fr_en")
print("Dataset structure:", dataset)

In the paper, they "concat news-test-2012 and news-test-2013" for the validation set, but I'm using the validation set kindly provided by presencesw.

## Tokeniser

The paper uses the Moses tokeniser, however, I'm going to use a multi-lingual tokeniser from HuggingFace, as it comes with a few features that makes life easier.

In [None]:
example = dataset['train'][0]
example

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    "facebook/mbart-large-cc25", model_max_length=max_length)

In [None]:
len(tokenizer)

In [None]:
print([tokenizer.decode(t) for t in tokenizer(example["en"])["input_ids"]])

In [None]:
print([tokenizer.decode(t) for t in tokenizer(example["fr"])["input_ids"]])

## Dataset and Dataloader



In [None]:
# Create dataset and dataloader
class TranslationDataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        src_tokens = self.tokenizer(item['en'], 
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len,
                                  return_tensors='pt')
        tgt_tokens = self.tokenizer(item['fr'],
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len,
                                  return_tensors='pt')
        
        return {
            'src': src_tokens['input_ids'].squeeze(),
            'tgt': tgt_tokens['input_ids'].squeeze(),
            'tgt_len': len(tgt_tokens['input_ids'][0])
        }

# Create dataloaders
train_dataset = TranslationDataset(dataset['train'].select(range(10000)), tokenizer, max_len=max_length)
val_dataset = TranslationDataset(dataset['validation'], tokenizer, max_len=max_length)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

## Training

In [None]:
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RNNSearch(
    vocab_size=len(tokenizer),
    hidden_size=hidden_size,
    output_size=len(tokenizer),
    sos_token=tokenizer.bos_token_id
).to(device)

# Adadelta optimizer as used in the paper
optimizer = torch.optim.Adadelta(model.parameters(), lr=1.0, rho=0.95, eps=1e-6)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            src = batch['src'].to(device)
            tgt = batch['tgt'].to(device)
            tgt_len = batch['tgt_len']
            
            output = model(src, tgt, tgt_len)
            output = output.view(-1, model.output_size)
            target = tgt.view(-1)
            
            loss = criterion(output, target)
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

def train_epoch(model, dataloader, optimizer, criterion, device, clip_value=5.0):
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)
        tgt_len = batch['tgt_len']
        
        optimizer.zero_grad()
        
        output = model(src, tgt, tgt_len)
        output = output.view(-1, model.output_size)
        target = tgt.view(-1)
        
        loss = criterion(output, target)
        loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# Training loop
best_valid_loss = float('inf')
n_epochs = 100
patience = 5
no_improvement = 0

print(f"Training on {device}")

for epoch in range(n_epochs):
    train_loss = train_epoch(model, train_dataloader, optimizer, criterion, device)
    valid_loss = validate(model, val_dataloader, criterion, device)
    
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\tValid Loss: {valid_loss:.3f}')
    
    # Save best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-model.pt')
        no_improvement = 0
    else:
        no_improvement += 1
    
    # Early stopping
    if no_improvement >= patience:
        print("Early stopping triggered")
        break