In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy
import datasets
#import torchtext
import tqdm
import evaluate

In [18]:
seed = 1234

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [20]:
dataset = datasets.load_dataset("bentrevett/multi30k")

In [22]:
train_data, valid_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

In [24]:
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [26]:
def tokenize_example(example, en_nlp, de_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example["en"])][:max_length]
    de_tokens = [token.text for token in de_nlp.tokenizer(example["de"])][:max_length]
    if lower:
        en_tokens = [token.lower() for token in en_tokens]
        de_tokens = [token.lower() for token in de_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    de_tokens = [sos_token] + de_tokens + [eos_token]
    return {"en_tokens": en_tokens, "de_tokens": de_tokens}

In [29]:
max_length = 1000
lower = True
sos_token = "<sos>"
eos_token = "<eos>"

fn_kwargs = {
    "en_nlp": en_nlp,
    "de_nlp": de_nlp,
    "max_length": max_length,
    "lower": lower,
    "sos_token": sos_token,
    "eos_token": eos_token,
}

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [65]:
from collections import Counter, defaultdict

class Vocab:
    def __init__(self, token_to_index, unk_token="<unk>"):
        self.token_to_index = token_to_index
        self.index_to_token = {idx: token for token, idx in token_to_index.items()}
        self.unk_token = unk_token
        self.unk_index = token_to_index[unk_token]

    def __len__(self):
        return len(self.token_to_index)

    def __getitem__(self, token):
        return self.token_to_index.get(token, self.unk_index)

    def token_to_idx(self, token):
        return self.__getitem__(token)

    def idx_to_token(self, idx):
        return self.index_to_token.get(idx, self.unk_token)

    def get_itos(self):
        max_index = max(self.index_to_token.keys())
        itos = [self.index_to_token.get(i, self.unk_token) for i in range(max_index)]
        return itos

    def get_stoi(self):
        return self.token_to_index

    def set_default_index(self, unk_index):
        """Sets the default index for unknown tokens."""
        self.unk_index = unk_index
        # Ensure the unk_token is mapped to unk_index
        self.token_to_index[self.unk_token] = self.unk_index
        # Ensure the index_to_token mapping is consistent
        self.index_to_token[self.unk_index] = self.unk_token

    
    def lookup_indices(self, tokens):
        return [self.token_to_idx(token) for token in tokens]

    def lookup_tokens(self, indices):
        if torch.is_tensor(indices):
            indices = indices.tolist()
        return [self.idx_to_token(index) for index in indices]

def build_vocab_from_iterator(iterator, min_freq=1, specials=None):
    counter = Counter()
    
    for tokens in iterator:
        counter.update(tokens)

    token_to_index = {}
    if specials:
        for idx, token in enumerate(specials):
            token_to_index[token] = idx

    for token, freq in counter.items():
        if freq >= min_freq and token not in token_to_index:
            token_to_index[token] = len(token_to_index)

    unk_token = specials[0] if specials else "<unk>"
    if unk_token not in token_to_index:
        token_to_index[unk_token] = len(token_to_index)

    return Vocab(token_to_index, unk_token)
    



In [66]:
# build the vocab
min_freq = 2
unk_token = "<unk>"
pad_token = "<pad>"

special_tokens = [unk_token, pad_token, sos_token, eos_token]

en_vocab = build_vocab_from_iterator(
    train_data["en_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
)
de_vocab = build_vocab_from_iterator(
    train_data["de_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
)


In [67]:
assert en_vocab[unk_token] == de_vocab[unk_token]
assert en_vocab[pad_token] == de_vocab[pad_token]

unk_index = en_vocab[unk_token]
pad_index = en_vocab[pad_token]


In [68]:
en_vocab.set_default_index(unk_index)
de_vocab.set_default_index(unk_index)

In [69]:
def numericalize_example(example, en_vocab, de_vocab):
    en_ids = en_vocab.lookup_indices(example["en_tokens"])
    de_ids = de_vocab.lookup_indices(example["de_tokens"])
    return {"en_ids": en_ids, "de_ids": de_ids}

In [70]:
fn_kwargs = {"en_vocab": en_vocab, "de_vocab": de_vocab}

train_data = train_data.map(numericalize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(numericalize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(numericalize_example, fn_kwargs=fn_kwargs)


Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [72]:
data_type = "torch"
format_columns = ["en_ids", "de_ids"]

train_data = train_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

valid_data = valid_data.with_format(
    type = data_type,
    columns=format_columns,
    output_all_columns = True,
)

test_data = test_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)


In [74]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_de_ids = [example["de_ids"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
        batch = {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
        }
        
        return batch
        
    return collate_fn

