# Encoder-Decoder Models

The following tutorial explores basic encoder-decoder models and walks through building one for machine translation. For more foundational understanding, see Chapter 2 of *Large Language Models: A Deep Dive*. You can find it for under $15 here: [purchase](https://link.springer.com/book/10.1007/978-3-031-65647-7)

### How Encoder-Decoder Models Work

Encoder-decoder models are designed to handle **variable-length inputs and outputs**, a challenge in many natural language processing tasks. Unlike standard neural networks with fixed input/output sizes (e.g., classifiers), encoder-decoder models dynamically generate sequences of arbitrary length by learning token-by-token outputs.

The **encoder** transforms an input sequence into a fixed-size (or variable-length) vector representation, called a **context vector**. During training, the model learns to encode meaningful features of the input sequence into this representation.

The **decoder** takes the context vector and generates the output sequence, one token at a time, typically using **autoregression** (feeding its previous output back in as input). This process relies on a hidden state that evolves as more tokens are generated.

The architecture involves two sets of hidden states:
- One in the encoder, which processes the input tokens.
- One in the decoder, which generates output tokens based on both the encoder's context and previously generated tokens.

The predictive capability of the decoder comes from probabilistically selecting the next token, based on the formula:

$\text{softmax}(s_{t-1}, y'_{t-1}, c)$

Where:
- $\text s_{t-1}$ is the decoder's previous hidden state  
- $\text y'_{t-1}$ is the previously generated token  
- $\text c$ is the context vector from the encoder

**Softmax** transforms a vector of logits (raw model scores) into probabilities that sum to 1, enabling the model to sample or choose the most likely next token.

### What this Means / TL:DR

TL:DR you can train a model to learn sequence probabilities such that it can output sequences that vary in length. Here are few examples of what you could do with such a network:

- Translate one language to another (a very common. use case)
- Extract data from a sentence, such as highlighting named entities (e.g., names, nouns, locations)
- Predict time series data

The crux of an ancoder decoder model is: we have a sequence, can we predict another sequence that depends on the data in this sequence?

### What this Tutorial Does

In this tutorial we'll be building an English to French translator model. We'll use a very small toy dataset, so don't expect accuracy. Remember to write it out yourself, line by line. We'll start by grabbing the data from huggingface.

Good luck!

# Perform Standard Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import re

import random
import math
import time
from collections import Counter

# Seed torch and random
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Import Hugging Face Datasets
try:
    from datasets import load_dataset
    HF_DATASETS_AVAILABLE = True
except ImportError:
    HF_DATASETS_AVAILABLE = False
    print("Hugging Face datasets library not found. Please install it: pip install datasets")
    print("Falling back to the toy dataset.")

device = torch.device('mps' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

# Toy Dataset Download

In [None]:
# Define the tags for our source (SRC) and target (TGT) sequences
SRC_LANGUAGE = 'en'
TGT_LANGUAGE = 'fr'

MAX_DATASET_SIZE_TRAIN = 10000 # Keeping it small for quick demo
MAX_DATASET_SIZE_VALID = 1000

raw_data_pairs = []

# Don't bother learning or recreating this function
if HF_DATASETS_AVAILABLE:
    try:
        print("Attempting to load 'opus_books' dataset for en-fr...")
        hf_dataset = load_dataset("opus_books", "en-fr", split='train')
        print(f"Successfully loaded 'opus_books' dataset. Original size: {len(hf_dataset)}")

        # Define the indices where we'll cut out the train and valid data
        train_end_idx = 0
        if MAX_DATASET_SIZE_TRAIN is not None:
            train_end_idx = min(MAX_DATASET_SIZE_TRAIN, len(hf_dataset))
            dataset_subset_train = hf_dataset.select(range(train_end_idx))
        else:
            dataset_subset_train = hf_dataset

        # Validation starts where training ends
        valid_start_idx = train_end_idx
        valid_end_idx = valid_start_idx
        if MAX_DATASET_SIZE_VALID is not None:
            valid_end_idx = min(valid_start_idx + MAX_DATASET_SIZE_VALID, len(hf_dataset))
            if valid_start_idx < valid_end_idx: # Ensure there's data left for validation
                 dataset_subset_valid = hf_dataset.select(range(valid_start_idx, valid_end_idx))
            else:
                dataset_subset_valid = None
                print("Not enough data for a separate validation set with current MAX_DATASET_SIZE settings.")
        else:
            dataset_subset_valid = None

        # Process training data
        train_data_list = []
        for example in dataset_subset_train:
            src_text = example['translation'][SRC_LANGUAGE]
            tgt_text = example['translation'][TGT_LANGUAGE]
            if src_text and tgt_text:
                train_data_list.append((src_text, tgt_text))

        # Process validation data
        valid_data_list = []
        if dataset_subset_valid:
            for example in dataset_subset_valid:
                src_text = example['translation'][SRC_LANGUAGE]
                tgt_text = example['translation'][TGT_LANGUAGE]
                if src_text and tgt_text:
                    valid_data_list.append((src_text, tgt_text))

        # We apply some randomness to the data
        random.shuffle(train_data_list)

    except Exception as e:
        print(f"Error loading dataset from Hugging Face: {e}")
        print("Falling back to the toy dataset.")
        HF_DATASETS_AVAILABLE = False

print(f"Total training examples: {len(train_data_list)}")
print(f"Total validation examples: {len(valid_data_list)}")

if train_data_list:
    print("\nSample training data point:")
    print(f"Source: {train_data_list[0][0]}")
    print(f"Target: {train_data_list[0][1]}")
else:
    print("Error: No training data available!")

# Tokenizer

What is a tokenizer? Tokenizers determine how to transform your input data into vector-ready data for the machine learning process. One way to tokenize a sentence is to simply assign unique numbers to each individual word. You could do this by counting up from 1 every time you encounter a new word and storing these values in a dictionary.

Other, more complex and performant tokenizers exist. Some tokenizers will often reduce words into their 'lemmas' (base words) or 'stems' (base strings) to avoid producing overly large vocabularies. However, these tokenizer are becoming less common as models become larger.

We build a simple tokenizer to introduce the concept. This tokenizer does not reduce to lemmas or stems on inputs. Instead, each unique word is given a numeric representation that counts upward as new words are added. This tokenizer will also be used to decode the model output back into french.

### Token Types

It is important to note *how* we deal with variable length sequences in this model. We already described how encoder-decoder models can handle variable length outputs, making them excellent for translation. However, we still need strategies around how to signal the start and end of a sentence, and how we ultimately fit these variable length sequences into the training data.

We can do this with four tokens. These tokens are:

- PAD: the $\text <pad>$ token is added N times to the end of a sequence to prevent a sequence from being too short.
- SOS: the Start of Sentence $\text <sos>$ token signals a sequence has started.
- EOS: the End of Sentence $\text <eos>$ token signals a sequence has ended.
-UNK: the unknown $\text <unk>$ token is used for inputs that the tokenizer does not have in its vocabulary.

Since our model ultimately has a maximum sequence length, we adjust the tokenizer to support a padding token, which we'll add to training data later on. For now, our tokenizer should support these four unique tokens, alongside all the token values for our training dataset.

### TL:DR

The tokenizer converts words to numbers, and we have some special tokens to handle padding (ensuring all inputs are the same length), start and ends of sentences (so that when we decode, we know when to stop), and an unknown token for any out-of-vocabulary inputs that are provided to our model.

In [None]:
# Define the special tokens
PAD_TOKEN = "<pad>"
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
UNK_TOKEN = "<unk>"

class CustomTokenizer:
  def __init__(self, language_name):
    self.language_name = language_name

    # these two dictionaries are super important. We need them to get an index (token)
    # from a word (word2index) and we need to later turn the token index back into
    # a word again (index2word).
    self.word2index = {}
    self.index2word = {}
    self.n_count = 0
    self.word_counts = Counter()

    # initialise the special tokens into the vocabulary
    self.add_word(PAD_TOKEN)
    self.add_word(SOS_TOKEN)
    self.add_word(EOS_TOKEN)
    self.add_word(UNK_TOKEN)

    # add the numeric token values as class attributes for later
    self.PAD_IDX = self.add_word(PAD_TOKEN)
    self.SOS_IDX = self.add_word(SOS_TOKEN)
    self.EOS_IDX = self.add_word(EOS_TOKEN)
    self.UNK_IDX = self.add_word(UNK_TOKEN)

  # a linear tokenizer (count -> index)
  def add_word(self, word):
    if word not in self.word2index:
      self.word2index[word] = self.n_count
      self.index2word[self.n_count] = word
      self.n_count += 1
    return self.word2index[word]

  # clean the sentence before tokenization
  def add_sentence(self, sentence):
    cleaned_sentence = re.sub(r'[^a-z\s\']', '', sentence.lower())
    for word in cleaned_sentence.lower().split(' '):
      self.word_counts[word] += 1

  def build_vocab(self, sentences):
    # Build up a count for each word
    for sentence in sentences:
      self.add_sentence(sentence)

    # Add each unique key (word) to the word2index / index2word dicts
    for word in sorted(self.word_counts.keys()):
      self.add_word(word)

  # our outward facing methods getting both indices and sentences
  def sentence_to_indices(self, sentence):
    cleaned_sentence = re.sub(r'[^a-z\s\']', '', sentence.lower())
    tokens = [SOS_TOKEN] + cleaned_sentence.lower().split(' ') + [EOS_TOKEN]
    indices = [self.word2index.get(token, self.UNK_IDX) for token in tokens]
    return indices

  def indices_to_sentence(self, indices):
    if hasattr(indices, 'tolist'):
      indices = indices.tolist()
    return ' '.join(self.index2word.get(index, UNK_TOKEN) for index in indices
                    if index not in [self.SOS_IDX, self.EOS_IDX, self.PAD_IDX])


Create the tokenizers and input vocabularies

In [None]:
src_tokenizer = CustomTokenizer(SRC_LANGUAGE)
tgt_tokenizer = CustomTokenizer(TGT_LANGUAGE)

# Print a few pairs
for pair in train_data_list[:3]:
  print(pair[0])
  print(pair[1])

src_sentences = [pair[0] for pair in train_data_list]
tgt_sentences = [pair[1] for pair in train_data_list]

src_tokenizer.build_vocab(src_sentences)
tgt_tokenizer.build_vocab(tgt_sentences)

Test their behaviour

In [None]:
# Vocabulary
print("\nSource Vocabulary (EN):")
print(src_tokenizer.word2index)
print(f"PAD_IDX: {src_tokenizer.PAD_IDX}, SOS_IDX: {src_tokenizer.SOS_IDX},"
      f"EOS_IDX: {src_tokenizer.EOS_IDX}, UNK_IDX: {src_tokenizer.UNK_IDX}")

print(f"\nTarget Vocabulary (FR)")
print(tgt_tokenizer.word2index)
print(f"PAD_IDX: {tgt_tokenizer.PAD_IDX}, SOS_IDX: {tgt_tokenizer.SOS_IDX}"
      f"EOS_IDX: {tgt_tokenizer.EOS_IDX}, UNK_IDX: {tgt_tokenizer.UNK_IDX}")

# Test the tokenizer
test_src_sent = "the book"
test_src_indices = src_tokenizer.sentence_to_indices(test_src_sent)
print(f"\n'{test_src_sent}' -> {test_src_indices}")
print(f"'{test_src_indices}' -> '{src_tokenizer.indices_to_sentence(test_src_indices)}'\n")

test_tgt_sent = "le livre"
test_tgt_indices = tgt_tokenizer.sentence_to_indices(test_tgt_sent)
print(f"'{test_tgt_sent}' -> {test_tgt_indices}")
print(f"'{test_tgt_indices}' -> '{tgt_tokenizer.indices_to_sentence(test_tgt_indices)}'")

### Padding
As discussed above, we need to pad the training data to ensure each entry is of equal length. We do this with the collate_fn function, which we then apply to our data in the get_data_iterator function below that.

In [None]:
def collate_fn(batch, src_tokenizer, tgt_tokenizer, device):
  src_batch, tgt_batch = [], []
  src_lens, tgt_lens = [], []
  for src_sample, tgt_sample in batch:
    src_indices = src_tokenizer.sentence_to_indices(src_sample)
    tgt_indices = tgt_tokenizer.sentence_to_indices(tgt_sample)

    # take the indices of this batch and create a tensor
    # (a matrix structure used in pytorch for training / inference)
    src_batch.append(torch.tensor(src_indices, dtype=torch.long))
    tgt_batch.append(torch.tensor(tgt_indices, dtype=torch.long))

    src_lens.append(len(src_indices))
    tgt_lens.append(len(tgt_indices))

  # pad the tensors using a utility method from torch.nn. Note the padding_value
  # batch_first=False is a complex topic, which we'll cover later on. Note that
  # typically, tensor dimensions START with the batch B. Here we do not structure
  # our tensors in this way.
  src_padded = nn.utils.rnn.pad_sequence(src_batch, padding_value=src_tokenizer.PAD_IDX, batch_first=False)
  tgt_padded = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=tgt_tokenizer.PAD_IDX, batch_first=False)
  src_lens = torch.tensor(src_lens)
  tgt_lens = torch.tensor(tgt_lens)

  # return on the computing device. Tensors must generally be moved onto the hardware first
  return src_padded.to(device), tgt_padded.to(device), src_lens.to(device), tgt_lens.to(device)

Create a sample dataloader. A dataloader is a function that iteratively returns the training and validation data during the training loop. Note the use of a yield pattern, which returns a padded batch of data each time the training loop will call it.

In [None]:
BATCH_SIZE = 64
def get_data_iterator(data, src_tokenizer, tgt_tokenizer, batch_size, device, shuffle=True):
  if shuffle:
    data_copy = list(data)
    random.shuffle(data_copy)
  else:
    data_copy = data

  for i in range(0, len(data_copy), batch_size):
    batch = data_copy[i:i+batch_size]
    yield collate_fn(batch, src_tokenizer, tgt_tokenizer, device)

In [None]:
print("\nTesting data iterator:")
data_iter = get_data_iterator(train_data_list, src_tokenizer, tgt_tokenizer, BATCH_SIZE, device)
for i, (src_batch, tgt_batch, src_lens, tgt_lens) in enumerate(data_iter):
  print(f"Batch {i+1}:")
  print("Source batch shape: ", src_batch.shape)
  print("target batch shape: ", tgt_batch.shape)
  print("Source lengths: ", src_lens)
  print("Target lengths: ", tgt_lens)
  print("Source batch (first example):\n", src_batch[:, 0])
  print("Target batch (first example):\n", tgt_batch[:, 0])
  if i == 0: break

# Building the Model Components

The following four code blocks build the components that make up this model. We start by defining a standard Encoder block. We'll use nn.GRU to perform the hidden state and context training. Gated recurrent unit (GRU) involves some fairly complex mathematics, which we'll implement manually at a later point. You can read more about it here: https://en.wikipedia.org/wiki/Gated_recurrent_unit

The second component we'll write is the attention mechanism. Typically, a normal Encoder-Decoder (sometimes called Seq2Seq) model will just rely on its own context and hidden states to perform training. However, we can improve the accuracy of the model by implementing attention. Attention in this instance is the measure of how well the encoder output aligns with the decoder hidden state. The results from this measure are then added to the context vector from the encoder to feed into the decoder module. More information is provided below.

Lastly, the decoder performs its sequential output with the attention-scored context from above. These three components compose the Seq2Seq class, which is the final code block. Try to pay attention to the tensor matrix shapes - how do they mutate and change as they pass through the model components. Comments indicate each mutation.

### Encoder

We'll write the encoder using an `nn.GRU` unit. The encoder computes hidden states over time using:

- Hidden states $h_t$ are computed as $h_t = f(h_{t-1}, x_t)$.
- The context vector $c$ is derived from the hidden states: either as the final state $h_T$, or as a function $c = m(h_1, h_2, ..., h_T)$.
- Encoders may be bidirectional, meaning $h_t$ combines forward and backward passes: information from both past ($h_{t-1}$) and future ($h_{t+1}$).

### How Does the GRU Work?

The Gated Recurrent Unit (GRU) uses gating mechanisms to control how information flows through its hidden state:

- **Update gate ($z_t$):** controls how much of the past state $h_{t-1}$ to retain.
- **Reset gate ($r_t$):** determines how much of the past to forget when generating new candidate states.
- **Candidate hidden state ($\tilde{h}_t$):** proposed new hidden state, computed using $r_t$ and a tanh activation.
- **Final hidden state ($h_t$):** a weighted combination of $\tilde{h}_t$ and $h_{t-1}$, using $z_t$.

Each gate computes outputs of shape $[B, H]$, taking as input the embedding $x_t$ of shape $[B, E]$ and the previous hidden state $h_{t-1}$. These mechanisms allow the model to retain or forget information dynamically over time, improving learning over long sequences.

### TL;DR

**TL;DR**: GRUs are a compact and efficient way to learn how input embeddings evolve into hidden representations over time. These hidden states then feed into the next stage: the attention-based decoder.

In [None]:
class Encoder(nn.Module):
  def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout_p):
    super().__init__()

    # This is arbitrary, you can make it as big or small as you like (but test results)
    self.hid_dim = hid_dim
    # also arbitrary. 4 is a good start
    self.n_layers = n_layers

    # Here is the interesting bit: input dim is the size of your vocabulary.
    # emb_dim is the arbitrary learning layer - you can select a size and experiment.
    # We therefore train the recurrent unit to work at a token level
    self.embedding = nn.Embedding(input_dim, emb_dim)
    self.rnn = nn.GRU(emb_dim,
                      hid_dim,
                      n_layers,
                      bidirectional=True,
                      dropout=dropout_p if n_layers > 1 else 0)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, src_seq):
    embedded = self.dropout(self.embedding(src_seq))

    # we get outputs, which is the hidden state at every token step, and the hidden,
    # which is the final state of the encoder.
    outputs, hidden = self.rnn(embedded)
    return outputs, hidden

### Attention Module

We implement an optional attention mechanism that computes attention weights based on the relationship between the current decoder state and the encoder's hidden states. These weights determine how much focus the decoder should place on each part of the input sequence at each timestep and provides contextual information about each token in relation to the others.

The resulting attention context vector captures relevant information from the encoder and is combined with the decoder's input or hidden state. This enriched representation is then passed into the embedding or decoding layer, allowing the model to dynamically incorporate context from the entire input sequence during decoding.

In [None]:
class Attention(nn.Module):
  def __init__(self, enc_hid_dim, dec_hid_dim):
    super().__init__()
    # enc_hid_dim_effective is the final hidden layer dimension, e.g. H * 2 for bidir
    # dec_hid_dim is the decoder hidden dimensions
    self.attn = nn.Linear(enc_hid_dim + dec_hid_dim, dec_hid_dim)
    self.v = nn.Linear(dec_hid_dim, 1, bias=False)

  def forward(self, decoder_hidden_top_layer, encoder_outputs):
    # decoder_hidden_top_layer = [batch_size, dec_hid_dim] as we take current only
    # encoder outputs = [src_len, batch_size, enc_hid_dim]
    batch_size = encoder_outputs.shape[1]
    src_len = encoder_outputs.shape[0]
    hidden_repeated = decoder_hidden_top_layer.unsqueeze(1).repeat(1, src_len, 1)
    encoder_outputs_permuted = encoder_outputs.permute(1, 0, 2)
    energy = torch.tanh(self.attn(torch.cat((hidden_repeated, encoder_outputs_permuted), dim=2)))
    # This doesn't damage attention [batch, src_len, 1] because it just removes the dim
    attention_scores = self.v(energy).squeeze(2)
    return F.softmax(attention_scores, dim=1)

### Decoder

The decoder takes the context vector from the encoder and generates its own hidden states to produce the output sequence. Each decoder hidden state depends not only on the previous decoder state but also on the previously generated token and the context vector:

- $s_{t'} = g(s_{t'-1}, y_{t'-1}, c)$
- The output at timestep $t'$, $y_{t'}$, is computed as a probability distribution:  
  $P(y_{t'} \mid y_1, \dots, y_{t'-1}, c) = \text{softmax}(s_{t'-1}, y_{t'-1}, c)$

This means that each predicted token is influenced by all previous tokens as well as the encoder's context, allowing the model to maintain sequential coherence.

When using attention, the decoder accesses all encoder hidden states at every timestep. This enables the model to dynamically focus on relevant parts of the input sequence, providing richer contextual understanding and improving output quality during training and inference.

In [None]:
class Decoder(nn.Module):
  def __init__(self, output_dim, emb_dim, dec_hid_dim, n_layers, dropout_p, enc_hid_dim):
    super().__init__()
    self.output_dim = output_dim
    self.hid_dim = dec_hid_dim
    self.n_layers = n_layers
    self.enc_hid_dim = enc_hid_dim # This is HID_DIM * 2

    # Components
    self.embedding = nn.Embedding(output_dim, emb_dim)
    self.attention = Attention(self.enc_hid_dim, self.hid_dim)
    rnn_input_size = emb_dim + self.enc_hid_dim
    self.rnn = nn.GRU(rnn_input_size,
                          dec_hid_dim,
                          n_layers,
                          dropout=dropout_p if n_layers > 1 else 0)

    self.fc_out = nn.Linear(dec_hid_dim + enc_hid_dim + emb_dim, output_dim)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, input_token, hidden_state, encoder_outputs):
      # ... your forward pass is correct and does not need to change ...
      input_token = input_token.unsqueeze(0)
      embedded = self.dropout(self.embedding(input_token))

      decoder_top_hidden = hidden_state[-1]
      a = self.attention(decoder_top_hidden, encoder_outputs)
      a = a.unsqueeze(1)

      encoder_outputs_permutated = encoder_outputs.permute(1, 0, 2)
      context = torch.bmm(a, encoder_outputs_permutated)
      context = context.permute(1, 0, 2)

      # Now this concatenation correctly matches the GRU's input size
      rnn_input = torch.cat((embedded, context), dim=2)
      output, new_hidden_state = self.rnn(rnn_input, hidden_state)

      # Recombine the inputs into the fc_out layer
      combined = torch.cat((output.squeeze(0), context.squeeze(0), embedded.squeeze(0)), dim=1)
      prediction = self.fc_out(combined)

      return prediction, new_hidden_state

### Seq2Seq Implementation

The `Seq2Seq` module manages the full encoder–decoder architecture. It contains:

- An **encoder** that processes the input sequence.
- A **decoder** that generates the output sequence.
- A **bridge** (a linear layer) that transforms the bidirectional hidden states from the encoder into a format suitable for initializing the decoder.

Since the encoder is bidirectional and the decoder is unidirectional, the hidden states must be reshaped and projected to match the decoder's expected dimensions. The bridge performs this transformation using concatenation and a learned linear projection.

During decoding, the model generates tokens step by step. At each timestep, it predicts the next token by computing a probability distribution over the target vocabulary using a softmax layer. The predicted outputs for each timestep are collected and returned as the full output sequence.

In [None]:
def resize_bidirectional_hidden(n_layers, batch_size, hid, enc_hidden, bridge):
  enc_hidden = enc_hidden.view(n_layers, 2, batch_size, hid)
  cat = torch.cat((enc_hidden[:, 0, :, :], enc_hidden[:, 1, :, :]), dim=2)
  return torch.tanh(bridge(cat))

class Seq2Seq(nn.Module):
  def __init__(self, encoder, decoder, device):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.device = device
    enc_bridge_input_dim = self.encoder.hid_dim * 2
    self.bridge = nn.Linear(enc_bridge_input_dim, self.decoder.hid_dim)

    # Sanity check for dimensionality
    assert encoder.hid_dim == decoder.hid_dim, "hidden dims must be equal"
    assert encoder.n_layers == decoder.n_layers, "layers must be equal"

  def _init_decoder_hidden(self, enc_hidden):
    n_layers, batch_size, hid = self.encoder.n_layers, enc_hidden.size(1), self.encoder.hid_dim
    return resize_bidirectional_hidden(n_layers, batch_size, hid, enc_hidden, self.bridge)

  def forward(self, src_seq, tgt_seq, teacher_forcing_ratio=0.5):
    batch_size = src_seq.shape[1]
    # Encode
    enc_out, hidden = self.encoder(src_seq)
    hidden = self._init_decoder_hidden(hidden)

    # Decode
    tgt_len = tgt_seq.shape[0]
    batch_size = tgt_seq.shape[1]
    tgt_vocab_size = self.decoder.output_dim

    outputs = torch.zeros(tgt_len, batch_size, tgt_vocab_size).to(self.device)

    dec_in = tgt_seq[0, :]
    for t in range(1, tgt_len):
      dec_out, hidden = self.decoder(dec_in, hidden, enc_out)
      outputs[t] = dec_out
      teacher_force = random.random() < teacher_forcing_ratio
      top1 = dec_out.argmax(1)
      dec_in = tgt_seq[t] if teacher_force else top1

    return outputs

# Training

Now, we perform the standard training process. By now, this format should be familiar to you.

In [None]:
# hyperparams
INPUT_DIM = src_tokenizer.n_count
OUTPUT_DIM = tgt_tokenizer.n_count
ENC_EMB_DIM = 128
DEC_EMB_DIM = 128
HID_DIM = 128
EFFECTIVE_ENC_HID_DIM = HID_DIM * 2
N_LAYERS = 2
ENC_DROPOUT = 0.2
DEC_DROPOUT = 0.2
LEARNING_RATE = 0.001
N_EPOCHS = 5
CLIP = 1
BIDIR_MODEL_NAME = "language_enc_dec_bidir_attn.pt"
UNIDIR_MODEL_NAME = "language_enc_dec.pt"
MODEL_NAME = BIDIR_MODEL_NAME

In [None]:
# components
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT).to(device)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT, EFFECTIVE_ENC_HID_DIM).to(device)
model_bidir = Seq2Seq(enc, dec, device)

In [None]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model_bidir):,} trainable parameters')

In [None]:
# optim and learn
optimizer = optim.Adam(model_bidir.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = tgt_tokenizer.PAD_IDX)

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
def train_epoch(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    batch_n = 0
    for i, (src, tgt, _, _) in enumerate(iterator): # src_lens, tgt_lens not used directly here
        batch_n += 1
        optimizer.zero_grad()
        # output = [tgt_len, batch_size, output_vocab_size]
        output = model(src, tgt)
        # get vocab length for next step
        output_dim = output.shape[-1]
        # remove <sos> tag by enforcing [1:]
        # turn [tgt_len, batch, vocab] into [(tgt_len-1 * batch), vocab] so it fits into loss
        output_flat = output[1:].view(-1, output_dim)
        # since we know the vocab, this doesn't have the V dimension [tgt_len-1, batch]
        tgt_flat = tgt[1:].view(-1)
        # now that they are equal dim, compute the loss between out and tgt
        loss = criterion(output_flat, tgt_flat)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / batch_n

def evaluate_epoch(model, iterator, criterion):
  # the same setup, but we loop without adam use
  model.eval()
  epoch_loss = 0
  batch_n = 0
  with torch.no_grad():
    for i, (src, tgt, _, _) in enumerate(iterator):
      batch_n += 1
      output = model(src, tgt, 0)
      output_dim = output.shape[-1]
      output_flat = output[1:].view(-1, output_dim)
      tgt_flat = tgt[1:].view(-1)
      loss = criterion(output_flat, tgt_flat)
      epoch_loss += loss.item()

  return epoch_loss / batch_n

### Train Execute

In [None]:
best_valid_loss = float('inf')

print("Starting training...")
for epoch in range(N_EPOCHS):
  start_time = time.time()
  train_iter = get_data_iterator(train_data_list,
                               src_tokenizer,
                               tgt_tokenizer,
                               BATCH_SIZE,
                               device,
                               shuffle=True)
  valid_iter = get_data_iterator(valid_data_list,
                               src_tokenizer,
                               tgt_tokenizer,
                               BATCH_SIZE,
                               device,
                               shuffle=False)
  train_loss = train_epoch(model_bidir, train_iter, optimizer, criterion, CLIP)
  valid_loss = evaluate_epoch(model_bidir, valid_iter, criterion)

  end_time = time.time()
  epoch_mins, epoch_secs = epoch_time(start_time, end_time)
  print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
  print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
  print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

torch.save(model_bidir.state_dict(), MODEL_NAME)

### Test the Model

In [None]:
model_bidir.load_state_dict(torch.load(MODEL_NAME, weights_only=True))
model_bidir.eval()

### Beam Search Decoding

The `beam_search_translate` function implements beam search for sequence generation in a trained Seq2Seq model. It performs the following steps:

- **Tokenization and Encoding**: The input sentence is tokenized and encoded using the source tokenizer and encoder. The encoder outputs and hidden state are used to initialize the decoder.

- **Beam Initialization**: A beam is initialized with the start-of-sequence (`<sos>`) token and the decoder's initial hidden state. The algorithm maintains two lists: `active_beams` (currently growing sequences) and `completed` (finished sequences ending in `<eos>`).

- **Step-by-Step Decoding**: At each timestep:
  - Active beams are expanded by generating the log-probabilities of the next token.
  - The top `k` candidates (beam width) are selected based on their cumulative scores.
  - Early `<eos>` tokens are discouraged by masking them out until a minimum number of decoding steps.

- **Length Penalty**: During beam ranking, a length normalization penalty is applied to favor more fluent outputs without disproportionately penalizing longer sequences.

- **Completion and Output**: Once the search terminates (via `<eos>` or `max_len`), the completed sequences are scored and the best one is selected. Start (`<sos>`) and end (`<eos>`) tokens are removed, and the remaining token indices are converted back into a sentence using the target tokenizer.

Beam search improves decoding quality over greedy approaches by exploring multiple hypotheses at each timestep and selecting the best overall sequence based on both likelihood and sequence length.

In [None]:
class BeamHypothesis:
  def __init__(self, tokens, log_prob, hidden_state):
    self.tokens = tokens
    self.log_prob = log_prob
    self.hidden_state = hidden_state

  def extend(self, token_idx, log_prob_token, new_hidden_state):
    return BeamHypothesis(
        self.tokens + [token_idx],
        self.log_prob + log_prob_token,
        new_hidden_state
    )

  @property
  def latest_token(self):
    return self.tokens[-1]

  def __len__(self):
    return len(self.tokens)

  def __lt__(self, other):
    return self.log_prob < other.log_prob

  def normalized_score(self, alpha=0.3):
    length = len(self.tokens)
    lp = ((5 + length) ** alpha) / ((5 + 1) ** alpha)
    return self.log_prob / lp

Build the search function

In [None]:
import torch
import torch.nn.functional as F

def beam_search_translate(sentence,
                          src_tokenizer,
                          tgt_tokenizer,
                          model,
                          device,
                          beam_width=3,
                          max_len=50,
                          len_penalty_alpha=0.6,
                          min_target_len=4):
    model.eval()

    # 1) Tokenize & batchify
    if isinstance(sentence, str):
        src_indices = src_tokenizer.sentence_to_indices(sentence)
    else:
        src_indices = sentence
    src_tensor = torch.tensor(src_indices, dtype=torch.long) \
                     .unsqueeze(1).to(device)  # [src_len, 1]

    # 2) Encode
    with torch.no_grad():
        enc_outputs, hidden = model.encoder(src_tensor)
        decoder_hidden = model._init_decoder_hidden(hidden)

    # 3) Initialize beams
    initial_beam = BeamHypothesis(
        tokens=[tgt_tokenizer.SOS_IDX],
        log_prob=0.0,
        hidden_state=decoder_hidden
    )
    active_beams   = [initial_beam]
    completed      = []

    def clone_hidden(h):
        if isinstance(h, tuple):            # LSTM case
            return tuple(t.clone() for t in h)
        return h.clone()                    # GRU case

    # 4) Beam search
    for step in range(max_len):
        if not active_beams:
            break

        candidates = []
        next_beams = []

        # move any EOS-ended beams to completed
        for beam in active_beams:
            if beam.latest_token == tgt_tokenizer.EOS_IDX:
                completed.append(beam)
            else:
                next_beams.append(beam)
        active_beams = next_beams
        if not active_beams:
            break

        # expand each active beam
        for beam in active_beams:
            inp = torch.tensor([beam.latest_token], dtype=torch.long).to(device)  # [batch=1]
            with torch.no_grad():
                output, new_hidden = model.decoder(inp, beam.hidden_state, enc_outputs)
            # output: [batch=1, vocab_size]
            logits    = output.squeeze(0)                     # → [vocab_size]
            log_probs = F.log_softmax(logits, dim=-1)

            # forbid <eos> too early
            if step < min_target_len:
                log_probs[tgt_tokenizer.EOS_IDX] = -float('inf')

            topk_lp, topk_idx = torch.topk(log_probs, beam_width)
            print([
                (tgt_tokenizer.index2word[i.item()], round(p.item(), 4))
                for i, p in zip(topk_idx, topk_lp)
            ])
            for lp, idx in zip(topk_lp.tolist(), topk_idx.tolist()):
                # **positional args** to match your extend() signature:
                new_beam = beam.extend(idx, lp, clone_hidden(new_hidden))
                candidates.append(new_beam)

        if not candidates:
            break

        # keep best K hypotheses
        active_beams = sorted(
            candidates,
            key=lambda b: b.normalized_score(len_penalty_alpha),
            reverse=True
        )[:beam_width]

    # gather finals
    completed.extend(active_beams)
    if not completed:
        return "error"

    # pick best
    best = max(
        completed,
        key=lambda b: b.normalized_score(len_penalty_alpha)
    )

    # strip SOS/EOS
    toks = best.tokens
    if toks and toks[0] == tgt_tokenizer.SOS_IDX:
        toks = toks[1:]
    if toks and toks[-1] == tgt_tokenizer.EOS_IDX:
        toks = toks[:-1]

    return tgt_tokenizer.indices_to_sentence(toks)

### BLEU Score Overview

This implementation calculates the BLEU score, which measures how closely a generated sentence matches one or more reference sentences. It's not necessary to implement this yourself, as there are libraries to do so. However, I kept it here for interest's sake.

#### Core Concepts

- **N-grams**: Sub-sequences of `n` words used to compare local word patterns.
- **Modified Precision**: Counts matching n-grams between candidate and references, clipped to avoid overcounting.
- **Brevity Penalty**: Penalizes candidates that are too short compared to references.
- **Geometric Mean**: Combines n-gram precisions into a single score.

#### BLEU Function

The `bleu()` function:
1. Computes modified precision for 1- to 4-grams.
2. Applies smoothing if no matches exist.
3. Calculates the geometric mean of precisions.
4. Multiplies by the brevity penalty.

Returns a score between 0 (no match) and 1 (perfect match).

In [None]:
import math
from collections import Counter
from typing import List

def ngrams(seq: List[str], n: int) -> List[tuple[str, ...]]:
  return [tuple(seq[i:i+n]) for i in range(len(seq) - n + 1)]

def modified_precision(candidate: List[str],
                       references: List[List[str]],
                       n: int):
  cand_ngrams = Counter(ngrams(candidate, n))
  max_reference_counts = Counter()

  for ref in references:
    ref_counts = Counter(ngrams(ref, n))
    for ngram, count in ref_counts.items():
      max_reference_counts[ngram] = max(max_reference_counts[ngram], count)

  clipped_counts = {ngram: min(count, max_reference_counts[ngram])
    for ngram, count in cand_ngrams.items()}

  numerator = sum(clipped_counts.values())
  denominator = sum(cand_ngrams.values())

  return numerator, denominator

def brevity_penalty(c: int, r: int) -> float:
  return 1.0 if c > r else math.exp(1 - r / c)

def closest_ref_len(c: int, ref_lens: List[int]) -> int:
  return min(ref_lens, key=lambda rl: (abs(rl - c), rl))

def bleu(candidate: List[str],
         references: List[List[str]],
         max_n: int = 4) -> float:
  weights = [1/max_n] * max_n
  precisions = []

  for n in range(1, max_n+1):
    num, den = modified_precision(candidate, references, n)
    if num == 0:
      num, den = 1, 2
    precisions.append((num, den))

  geo_mean = math.exp(sum(w * math.log(num/den)
    for (num, den), w in zip(precisions, weights)))

  c = len(candidate)
  r = closest_ref_len(c, [len(r) for r in references])
  bp = brevity_penalty(c, r)
  return bp * geo_mean

def tokenize_sequence(sequence):
  return sequence.lower().split(' ')

# Last Step: Run It!

In [None]:
# get BLEU candidates / references
references_bidir = [tokenize_sequence(text_pair[1]) for text_pair in valid_data_list]
candidates_bidir = [tokenize_sequence(beam_search_translate(text_pair[0], src_tokenizer, tgt_tokenizer, model_bidir, device, beam_width=5, max_len=100)) for text_pair in valid_data_list]

In [None]:
bleu_results = []
for candidate in candidates_bidir:
  bleu_results.append(bleu(candidate, references_bidir))

print(bleu_results)
sum_of_bleu = sum(bleu_results)
len_of_bleu = len(bleu_results)
print(f'The average bidir BLEU is: {sum_of_bleu / len_of_bleu}')

In [None]:
print("\n--- Sample Translations from Validation Set ---")
for i, (src_text, tgt_text) in enumerate(valid_data_list[:10]): # Show first 10
    translation = beam_search_translate(src_text, src_tokenizer, tgt_tokenizer, model_bidir, device)
    print(f"Original (EN):     {src_text}")
    print(f"Ground Truth (FR): {tgt_text}")
    print(f"Translated (FR):   {translation}\n")