<a href="https://colab.research.google.com/github/jackshiels/UsefulLLMTutorials/blob/main/1_EncoderDecoderModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Encoder-Decoder Models

The following tutorial takes a look at basic encoder-decoder and generates a model for machine translation. The recommended reading for this tutorial is Chapter 2 of Large Language Models: A Deep Dive. You can find it here for under $15: [purchase](https://link.springer.com/book/10.1007/978-3-031-65647-7)

We will be implementing the torch Gated Recurrent Unit (GRU), which is a choice against using the traditional Long Short-Term Memory (LSTM) model. By the end of the softmax layer, we are implementing greedy searching for tokens. An alternative attention approach is provided, too.

In [78]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import random
import math
import time
from collections import Counter

# Seed torch and random
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

Using device cpu


In [79]:
# Options for training
bidirectional_gru = True

### Toy Dataset

In [80]:
SRC_LANGUAGE = 'en'
TGT_LANGUAGE = 'fr'

raw_data_pairs = [
    ("hello world", "bonjour le monde"),
    ("how are you", "comment allez vous"),
    ("i am fine", "je vais bien"),
    ("good morning", "bonjour"),
    ("thank you", "merci"),
    ("i love pytorch", "j aime pytorch"),
    ("machine translation is cool", "la traduction automatique est cool"),
    ("see you later", "a plus tard"),
    ("what is your name", "quel est votre nom"),
    ("my name is model", "mon nom est modele"),
    ("nice to meet you", "ravi de vous rencontrer"),
    ("good night", "bonne nuit"),
    ("have a nice day", "bonne journee"),
    ("where are you from", "d ou venez vous"),
    ("i am from canada", "je viens du canada"),
    ("do you speak english", "parlez vous anglais"),
    ("yes i speak english", "oui je parle anglais"),
    ("no i don't understand", "non je ne comprends pas"),
    ("can you help me", "pouvez vous m aider"),
    ("i need assistance", "j ai besoin d aide"),
    ("what time is it", "quelle heure est il"),
    ("it is five o'clock", "il est cinq heures"),
    ("where is the station", "ou est la gare"),
    ("i am learning french", "j apprends le francais"),
    ("this is a cat", "c est un chat"),
    ("that is a dog", "c est un chien"),
    ("the weather is nice", "il fait beau"),
    ("it is raining", "il pleut"),
    ("i am hungry", "j ai faim"),
    ("i am thirsty", "j ai soif"),
    ("let's go", "allons y"),
    ("come here", "viens ici"),
    ("open the door", "ouvre la porte"),
    ("close the window", "ferme la fenetre"),
    ("i am tired", "je suis fatigue"),
    ("i don't know", "je ne sais pas"),
    ("i agree", "je suis d accord"),
    ("i disagree", "je ne suis pas d accord"),
    ("i like this song", "j aime cette chanson"),
    ("this is my friend", "c est mon ami"),
    ("do you like coffee", "aimes tu le cafe"),
    ("yes i like coffee", "oui j aime le cafe"),
    ("no i prefer tea", "non je prefere le the"),
    ("what do you do", "que fais tu dans la vie"),
    ("i am a student", "je suis etudiant"),
    ("i am a teacher", "je suis professeur"),
    ("deep learning is interesting", "l apprentissage profond est interessant"),
    ("we are training a model", "nous entrainons un modele"),
    ("this dataset is large", "ce jeu de donnees est grand"),
    ("the model is overfitting", "le modele fait du surapprentissage"),
    ("we need more data", "nous avons besoin de plus de donnees"),
    ("good afternoon", "bon apres-midi"),
    ("good evening", "bonsoir"),
    ("excuse me", "excusez-moi"),
    ("please", "s il vous plait"),
    ("you're welcome", "de rien"),
    ("how much does it cost", "combien ca coute"),
    ("where is the bathroom", "ou sont les toilettes"),
    ("i need a doctor", "j ai besoin d un medecin"),
    ("i am lost", "je suis perdu"),
    ("can you repeat that", "pouvez vous repeter ca"),
    ("speak slower please", "parlez plus lentement s il vous plait"),
    ("write it down", "ecrivez le"),
    ("what is this", "qu est ce que c est"),
    ("how do you say this in french", "comment dit on ca en francais"),
    ("i understand", "je comprends"),
    ("i don't understand french", "je ne comprends pas le francais"),
    ("it's too expensive", "c est trop cher"),
    ("i'll take it", "je le prends"),
    ("where is the exit", "ou est la sortie"),
    ("where is the entrance", "ou est l entree"),
    ("is there a restaurant nearby", "y a t il un restaurant a proximite"),
    ("i want to eat", "je veux manger"),
    ("i want to drink", "je veux boire"),
    ("the bill please", "l addition s il vous plait"),
    ("it was delicious", "c etait delicieux"),
    ("i would like water", "je voudrais de l eau"),
    ("i would like a coffee", "je voudrais un cafe"),
    ("i would like a beer", "je voudrais une biere"),
    ("do you have a table for two", "avez vous une table pour deux"),
    ("i have a reservation", "j ai une reservation"),
    ("what is the weather like", "quel temps fait il"),
    ("it is sunny", "il fait soleil"),
    ("it is cloudy", "il fait nuageux"),
    ("it is cold", "il fait froid"),
    ("it is hot", "il fait chaud"),
    ("it is windy", "il fait du vent"),
    ("what day is it today", "quel jour sommes nous aujourd hui"),
    ("today is monday", "aujourd hui c est lundi"),
    ("tomorrow is tuesday", "demain c est mardi"),
    ("yesterday was sunday", "hier c etait dimanche"),
    ("see you soon", "a bientot"),
    ("have a good trip", "bon voyage"),
    ("be careful", "fais attention"),
    ("no problem", "pas de probleme"),
    ("i am busy", "je suis occupe"),
    ("i am happy", "je suis content"),
    ("i am sad", "je suis triste"),
    ("i am bored", "je m ennuie"),
    ("i am excited", "je suis excite"),
    ("i am sick", "je suis malade"),
    ("i have a headache", "j ai mal a la tete"),
    ("i have a stomach ache", "j ai mal au ventre"),
    ("i feel good", "je me sens bien"),
    ("i feel bad", "je me sens mal"),
    ("what time do you open", "a quelle heure ouvrez vous"),
    ("what time do you close", "a quelle heure fermez vous"),
    ("is it open", "est ce ouvert"),
    ("is it closed", "est ce ferme"),
    ("can i pay by card", "puis je payer par carte"),
    ("can i pay cash", "puis je payer en especes"),
    ("where is the bank", "ou est la banque"),
    ("where is the post office", "ou est la poste"),
    ("how far is it", "a quelle distance est ce"),
    ("it is far", "c est loin"),
    ("it is near", "c est pres"),
    ("turn left", "tournez a gauche"),
    ("turn right", "tournez a droite"),
    ("go straight ahead", "allez tout droit"),
    ("stop here", "arretez vous ici"),
    ("take me to this address", "emmenez moi a cette adresse"),
    ("i want a ticket to paris", "je veux un billet pour paris"),
    ("one way or round trip", "aller simple ou aller retour"),
    ("how long does it take", "combien de temps ca prend"),
    ("when does the train leave", "quand part le train"),
    ("when does the bus arrive", "quand arrive le bus"),
    ("i am here on vacation", "je suis ici en vacances"),
    ("i am here for work", "je suis ici pour le travail"),
    ("i like france", "j aime la france"),
    ("i don't like it", "je n aime pas ca"),
    ("can i try it on", "puis je l essayer"),
    ("what size is this", "quelle taille est ce"),
    ("do you have a bigger size", "avez vous une taille plus grande"),
    ("do you have a smaller size", "avez vous une taille plus petite"),
    ("i need help with my luggage", "j ai besoin d aide avec mes bagages"),
    ("where is the information desk", "ou est le bureau d information"),
    ("what's your phone number", "quel est votre numero de telephone"),
    ("what's your email address", "quelle est votre adresse e-mail"),
    ("can i call you", "puis je vous appeler"),
    ("please wait", "veuillez patienter"),
    ("come in", "entrez"),
    ("sit down", "asseyez vous"),
    ("stand up", "levez vous"),
    ("listen to me", "ecoutez moi"),
    ("look at this", "regardez ca"),
    ("i am learning a lot", "j apprends beaucoup"),
    ("it is difficult", "c est difficile"),
    ("it is easy", "c est facile"),
    ("it is very interesting", "c est tres interessant"),
    ("i need more practice", "j ai besoin de plus de pratique"),
    ("what are you doing", "que faites vous"),
    ("i am reading a book", "je lis un livre"),
    ("i am watching tv", "je regarde la tele"),
    ("i am listening to music", "j ecoute de la musique"),
    ("i am cooking", "je cuisine"),
    ("i am working", "je travaille"),
    ("i am studying", "j etudie"),
    ("i am going home", "je rentre a la maison"),
    ("i am going to bed", "je vais me coucher"),
    ("i am waking up", "je me reveille"),
    ("have a good meal", "bon appetit"),
    ("cheers", "sante"),
    ("happy birthday", "joyeux anniversaire"),
    ("merry christmas", "joyeux noel"),
    ("happy new year", "bonne annee"),
    ("congratulations", "felicitations"),
    ("good luck", "bonne chance"),
    ("i am sorry", "je suis desole"),
    ("it's okay", "c est bon"),
    ("never mind", "laisse tomber"),
    ("i totally agree", "je suis entierement d accord"),
    ("i think so", "je pense que oui"),
    ("i don't think so", "je ne pense pas"),
    ("it's important", "c est important"),
    ("it's urgent", "c est urgent"),
    ("i need help with this exercise", "j ai besoin d aide pour cet exercice"),
    ("this is a challenging problem", "c est un probleme difficile"),
    ("we need to optimize the code", "nous devons optimiser le code"),
    ("the algorithm is complex", "l algorithme est complexe"),
    ("what is a neural network", "qu est ce qu un reseau de neurones"),
    ("how does backpropagation work", "comment fonctionne la retropropagation"),
    ("we are collecting more data", "nous collectons plus de donnees"),
    ("the training loss is decreasing", "la perte d entrainement diminue"),
    ("the validation accuracy is stable", "la precision de validation est stable"),
    ("we need to adjust the hyperparameters", "nous devons ajuster les hyperparametres"),
    ("this model is production-ready", "ce modele est pret pour la production"),
    ("data preprocessing is crucial", "le pre-traitement des donnees est crucial"),
    ("feature engineering is important", "l ingenierie des caracteristiques est importante"),
    ("we are debugging the script", "nous deboguons le script"),
    ("what is the learning rate", "quel est le taux d apprentissage"),
    ("gradient descent is an optimization algorithm", "la descente de gradient est un algorithme d optimisation"),
    ("we use GPUs for faster training", "nous utilisons des GPU pour un entrainement plus rapide"),
    ("this is a classification task", "c est une tache de classification"),
    ("this is a regression task", "c est une tache de regression"),
    ("we need to fine-tune the model", "nous devons affiner le modele"),
    ("transfer learning is effective", "l apprentissage par transfert est efficace"),
    ("explain the attention mechanism", "expliquez le mecanisme d attention"),
    ("what are transformers in nlp", "que sont les transformers en tnl"),
    ("the model predicts the next word", "le modele predit le mot suivant"),
    ("it's an end-to-end system", "c est un systeme de bout en bout"),
    ("we are evaluating the performance", "nous evaluons la performance"),
    ("the results are promising", "les resultats sont prometteurs"),
    ("we need to document the code", "nous devons documenter le code"),
    ("version control is essential", "le controle de version est essentiel"),
    ("what is your favorite programming language", "quel est votre langage de programmation prefere"),
    ("i prefer python", "je prefere python"),
    ("this is a good example", "c est un bon exemple"),
    ("it's a difficult question", "c est une question difficile"),
    ("i'm thinking about it", "j y reflechis"),
    ("can you explain more", "pouvez vous expliquer plus"),
    ("i agree with you", "je suis d accord avec vous"),
    ("i understand what you mean", "je comprends ce que vous voulez dire"),
    ("how was your day", "comment etait ta journee"),
    ("it was good", "c etait bien"),
    ("it was bad", "c etait mauvais"),
    ("i had a busy day", "j ai eu une journee occupee"),
    ("what are your hobbies", "quels sont vos loisirs"),
    ("i like to travel", "j aime voyager"),
    ("i like to read", "j aime lire"),
    ("i like to cook", "j aime cuisiner"),
    ("what is your favorite food", "quel est votre plat prefere"),
    ("i like french food", "j aime la cuisine francaise"),
    ("can you recommend a good book", "pouvez vous me recommander un bon livre"),
    ("i will try my best", "je ferai de mon mieux"),
    ("i hope so", "j espere que oui"),
    ("i hope not", "j espere que non"),
    ("it's a pleasure", "c est un plaisir"),
    ("take care", "prends soin de toi"),
    ("what's new", "quoi de neuf"),
    ("nothing much", "pas grand chose"),
    ("do you have any questions", "avez vous des questions"),
    ("i have no questions", "je n ai pas de questions"),
    ("thank you for your time", "merci pour votre temps"),
    ("see you tomorrow", "a demain"),
    ("have a good weekend", "bon weekend"),
    ("i want to learn more", "je veux en savoir plus"),
    ("this is very useful", "c est tres utile"),
    ("can you show me", "pouvez vous me montrer"),
    ("i am sure", "je suis sur"),
    ("i am not sure", "je ne suis pas sur"),
    ("it's not fair", "ce n est pas juste"),
    ("it's wonderful", "c est merveilleux"),
    ("it's terrible", "c est terrible"),
    ("i need a break", "j ai besoin d une pause"),
    ("let's take a break", "faisons une pause"),
    ("what's the problem", "quel est le probleme"),
    ("there is no problem", "il n y a pas de probleme"),
    ("i can't find it", "je ne le trouve pas"),
    ("i found it", "je l ai trouve"),
    ("it's getting late", "il se fait tard"),
    ("i must go now", "je dois partir maintenant"),
    ("this is too much", "c est trop"),
    ("this is not enough", "ce n est pas assez"),
    ("i'm ready", "je suis pret"),
    ("are you ready", "etes vous pret"),
    ("i'm waiting for you", "je vous attends"),
    ("don't worry", "ne t inquiete pas"),
    ("it's going to be okay", "ca va aller"),
    ("what's your opinion", "quel est votre avis"),
    ("in my opinion", "a mon avis"),
    ("i think that", "je pense que"),
    ("it seems that", "il semble que"),
    ("i would like to know", "je voudrais savoir"),
    ("can you explain to me", "pouvez vous m expliquer"),
    ("i'm trying to learn", "j essaie d apprendre")
]

In [81]:
random.shuffle(raw_data_pairs)
split_idx = int(len(raw_data_pairs) * 0.9)
train_data = raw_data_pairs[:split_idx]
valid_data = raw_data_pairs[split_idx:]

print(f"Training examples: {len(train_data)}")
print(f"Validation examples: {len(valid_data)}")

Training examples: 238
Validation examples: 27


### Tokenizer
We build a simple tokenizer to introduce the concept. This tokenizer does not reduce to lemmas or perform any vector distributions for the inputs. Instead, each unique word is given a numeric representation that counts upward as new words are added.

In [82]:
PAD_TOKEN = "<pad>"
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
UNK_TOKEN = "<unk>"

class CustomTokenizer:
  def __init__(self, language_name):
    self.language_name = language_name
    self.word2index = {}
    self.index2word = {}
    self.n_count = 0
    self.word_counts = Counter()
    self.add_word(PAD_TOKEN)
    self.add_word(SOS_TOKEN)
    self.add_word(EOS_TOKEN)
    self.add_word(UNK_TOKEN)

    self.PAD_IDX = self.add_word(PAD_TOKEN)
    self.SOS_IDX = self.add_word(SOS_TOKEN)
    self.EOS_IDX = self.add_word(EOS_TOKEN)
    self.UNK_IDX = self.add_word(UNK_TOKEN)

  # a linear tokenizer (count -> index)
  def add_word(self, word):
    if word not in self.word2index:
      self.word2index[word] = self.n_count
      self.index2word[self.n_count] = word
      self.n_count += 1
    return self.word2index[word]

  def add_sentence(self, sentence):
    for word in sentence.lower().split(' '):
      self.word_counts[word] += 1

  def build_vocab(self, sentences):
    # Build up a count for each word
    for sentence in sentences:
      self.add_sentence(sentence)

    # Add each unique key (word) to the word2index / index2word dicts
    for word in sorted(self.word_counts.keys()):
      self.add_word(word)

  def sentence_to_indices(self, sentence):
    tokens = [SOS_TOKEN] + sentence.lower().split(' ') + [EOS_TOKEN]
    indices = [self.word2index.get(token, self.UNK_IDX) for token in tokens]
    return indices

  def indices_to_sentence(self, indices):
    if hasattr(indices, 'tolist'):
      indices = indices.tolist()
    return ' '.join(self.index2word.get(index, UNK_TOKEN) for index in indices
                    if index not in [self.SOS_IDX, self.EOS_IDX, self.PAD_IDX])


Create the tokenizers and input vocabularies

In [83]:
src_tokenizer = CustomTokenizer(SRC_LANGUAGE)
tgt_tokenizer = CustomTokenizer(TGT_LANGUAGE)

src_sentences = [pair[0] for pair in train_data]
tgt_sentences = [pair[1] for pair in train_data]

src_tokenizer.build_vocab(src_sentences)
tgt_tokenizer.build_vocab(tgt_sentences)

Test their behaviour

In [84]:
# Vocabulary
print("\nSource Vocabulary (EN):")
print(src_tokenizer.word2index)
print(f"PAD_IDX: {src_tokenizer.PAD_IDX}, SOS_IDX: {src_tokenizer.SOS_IDX},"
      f"EOS_IDX: {src_tokenizer.EOS_IDX}, UNK_IDX: {src_tokenizer.UNK_IDX}")

print(f"\nTarget Vocabulary (FR)")
print(tgt_tokenizer.word2index)
print(f"PAD_IDX: {tgt_tokenizer.PAD_IDX}, SOS_IDX: {tgt_tokenizer.SOS_IDX}"
      f"EOS_IDX: {tgt_tokenizer.EOS_IDX}, UNK_IDX: {tgt_tokenizer.UNK_IDX}")

# Test the tokenizer
test_src_sent = "hello world"
test_src_indices = src_tokenizer.sentence_to_indices(test_src_sent)
print(f"\n'{test_src_sent}' -> {test_src_indices}")
print(f"'{test_src_indices}' -> '{src_tokenizer.indices_to_sentence(test_src_indices)}'\n")

test_tgt_sent = "bonjour le monde"
test_tgt_indices = tgt_tokenizer.sentence_to_indices(test_tgt_sent)
print(f"'{test_tgt_sent}' -> {test_tgt_indices}")
print(f"'{test_tgt_indices}' -> '{tgt_tokenizer.indices_to_sentence(test_tgt_indices)}'")


Source Vocabulary (EN):
{'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3, 'a': 4, 'about': 5, 'accuracy': 6, 'ache': 7, 'address': 8, 'adjust': 9, 'afternoon': 10, 'agree': 11, 'algorithm': 12, 'am': 13, 'an': 14, 'any': 15, 'are': 16, 'arrive': 17, 'assistance': 18, 'at': 19, 'attention': 20, 'backpropagation': 21, 'bad': 22, 'bank': 23, 'bathroom': 24, 'be': 25, 'bed': 26, 'beer': 27, 'best': 28, 'bigger': 29, 'bill': 30, 'birthday': 31, 'book': 32, 'bored': 33, 'break': 34, 'bus': 35, 'busy': 36, 'by': 37, 'call': 38, 'can': 39, "can't": 40, 'canada': 41, 'card': 42, 'care': 43, 'careful': 44, 'cash': 45, 'cat': 46, 'challenging': 47, 'cheers': 48, 'christmas': 49, 'classification': 50, 'close': 51, 'closed': 52, 'cloudy': 53, 'code': 54, 'coffee': 55, 'cold': 56, 'come': 57, 'complex': 58, 'control': 59, 'cook': 60, 'cooking': 61, 'cool': 62, 'cost': 63, 'crucial': 64, 'data': 65, 'dataset': 66, 'day': 67, 'debugging': 68, 'delicious': 69, 'descent': 70, 'desk': 71, 'difficult': 72

### Padding
We need to pad so that different length sentences can be accepted by the RNN in batches. We use padding tokens to achieve this.

In [85]:
def collate_fn(batch, src_tokenizer, tgt_tokenizer, device):
  src_batch, tgt_batch = [], []
  src_lens, tgt_lens = [], []
  for src_sample, tgt_sample in batch:
    src_indices = src_tokenizer.sentence_to_indices(src_sample)
    tgt_indices = tgt_tokenizer.sentence_to_indices(tgt_sample)

    src_batch.append(torch.tensor(src_indices, dtype=torch.long))
    tgt_batch.append(torch.tensor(tgt_indices, dtype=torch.long))

    src_lens.append(len(src_indices))
    tgt_lens.append(len(tgt_indices))

  src_padded = nn.utils.rnn.pad_sequence(src_batch, padding_value=src_tokenizer.PAD_IDX, batch_first=False)
  tgt_padded = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=tgt_tokenizer.PAD_IDX, batch_first=False)

  return src_padded.to(device), tgt_padded.to(device), torch.tensor(src_lens), torch.tensor(tgt_lens)

Create a sample dataloader

In [86]:
BATCH_SIZE = 2
def get_data_iterator(data, src_tokenizer, tgt_tokenizer, batch_size, device, shuffle=True):
  if shuffle:
    data_copy = list(data)
    random.shuffle(data_copy)
  else:
    data_copy = data

  for i in range(0, len(data_copy), batch_size):
    batch = data_copy[i:i+batch_size]
    yield collate_fn(batch, src_tokenizer, tgt_tokenizer, device)

In [87]:
print("\nTesting data iterator:")
data_iter = get_data_iterator(train_data,src_tokenizer, tgt_tokenizer, BATCH_SIZE, device)
for i, (src_batch, tgt_batch, src_lens, tgt_lens) in enumerate(data_iter):
  print(f"Batch {i+1}:")
  print("Source batch shape: ", src_batch.shape)
  print("target batch shape: ", tgt_batch.shape)
  print("Source lengths: ", src_lens)
  print("Target lengths: ", tgt_lens)
  print("Source batch (first example):\n", src_batch[:, 0])
  print("Target batch (first example):\n", tgt_batch[:, 0])
  if i == 0: break


Testing data iterator:
Batch 1:
Source batch shape:  torch.Size([6, 2])
target batch shape:  torch.Size([8, 2])
Source lengths:  tensor([6, 6])
Target lengths:  tensor([8, 8])
Source batch (first example):
 tensor([  1, 137, 181,   4,  34,   2])
Target batch (first example):
 tensor([  1, 174,   9,  44,  86, 342, 241,   2])


### Encoder

We'll write the encoder using an nn.GRU unit. The encoder unit has several stages to its learning:

* Timestep states *h* are computed as h$_{t}$ = f(h$_{t-1}$, x$_{t}$).
* The final state h$_{t}$ is the context variable.
* The context variable is given by some mapping *m* such that c = m(h$_{1}$, h$_{2}$, ... , h$_{t}$)
* Encoders may be bidirectional, such that h$_{t}$ is a function of h$_{t-1}$ and h$_{t+1}$

We build the unidirectional encoder here.

Note the dimensionality of the various matrices.

In [88]:
class Encoder(nn.Module):
  def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout_p):
    super().__init__()
    self.hid_dim = hid_dim
    self.n_layers = n_layers
    # Here is the interesting bit: input dim is the size of your vocabulary.
    # emb_dim is the arbitrary learning layer - you can select a size and experiment.
    # We therefore train the rnn to work at a token level
    self.embedding = nn.Embedding(input_dim, emb_dim)
    self.rnn = nn.GRU(emb_dim,
                      hid_dim,
                      n_layers,
                      bidirectional=bidirectional_gru,
                      dropout=dropout_p if n_layers > 1 else 0)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, src_seq):
    embedded = self.dropout(self.embedding(src_seq))
    outputs, hidden = self.rnn(embedded)
    return outputs, hidden

### Decoder

The decoder takes the context variable from the encoder and creates its own hidden state. This hidden state is not only a function of the last hidden state, but also the previously decoded token.

* s$_{t'}$ = g(s$_{t-1}$, y$_{t'-1}$, c)
* y$_{t'}$ is a probability distribution of P(y$_{t'}$| y$_{t-1}$, ..., y$_{1}$, c) = softmax(s$_{t-1}$, y$_{t'-1}$, c).
* In a sense, this means that the last N output tokens from the decoder influence the latest token and the current hidden state.

In [89]:
class Decoder(nn.Module):
  def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout_p):
    super().__init__()
    self.output_dim = output_dim
    self.hid_dim = hid_dim
    self.n_layers = n_layers

    # Components
    self.embedding = nn.Embedding(output_dim, emb_dim)
    self.rnn = nn.GRU(emb_dim,
                      hid_dim,
                      n_layers,
                      dropout=dropout_p if n_layers > 1 else 0)
    self.fc_out = nn.Linear(hid_dim, output_dim)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, input_token, hidden_state):
    # turns [batch_size] into [1, batch_size]
    input_token = input_token.unsqueeze(0)
    embedded = self.dropout(self.embedding(input_token))

    # Per token decoding
    output, new_hidden_state = self.rnn(embedded, hidden_state)
    prediction = self.fc_out(output.squeeze(0))

    return prediction, new_hidden_state

### Seq2Seq Implementation

The Seq2Seq component implements and handles the encoder/decoder architecture.

In [90]:
def resize_bidirectional_hidden(n_layers, batch_size, hid, enc_hidden, bridge):
  enc_hidden = enc_hidden.view(n_layers, 2, batch_size, hid)
  cat = torch.cat((enc_hidden[:, 0, :, :], enc_hidden[:, 1, :, :]), dim=2)
  return torch.tanh(bridge(cat))

class Seq2Seq(nn.Module):
  def __init__(self, encoder, decoder, device):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.device = device
    self.bridge = nn.Linear(encoder.hid_dim * 2, decoder.hid_dim)

    # Sanity check for dimensionality
    assert encoder.hid_dim == decoder.hid_dim, "hidden dims must be equal"
    assert encoder.n_layers == decoder.n_layers, "layers must be equal"

  def _init_decoder_hidden(self, enc_hidden):
    n_layers, batch_size, hid = self.encoder.n_layers, enc_hidden.size(1), self.encoder.hid_dim
    return resize_bidirectional_hidden(n_layers, batch_size, hid, enc_hidden, self.bridge)

  def forward(self, src_seq, tgt_seq, teacher_forcing_ratio=0.5):
    batch_size = src_seq.shape[1]
    tgt_len = tgt_seq.shape[0]
    tgt_vocab_size = self.decoder.output_dim
    outputs = torch.zeros(tgt_len, batch_size, tgt_vocab_size)

    # Encode
    enc_out, hidden = self.encoder(src_seq)
    if (bidirectional_gru):
      hidden = self._init_decoder_hidden(hidden)

    # Decode
    dec_in = tgt_seq[0, :]
    for t in range(1, tgt_len):
      dec_out, hidden = self.decoder(dec_in, hidden)
      outputs[t] = dec_out
      teacher_force = random.random() < teacher_forcing_ratio
      top1 = dec_out.argmax(1)
      dec_in = tgt_seq[t] if teacher_force else top1

    return outputs

### Training

In [91]:
# hyperparams
INPUT_DIM = src_tokenizer.n_count
OUTPUT_DIM = tgt_tokenizer.n_count
ENC_EMB_DIM = 128
DEC_EMB_DIM = 128
HID_DIM = 256
N_LAYERS = 2
ENC_DROPOUT = 0.2
DEC_DROPOUT = 0.2
LEARNING_RATE = 0.001
N_EPOCHS = 5
CLIP = 1
BIDIR_MODEL_NAME = "language_enc_dec_bidir.pt"
UNIDIR_MODEL_NAME = "language_enc_dec.pt"
MODEL_NAME = BIDIR_MODEL_NAME if bidirectional_gru else UNIDIR_MODEL_NAME

In [92]:
# components
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT).to(device)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT).to(device)
model_bidir = Seq2Seq(enc, dec, device)

In [93]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model_bidir):,} trainable parameters')

The model has 2,781,806 trainable parameters


In [94]:
# optim and learn
optimizer = optim.Adam(model_bidir.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = tgt_tokenizer.PAD_IDX)

In [95]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [96]:
def train_epoch(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    batch_n = 0
    for i, (src, tgt, _, _) in enumerate(iterator): # src_lens, tgt_lens not used directly here
        batch_n += 1
        optimizer.zero_grad()
        # output = [tgt_len, batch_size, output_vocab_size]
        output = model(src, tgt)
        # get vocab length for next step
        output_dim = output.shape[-1]
        # remove <sos> tag by enforcing [1:]
        # turn [tgt_len, batch, vocab] into [(tgt_len-1 * batch), vocab] so it fits into loss
        output_flat = output[1:].view(-1, output_dim)
        # since we know the vocab, this doesn't have the V dimension [tgt_len-1, batch]
        tgt_flat = tgt[1:].view(-1)
        # now that they are equal dim, compute the loss between out and tgt
        loss = criterion(output_flat, tgt_flat)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / batch_n

def evaluate_epoch(model, iterator, criterion):
  # the same setup, but we loop without adam use
  model.eval()
  epoch_loss = 0
  batch_n = 0
  with torch.no_grad():
    for i, (src, tgt, _, _) in enumerate(iterator):
      batch_n += 1
      output = model(src, tgt, 0)
      output_dim = output.shape[-1]
      output_flat = output[1:].view(-1, output_dim)
      tgt_flat = tgt[1:].view(-1)
      loss = criterion(output_flat, tgt_flat)
      epoch_loss += loss.item()

  return epoch_loss / batch_n

### Train Execute

In [97]:
best_valid_loss = float('inf')

print("Starting training...")
for epoch in range(N_EPOCHS):
  start_time = time.time()
  train_iter = get_data_iterator(train_data,
                               src_tokenizer,
                               tgt_tokenizer,
                               BATCH_SIZE,
                               device,
                               shuffle=True)
  valid_iter = get_data_iterator(valid_data,
                               src_tokenizer,
                               tgt_tokenizer,
                               BATCH_SIZE,
                               device,
                               shuffle=False)
  train_loss = train_epoch(model_bidir, train_iter, optimizer, criterion, CLIP)
  valid_loss = evaluate_epoch(model_bidir, valid_iter, criterion)

  end_time = time.time()
  epoch_mins, epoch_secs = epoch_time(start_time, end_time)
  print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
  print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
  print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

torch.save(model_bidir.state_dict(), MODEL_NAME)

Starting training...
Epoch: 01 | Time: 0m 9s
	Train Loss: 4.659 | Train PPL: 105.531
	 Val. Loss: 4.530 |  Val. PPL:  92.789
Epoch: 02 | Time: 0m 9s
	Train Loss: 3.890 | Train PPL:  48.931
	 Val. Loss: 4.623 |  Val. PPL: 101.823
Epoch: 03 | Time: 0m 8s
	Train Loss: 3.484 | Train PPL:  32.587
	 Val. Loss: 4.659 |  Val. PPL: 105.556
Epoch: 04 | Time: 0m 8s
	Train Loss: 3.111 | Train PPL:  22.443
	 Val. Loss: 4.803 |  Val. PPL: 121.906
Epoch: 05 | Time: 0m 9s
	Train Loss: 2.774 | Train PPL:  16.030
	 Val. Loss: 4.897 |  Val. PPL: 133.900


### Test the Model

In [98]:
model_bidir.load_state_dict(torch.load(MODEL_NAME, weights_only=True))
model_bidir.eval()

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(334, 128)
    (rnn): GRU(128, 256, num_layers=2, dropout=0.2, bidirectional=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (decoder): Decoder(
    (embedding): Embedding(366, 128)
    (rnn): GRU(128, 256, num_layers=2, dropout=0.2)
    (fc_out): Linear(in_features=256, out_features=366, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (bridge): Linear(in_features=512, out_features=256, bias=True)
)

In [99]:
def translate_sentence(sentence, src_tokenizer, tgt_tokenizer, model, device, max_len=50):
    model.eval()

    # Tokenize
    if isinstance(sentence, str):
        tokens = src_tokenizer.sentence_to_indices(sentence)
    else:
        tokens = sentence

    # Shape: [src_len, 1]
    src_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)

    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor)
        # Use model's bridge logic to reconcile bidirectional hidden state
        if (bidirectional_gru):
          hidden = model._init_decoder_hidden(hidden)

    tgt_indices = [tgt_tokenizer.SOS_IDX]
    for _ in range(max_len):
        # Use the last generated token as input
        tgt_tensor = torch.tensor([tgt_indices[-1]], dtype=torch.long).to(device)

        with torch.no_grad():
            output, hidden = model.decoder(tgt_tensor, hidden)

        pred_token = output.argmax(1).item()
        tgt_indices.append(pred_token)

        if pred_token == tgt_tokenizer.EOS_IDX:
            break

    translated_sentence = tgt_tokenizer.indices_to_sentence(tgt_indices)
    return translated_sentence

Implement BLEU to test accuracy

In [100]:
import math
from collections import Counter
from typing import List, Tuple

def ngrams(seq: List[str], n: int) -> List[tuple[str, ...]]:
  return [tuple(seq[i:i+n]) for i in range(len(seq) - n + 1)]

def modified_precision(candidate: List[str],
                       references: List[List[str]],
                       n: int):
  cand_ngrams = Counter(ngrams(candidate, n))
  max_reference_counts = Counter()

  for ref in references:
    ref_counts = Counter(ngrams(ref, n))
    for ngram, count in ref_counts.items():
      max_reference_counts[ngram] = max(max_reference_counts[ngram], count)

  clipped_counts = {ngram: min(count, max_reference_counts[ngram])
    for ngram, count in cand_ngrams.items()}

  numerator = sum(clipped_counts.values())
  denominator = sum(cand_ngrams.values())

  return numerator, denominator

def brevity_penalty(c: int, r: int) -> float:
  return 1.0 if c > r else math.exp(1 - r / c)

def closest_ref_len(c: int, ref_lens: List[int]) -> int:
  return min(ref_lens, key=lambda rl: (abs(rl - c), rl))

def bleu(candidate: List[str],
         references: List[List[str]],
         max_n: int = 4) -> float:
  weights = [1/max_n] * max_n
  precisions = []

  for n in range(1, max_n+1):
    num, den = modified_precision(candidate, references, n)
    if num == 0:
      num, den = 1, 2
    precisions.append((num, den))

  geo_mean = math.exp(sum(w * math.log(num/den)
    for (num, den), w in zip(precisions, weights)))

  c = len(candidate)
  r = closest_ref_len(c, [len(r) for r in references])
  bp = brevity_penalty(c, r)
  return bp * geo_mean

def tokenize_sequence(sequence):
  return sequence.lower().split(' ')

In [101]:
# get BLEU candidates / references
references_bidir = [tokenize_sequence(text_pair[1]) for text_pair in valid_data]
candidates_bidir = [tokenize_sequence(translate_sentence(text_pair[0], src_tokenizer, tgt_tokenizer, model_bidir, device)) for text_pair in valid_data]
print(references_bidir)
print(candidates_bidir)

[['je', 'voudrais', 'savoir'], ['quels', 'sont', 'vos', 'loisirs'], ['j', 'ai', 'soif'], ['la', 'perte', 'd', 'entrainement', 'diminue'], ['il', 'semble', 'que'], ['j', 'apprends', 'le', 'francais'], ['j', 'aime', 'cette', 'chanson'], ['allez', 'tout', 'droit'], ['je', 'rentre', 'a', 'la', 'maison'], ['combien', 'de', 'temps', 'ca', 'prend'], ['je', 'me', 'reveille'], ['felicitations'], ['aujourd', 'hui', 'c', 'est', 'lundi'], ['merci'], ['les', 'resultats', 'sont', 'prometteurs'], ['ce', 'n', 'est', 'pas', 'assez'], ['a', 'plus', 'tard'], ['quel', 'est', 'votre', 'nom'], ['je', 'veux', 'un', 'billet', 'pour', 'paris'], ['nous', 'collectons', 'plus', 'de', 'donnees'], ['nous', 'avons', 'besoin', 'de', 'plus', 'de', 'donnees'], ['non', 'je', 'prefere', 'le', 'the'], ['non', 'je', 'ne', 'comprends', 'pas'], ['l', 'apprentissage', 'profond', 'est', 'interessant'], ['bonjour'], ['je', 'suis', 'perdu'], ['j', 'espere', 'que', 'non']]
[['je', 'aime', 'voyager'], ['quel', 'est', 'votre'], ['j

In [102]:
bleu_results = []
for candidate in candidates_bidir:
  bleu_results.append(bleu(candidate, references_bidir))

print(bleu_results)
sum_of_bleu = sum(bleu_results)
len_of_bleu = len(bleu_results)
print(f'The average bidir BLEU is: {sum_of_bleu / len_of_bleu}')

[0.537284965911771, 0.8408964152537145, 0.7071067811865476, 0.5, 0.5, 0.537284965911771, 0.537284965911771, 0.5, 0.7071067811865476, 0.537284965911771, 0.7071067811865476, 0.5, 0.537284965911771, 0.4518010018049224, 0.5, 0.5946035575013605, 0.5946035575013605, 0.5946035575013605, 0.5623413251903491, 0.5, 0.4728708045015879, 0.537284965911771, 0.537284965911771, 0.537284965911771, 0.5, 0.7071067811865476, 0.5623413251903491]
The average bidir BLEU is: 0.5666951257957542


In [103]:
for text_valid in valid_data:
    translation = translate_sentence(text_valid[0], src_tokenizer, tgt_tokenizer, model, device)
    print(f"Original (EN): {text_valid}")
    # Find the ground truth if available
    gt_fr = "N/A"
    for en_s, fr_s in raw_data_pairs: # search in all raw data
        if en_s == text_valid[0]:
            gt_fr = fr_s
            break
    print(f"Ground Truth (FR): {gt_fr}")
    print(f"Translated (FR): {translation}\n")

Original (EN): ('i would like to know', 'je voudrais savoir')
Ground Truth (FR): je voudrais savoir
Translated (FR): je aime voyager

Original (EN): ('what are your hobbies', 'quels sont vos loisirs')
Ground Truth (FR): quels sont vos loisirs
Translated (FR): quel est votre

Original (EN): ('i am thirsty', 'j ai soif')
Ground Truth (FR): j ai soif
Translated (FR): je suis

Original (EN): ('the training loss is decreasing', 'la perte d entrainement diminue')
Ground Truth (FR): la perte d entrainement diminue
Translated (FR): la est modele est

Original (EN): ('it seems that', 'il semble que')
Ground Truth (FR): il semble que
Translated (FR): c etait

Original (EN): ('i am learning french', 'j apprends le francais')
Ground Truth (FR): j apprends le francais
Translated (FR): je suis sens

Original (EN): ('i like this song', 'j aime cette chanson')
Ground Truth (FR): j aime cette chanson
Translated (FR): je aime voyager

Original (EN): ('go straight ahead', 'allez tout droit')
Ground Truth