In [None]:
import datasets
import spacy
from collections import Counter
import torch.nn as nn

In [None]:
# Read the dataset using dataset.load_dataset()
dataset = datasets.load_dataset("data/Multi30k_HuggingFace")
train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']
train_set[0]

In [None]:
# Use tokenizer from spacy
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [None]:
def tokenize_example(example, en_nlp, de_nlp, sos_token, eos_token, is_lower):
    en_tokens = [token.text.lower() if is_lower else token.text for token in en_nlp.tokenizer(example['en'])]
    de_tokens = [token.text.lower() if is_lower else token.text for token in de_nlp.tokenizer(example['de'])]

    return {'en_tokens': [sos_token] + en_tokens + [eos_token], # input only needs eos token
            'de_tokens': [sos_token] + de_tokens + [eos_token] # output needs both sos and eos tokens
            }


In [11]:
unk_token = '<unk>'
pad_token = '<pad>'
sos_token = '<sos>'
eos_token = '<eos>'
is_lower = True
fn_kwargs = {
    'en_nlp': en_nlp,
    'de_nlp': de_nlp,
    'sos_token': sos_token,
    'eos_token': eos_token,
    'is_lower': is_lower
}
train_set = train_set.map(tokenize_example, fn_kwargs=fn_kwargs)
train_set = val_set.map(tokenize_example, fn_kwargs=fn_kwargs)
train_set = test_set.map(tokenize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [7]:
print(train_set[0]['en'])
print(train_set[0]['en_tokens'])

A man in an orange hat starring at something.
['<sos>', 'a', 'man', 'in', 'an', 'orange', 'hat', 'starring', 'at', 'something', '.', '<eos>']


In [51]:
vocab_min_freq = 2
special_tokens = [unk_token, pad_token, sos_token, eos_token]
en_token_dict = Counter()
de_token_dict = Counter()
for sentence in train_set:
    en_token_dict.update(sentence['en_tokens'])
    de_token_dict.update(sentence['de_tokens'])

en_token_dict = {k: v for (k, v) in en_token_dict.items() if v > vocab_min_freq}
de_token_dict = {k: v for (k, v) in de_token_dict.items() if v > vocab_min_freq}

en_vocab = special_tokens + list(en_token_dict)
de_vocab = special_tokens + list(de_token_dict)

In [52]:
def lookup_tokenids(example, en_vocab, de_vocab):
    en_ids = []
    de_ids = []
    for w in example['en_tokens']:
        en_idx = en_vocab.index(w) if w in en_vocab else 0
        en_ids.append(en_idx)
        
    for w in example['de_tokens']:
        de_idx = de_vocab.index(w) if w in de_vocab else 0
        de_ids.append(de_idx)
    
    return {'en_ids': en_ids,
            'de_ids': de_ids}

In [53]:
train_set = train_set.map(lookup_tokenids, fn_kwargs={'en_vocab': en_vocab, 'de_vocab': de_vocab})

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, rnn_hidden_dim, rnn_num_layers):
        super().__init__()
        # 1 layer Embedding
        # 2 layers GRU
        # the latent space is the same as the hidden space of the last layer of the GRU
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.encoder = nn.GRU(embedding_dim, rnn_hidden_dim, num_layers=rnn_num_layers, batch_first=True, bias=True)
    
    def forward(self, x):
        x = self.embedding(x)
        # hidden state at the last layer for every word in the sequence:
        #       batch, sequence, hidden_dim
        # final hidden state at every layer
        #       layer, batch, hidden_dim
        x = self.embedding(x)
        state_sequence, state_layer = self.encoder(x)
        return state_sequence, state_layer

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, rnn_hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(embedding_dim, output_dim)
        self.decoder = nn.GRU(rnn_hidden_dim, embedding_dim)
    
    def forward(self, x):
        x = self.embedding(x)

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = Decoder()
        
    
    def forward(self, x):
