In [2]:
import sys
sys.executable

'/Users/b/ml/seq2seq/venv/bin/python'

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import spacy
import datasets
import tqdm
import evaluate
import matplotlib.pyplot as pyplot
import matplotlib.ticker as ticker

In [16]:
seed = 1234

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [22]:
# dataset found at bentrevett/multi30k
dataset = datasets.load_dataset("bentrevett/multi30k")
train_data, valid_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

In [27]:
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [30]:
def tokenize_example(example, en_nlp, de_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example["en"])]
    de_tokens = [token.text for token in de_nlp.tokenizer(example["de"])]
    if lower:
        en_tokens = [token.lower() for token in en_tokens]
        de_tokens = [token.lower() for token in de_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    de_tokens = [sos_token] + de_tokens + [eos_token]
    return {"en_tokens": en_tokens, "de_tokens": de_tokens}
        

In [36]:
max_length = 1000
lower = True
sos_token = "<sos>"
eos_token = "<eos>"

fn_kwargs = {
    "en_nlp": en_nlp,
    "de_nlp": de_nlp,
    "max_length": max_length,
    "lower": lower,
    "sos_token": sos_token,
    "eos_token": eos_token,
}

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [56]:
class Vocab:
    def __init__(self, token_to_index, unk_token="<unk>"):
        self.token_to_index = token_to_index
        self.index_to_token = {idx:token for token,idx in token_to_index.items()}
        self.unk_token = unk_token
        self.unk_index = self.token_to_index[unk_token]

    def __len__(self):
        return len(self.token_to_index)

    def __getitem__(self, token):
        return self.token_to_index.get(token, self.unk_index)

    def token_to_idx(self, token):
        return self.__getitem__(token)

    def idx_to_token(self, idx):
        return self.index_to_token.get(idx, self.unk_token)

    def set_default_index(self, index):
        if index==self.unk_index:
            return

        current_token = self.index_to_token(index, None)

        if self.unk_token in self.token_to_index and current_token:
            self.token_to_index[self.unk_token], self.token_to_index[current_token] = index, self.unk_index

        if self.unk_index in self.index_to_token and index in self.index_to_token:
            self.index_to_token[self.unk_index], self.index_to_token[index] = self.index_to_token[index], self.index_to_token[self.unk_index]

        self.unk_index = index

    def get_itos(self):
        return self.index_to_token

    def get_stoi(self):
        return self.token_to_index

    def lookup_indices(self, tokens):
        return [self.token_to_idx(token) for token in tokens]

    def lookup_tokens(self, indices):
        if torch.istensor(indices):
            indices = indices.tolist()
        
        return [self.idx_to_token(index) for index in indices]

        

In [65]:
def build_vocab_from_iterator(iterator, min_freq=1, specials=None):
    counter = Counter()
    for token in iterator:
        counter.update(token)

    token_to_index = {}
    if specials:
        for idx, token in enumerate(specials):
            token_to_index[token] = idx

    for token, freq in counter.items():
        if freq >= min_freq and token not in token_to_index:
            token_to_index[token] = len(token_to_index)

    unk_token = specials[0] if specials else "<unk>"
    if unk_token not in token_to_index:
        token_to_index[unk_token] = len(token_to_index)

    return Vocab(token_to_index, unk_token=unk_token)


tokens = [["death", "destruction", "domination", "weapons", "graves"], ["all", "you", "care", "about", "is", "love"]]
min_freq = 1
specials = ["<unk>", "<pad>", "<sos>", "<eos>"]
vocab = build_vocab_from_iterator(tokens, min_freq, specials)
print(vocab.get_stoi())
print(vocab.idx_to_token(6))

    

{'<unk>': 0, '<pad>': 1, '<sos>': 2, '<eos>': 3, 'death': 4, 'destruction': 5, 'domination': 6, 'weapons': 7, 'graves': 8, 'all': 9, 'you': 10, 'care': 11, 'about': 12, 'is': 13, 'love': 14}
domination


In [73]:
min_freq = 2
unk_token = "<unk>"
pad_token = "<pad>"

special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token,
]

en_vocab = build_vocab_from_iterator(
    train_data["en_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
)

de_vocab = build_vocab_from_iterator(
    train_data["de_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
)

In [77]:
assert en_vocab[unk_token] == de_vocab[unk_token]
assert de_vocab[pad_token] == en_vocab[pad_token]

unk_index = en_vocab[unk_token]
pad_index = en_vocab[pad_token]

In [79]:
en_vocab.set_default_index(unk_index)
de_vocab.set_default_index(unk_index)

In [82]:
def numericalize_example(example, en_vocab, de_vocab):
    en_ids = en_vocab.lookup_indices(example["en"])
    de_ids = de_vocab.lookup_indices(example["de"])
    return {"en_ids": en_ids, "de_ids": de_ids}

In [86]:
fn_kwargs = {"en_vocab": en_vocab, "de_vocab": de_vocab}

train_data = train_data.map(numericalize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(numericalize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(numericalize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [89]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en"] for example in batch]
        batch_de_ids = [example["de"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
        batch = {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
        }
        return batch

    return collate_fn

In [94]:
def get_data_loader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )
    return data_loader

In [96]:
batch_size = 128

train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)
vaild_data_loader = get_data_loader(valid_data, batch_size, pad_index)
test_data_loader = get_data_loader(test_data, batch_size, pad_index)

In [97]:
class Encoder(nn.Module):
    def __init__(
        self, input_dim, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, dropout
    ):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, encoder_hidden_dim, bidirectional=True)
        self.fc = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        hidden = torch.tanh(
            self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        )
        return outputs, hidden
        

In [None]:
class Encoder(nn.Module):
    def __init__(
        self, input_dim, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, dropout
    ):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, encoder_hidden_dim, bidirectional=True)
        self.fc = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = [src length, batch size]
        embedded = self.dropout(self.embedding(src))
        # embedded = [src length, batch size, embedding dim]
        outputs, hidden = self.rnn(embedded)
        # outputs = [src length, batch size, encoder hidden dim * n directions]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        # outputs are always from the last layer
        # hidden [-2, :, :] is the last of the forwards RNN
        # hidden [-1, :, :] is the last of the backwards RNN
        # initial decoder hidden is final hidden state of the forwards and backwards
        # encoder RNNs are fed through a linear layer
        hidden = torch.tanh(
            self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        )
        # outputs = [src length, batch size, encoder hidden dim * 2]
        # hidden = [batch size, decoder hidden dim]
        return outputs, hidden
        