In [None]:
import datasets
import spacy
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from collections import Counter
from torch.utils.data import DataLoader

In [None]:
# Read the dataset using dataset.load_dataset()
dataset = datasets.load_dataset("data/Multi30k_HuggingFace")
train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']
train_set[0]

In [None]:
# Use tokenizer from spacy
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [None]:
# Build the token frequency dict, ignore tokens with low frequency
en_token_dict = Counter()
de_token_dict = Counter()
unk, pad, sos, eos = '<unk>', '<pad>', '<sos>', '<eos>'
special_tokens = [unk, pad, sos, eos]
min_freq = 2

for example in train_set:
    en_tokens = [token.text.lower() for token in en_nlp.tokenizer(example['en'])]
    de_tokens = [token.text.lower() for token in de_nlp.tokenizer(example['de'])]
    en_token_dict.update(en_tokens)
    de_token_dict.update(de_tokens)

# No need to keep track of the frequency
en_token_dict = [k for (k, v) in en_token_dict.items() if v >= min_freq]
en_token_dict = special_tokens + en_token_dict
en_token_dict = {value: index for (index, value) in enumerate(en_token_dict)}

de_token_dict = [k for (k, v) in de_token_dict.items() if v >= min_freq]
de_token_dict = special_tokens + de_token_dict
de_token_dict = {value: index for (index, value) in enumerate(de_token_dict)}

In [None]:
# Create token list and token IDs for each sentence in the dataset
def tokenize_example(example, en_nlp, de_nlp, sos, eos):
    en_tokens, de_tokens = [], []
    en_ids, de_ids = [], []
    for token in en_nlp.tokenizer(example['en']):
        token = token.text.lower()
        if token not in en_token_dict:
            token = unk

        en_tokens.append(token)
        en_ids.append(en_token_dict[token])
    
    # input only needs eos token
    en_tokens = en_tokens + [eos]
    en_ids = en_ids + [en_token_dict[eos]]

    for token in de_nlp.tokenizer(example['de']):
        token = token.text.lower()
        if token not in de_token_dict:
            token = unk
            
        de_tokens.append(token)
        de_ids.append(de_token_dict[token])

    # output needs both sos and eos tokens
    de_tokens = [sos] + de_tokens + [eos]
    de_ids = [de_token_dict[sos]] + de_ids + [de_token_dict[eos]] 

    example['en_tokens'] = en_tokens
    example['en_ids'] = en_ids
    example['de_tokens'] = de_tokens
    example['de_ids'] = de_ids

    return example


In [None]:
fn_kwargs = {
    'en_nlp': en_nlp,
    'de_nlp': de_nlp,
    'sos': sos,
    'eos': eos,
}
train_set = train_set.map(tokenize_example, fn_kwargs=fn_kwargs)
val_set = val_set.map(tokenize_example, fn_kwargs=fn_kwargs)
test_set = test_set.map(tokenize_example, fn_kwargs=fn_kwargs)

In [None]:
print(train_set[0]['en'])
print(train_set[0]['en_tokens'])
print(train_set[0]['en_ids'])
print(train_set[0]['de'])
print(train_set[0]['de_tokens'])
print(train_set[0]['de_ids'])

In [None]:
# Write a collate_fn to pad sequences with variable length into a batch of tensors for Dataloader
def get_collate_fn(pad_index):
    def collate_fn(batch):
        # pad each sequence in the batch to the same length
        en_ids = [sequence['en_ids'] for sequence in batch]
        de_ids = [sequence['de_ids'] for sequence in batch]
        en_ids = rnn.pad_sequence(en_ids, padding_value=pad_index)
        de_ids = rnn.pad_sequence(de_ids, padding_value=pad_index)

        for i in range(len(en_ids)-1):
            print(len(en_ids[i]) == len(en_ids[i+1]))
        
        for i in range(len(de_ids)-1):
            print(len(de_ids[i]) == len(de_ids[i+1]))

        return en_ids, de_ids


In [None]:
pad_idx = en_token_dict[pad]
collate_fn = get_collate_fn(pad_idx)
train_dl = DataLoader(train_set, collate_fn=collate_fn, batch_size=64, shuffle=True)

In [None]:
for dl in train_dl:
    en, de = dl
    print(en.size())
    print(de.size())
    break

In [None]:
class Encoder(nn.Module):
    def __init__(self, token_count, embedding_dim, rnn_hidden_dim, rnn_num_layers):
        super().__init__()
        # 1 layer Embedding
        # 2 layers GRU
        # the latent space is the same as the hidden space of the last layer of the GRU
        self.embedding = nn.Embedding(token_count, embedding_dim)
        self.encoder = nn.GRU(embedding_dim, rnn_hidden_dim, num_layers=rnn_num_layers, batch_first=True, bias=True)
    
    def forward(self, x):
        # hidden state at the last layer for every word in the sequence:
        #       batch, sequence, hidden_dim
        # final hidden state at every layer
        #       layer, batch, hidden_dim
        x = self.embedding(x)
        state_sequence, state_layer = self.encoder(x)
        return state_sequence, state_layer

In [None]:
class Decoder(nn.Module):
    def __init__(self, token_count, embedding_dim, rnn_num_layers):
        super().__init__()
        self.embedding = nn.Embedding(token_count, embedding_dim)
        self.decoder = nn.GRU(embedding_dim, token_count, num_layers=rnn_num_layers, batch_first=True, bias=True)
    
    def forward(self, x, latent):
        x = self.embedding(x)
        state_sequence, state_layer = self.decoder(x, latent)
        return state_sequence, state_layer


In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, token_count, embedding_dim, rnn_hidden_dim, rnn_num_layers):
        super().__init__()
        self.encoder = Encoder(token_count, embedding_dim, rnn_hidden_dim, rnn_num_layers)
        self.decoder = Decoder(token_count, embedding_dim, rnn_num_layers)
    
    def forward(self, en, de):
        z, _ = self.encoder(en)
        z = z[:,-1,:]
        self.decoder()


