In [None]:
import datasets
import spacy
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F
import torch.optim as optim
from collections import Counter
from torch.utils.data import DataLoader

In [None]:
# Read the dataset using dataset.load_dataset()
dataset = datasets.load_dataset("data/Multi30k_HuggingFace")
train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']
train_set[0]

In [None]:
# Use tokenizer from spacy
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [None]:
# Build the token frequency dict, ignore tokens with low frequency
en_token_dict = Counter()
de_token_dict = Counter()
unk, pad, sos, eos = '<unk>', '<pad>', '<sos>', '<eos>'
special_tokens = [unk, pad, sos, eos]
min_freq = 2

for example in train_set:
    en_tokens = [token.text.lower() for token in en_nlp.tokenizer(example['en'])]
    de_tokens = [token.text.lower() for token in de_nlp.tokenizer(example['de'])]
    en_token_dict.update(en_tokens)
    de_token_dict.update(de_tokens)

# No need to keep track of the frequency
en_token_dict = [k for (k, v) in en_token_dict.items() if v >= min_freq]
en_token_dict = special_tokens + en_token_dict
en_token_dict = {value: index for (index, value) in enumerate(en_token_dict)}

de_token_dict = [k for (k, v) in de_token_dict.items() if v >= min_freq]
de_token_dict = special_tokens + de_token_dict
de_token_dict = {value: index for (index, value) in enumerate(de_token_dict)}

In [None]:
# Create token list and token IDs for each sentence in the dataset
def tokenize_example(example, en_nlp, de_nlp, sos, eos):
    en_tokens, de_tokens = [], []
    en_ids, de_ids = [], []
    for token in en_nlp.tokenizer(example['en']):
        token = token.text.lower()
        if token not in en_token_dict:
            token = unk

        en_tokens.append(token)
        en_ids.append(en_token_dict[token])
    
    # Just add both sos and eos
    # sos and eos tokens will be processed later in the collate_fn when merging data into batch
    en_tokens = [sos] + en_tokens + [eos]
    en_ids = [en_token_dict[sos]] + en_ids + [en_token_dict[eos]]

    for token in de_nlp.tokenizer(example['de']):
        token = token.text.lower()
        if token not in de_token_dict:
            token = unk
            
        de_tokens.append(token)
        de_ids.append(de_token_dict[token])

    de_tokens = [sos] + de_tokens + [eos]
    de_ids = [de_token_dict[sos]] + de_ids + [de_token_dict[eos]] 

    example['en_tokens'] = en_tokens
    example['en_ids'] = en_ids
    example['de_tokens'] = de_tokens
    example['de_ids'] = de_ids

    return example


In [None]:
fn_kwargs = {
    'en_nlp': en_nlp,
    'de_nlp': de_nlp,
    'sos': sos,
    'eos': eos,
}
train_set = train_set.map(tokenize_example, fn_kwargs=fn_kwargs)
val_set = val_set.map(tokenize_example, fn_kwargs=fn_kwargs)
test_set = test_set.map(tokenize_example, fn_kwargs=fn_kwargs)

In [None]:
print(train_set[0]['en'])
print(train_set[0]['en_tokens'])
print(train_set[0]['en_ids'])
print(train_set[0]['de'])
print(train_set[0]['de_tokens'])
print(train_set[0]['de_ids'])

In [None]:
# Write a collate_fn to pad sequences with variable length into a batch of tensors for Dataloader
def get_collate_fn(pad_index):
    def collate_fn(batch):
        # Encoder input: <sequence> + <eos>
        encoder_input = [torch.tensor(sequence['en_ids'][1:]) for sequence in batch]
        encoder_input = rnn.pad_sequence(encoder_input, padding_value=pad_index, batch_first=True)

        # Decode input: <sos> + <sequence>
        decoder_input = [torch.tensor(sequence['de_ids'][:-1]) for sequence in batch]
        decoder_input = rnn.pad_sequence(decoder_input, padding_value=pad_index, batch_first=True)
        
        # Decode output: <sequence> + <eos>
        decoder_output = [torch.tensor(sequence['de_ids'][1:]) for sequence in batch]
        decoder_output = rnn.pad_sequence(decoder_output, padding_value=pad_index, batch_first=True)

        return encoder_input, decoder_input, decoder_output

    return collate_fn

In [None]:
pad_idx = en_token_dict[pad]
collate_fn = get_collate_fn(pad_idx)
batch_size = 128
train_dl = DataLoader(train_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, rnn_hidden_dim, rnn_num_layers):
        super().__init__()
        # 1 layer Embedding
        # 2 layers GRU
        # the latent space is the same as the hidden space of the last layer of the GRU
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.encoder = nn.GRU(embedding_dim, rnn_hidden_dim, num_layers=rnn_num_layers, batch_first=True, bias=True)
    
    def forward(self, x):
        # hidden state at the last layer for every word in the sequence:
        #       batch, sequence, hidden_dim
        # final hidden state at every layer
        #       layer, batch, hidden_dim
        x = self.embedding(x)
        _, state_layer = self.encoder(x)
        return state_layer

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, rnn_hidden_dim, rnn_num_layers):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.decoder = nn.GRU(embedding_dim, rnn_hidden_dim, num_layers=rnn_num_layers, batch_first=True, bias=True)
        self.fc = nn.Linear(rnn_hidden_dim, output_dim)

    def forward(self, x, latent):
        x = self.embedding(x)
        # hidden state at the last layer for every word in the sequence:
        #       batch, sequence, hidden_dim
        # final hidden state at every layer
        #       layer, batch, hidden_dim
        state_sequence, _ = self.decoder(x, latent)
        return self.fc(state_sequence)


In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, encoder_input, decoder_input):
        z = self.encoder(encoder_input)
        return self.decoder(decoder_input, z)


In [None]:
input_dim = len(en_token_dict)
output_dim = len(de_token_dict)
encoder_embedding_dim = 256
decoder_embedding_dim = 256
rnn_hidden_dim = 512
rnn_num_layers = 2

encoder = Encoder(input_dim, encoder_embedding_dim, rnn_hidden_dim, rnn_num_layers)
decoder = Decoder(output_dim, decoder_embedding_dim, rnn_hidden_dim, rnn_num_layers)
seq2seq = Seq2Seq(encoder, decoder)
optimizer = optim.Adam(seq2seq.parameters(), lr=1e-3)
epochs = 20

for epoch in range(epochs):
    epoch_loss = 0
    for dl in train_dl:
        encoder_input, decoder_input, decoder_output = dl
        output = seq2seq(encoder_input, decoder_input)
        # TODO: Add ignore index to ignore pad token
        loss = F.cross_entropy(output.permute(0, 2, 1), decoder_output, reduction='mean')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    print(f"Epoch={epoch}: Loss={epoch_loss}")
