In [2]:
import sys
sys.executable

'/Users/b/ml/seq2seq/venv/bin/python'

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import spacy
import datasets
import tqdm
import evaluate
import matplotlib.pyplot as pyplot
import matplotlib.ticker as ticker

In [16]:
seed = 1234

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [22]:
# dataset found at bentrevett/multi30k
dataset = datasets.load_dataset("bentrevett/multi30k")
train_data, valid_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

In [27]:
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [30]:
def tokenize_example(example, en_nlp, de_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example["en"])]
    de_tokens = [token.text for token in de_nlp.tokenizer(example["de"])]
    if lower:
        en_tokens = [token.lower() for token in en_tokens]
        de_tokens = [token.lower() for token in de_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    de_tokens = [sos_token] + de_tokens + [eos_token]
    return {"en_tokens": en_tokens, "de_tokens": de_tokens}
        

In [36]:
max_length = 1000
lower = True
sos_token = "<sos>"
eos_token = "<eos>"

fn_kwargs = {
    "en_nlp": en_nlp,
    "de_nlp": de_nlp,
    "max_length": max_length,
    "lower": lower,
    "sos_token": sos_token,
    "eos_token": eos_token,
}

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [55]:
class Vocab:
    def __init__(self, token_to_index, unk_token="<unk>"):
        self.token_to_index = token_to_index
        self.index_to_token = {idx:token for token in token_to_index.items()}
        self.unk_token = unk_token
        self.unk_index = token_to_index[unk_token]

    def __len__(self):
        return len(self.token_to_index)

    def __getitem__(self, token):
        return self.token_to_index.get(token, self.unk_index)
        
    def token_to_idx(self, token):
        return self.__getitem__(token)

    def idx_to_token(self, idx):
        return self.index_to_token.get(idx, self.unk_token)

    def set_default_index(self, index):
        if index==self.unk_index:
            return

        current_token = self.index_to_token.get(index, None)

        if self.unk_token in self.token_to_index and current_token:
            self.token_to_index[self.unk_token], self.token_to_index[current_token] = index, self.unk_index

        if self.unk_index in self.index_to_token and index in self.index_to_token:
            self.index_to_token[self.unk_index], self.index_to_token[index] = self.index_to_token[index], self.index_to_token[self.unk_index]

        self.unk_index = index
    
    def get_itos(self):
        return self.index_to_token

    def get_stoi(self):
        return self.token_to_index

    def lookup_indices(self, tokens):
        return [self.token_to_idx(token) for token in tokens]

    def lookup_tokens(self, indices):
        if torch.istensor(indices):
            indices = indices.tolist()

        return [self.idx_to_token(index) for index in indices]

        