In [None]:
import numpy as np
import pandas as pd
import spacy
import torch
from torch.utils.data import TensorDataset, DataLoader

# -Preprocess- #
train_path_en = "/content/drive/MyDrive/Colab Notebooks/multi30k/train/train.lc.norm.tok.en.txt"
train_path_de = "/content/drive/MyDrive/Colab Notebooks/multi30k/train/train.lc.norm.tok.de.txt"
test_path_en = "/content/drive/MyDrive/Colab Notebooks/multi30k/test/test_2017_flickr.lc.norm.tok.en.txt"
test_path_de = "/content/drive/MyDrive/Colab Notebooks/multi30k/test/test_2017_flickr.lc.norm.tok.de.txt"

with open(train_path_en) as en_raw_train:
    en_parsed_train = en_raw_train.readlines()
with open(train_path_de) as de_raw_train:
    de_parsed_train = de_raw_train.readlines()
with open(test_path_en) as en_raw_test:
    en_parsed_test = en_raw_test.readlines()
with open(test_path_de) as de_raw_test:
    de_parsed_test = de_raw_test.readlines()

en_train = [sent.strip().split(" ") for sent in en_parsed_train]
en_test = [sent.strip().split(" ") for sent in en_parsed_test]
de_train = [sent.strip().split(" ") for sent in de_parsed_train]
de_test = [sent.strip().split(" ") for sent in de_parsed_test]

en_index2word = ["<UNK>", "<PAD>", "<SOS>", "<EOS>"]
de_index2word = ["<UNK>", "<PAD>", "<SOS>", "<EOS>"]

for ds in [en_train, en_test]:
    for sent in ds:
        for token in sent:
            if token not in en_index2word:
                en_index2word.append(token)

for ds in [de_train, de_test]:
    for sent in ds:
        for token in sent:
            if token not in de_index2word:
                de_index2word.append(token)
en_word2index = {token: idx for idx, token in enumerate(en_index2word)}
de_word2index = {token: idx for idx, token in enumerate(de_index2word)}

seq_length = 20
def build_vocab(vocab, sent, max_length):
    sos = [vocab["<SOS>"]]
    eos = [vocab["<EOS>"]]
    pad = [vocab["<PAD>"]]

    if len(sent) < max_length - 2: # -2 for SOS and EOS
        n_pads = max_length - 2 - len(sent)
        encoded = [vocab[w] for w in sent]
        return sos + encoded + eos + pad * n_pads
    else: # sent is longer than max_length; truncating
        encoded = [vocab[w] for w in sent]
        truncated = encoded[:max_length - 2]
        return sos + truncated + eos + pad * (max_length - 2 - len(sent))

en_train_dataset = [build_vocab(en_word2index, sent, seq_length) for sent in en_train]
en_test_dataset = [build_vocab(en_word2index, sent, seq_length) for sent in en_test]
de_train_dataset = [build_vocab(de_word2index, sent, seq_length) for sent in de_train]
de_test_dataset = [build_vocab(de_word2index, sent, seq_length) for sent in de_test]

def load_data(batch_size, de_train_dataset, en_train_dataset, de_test_dataset, en_test_dataset):
    train_x = np.array(de_train_dataset)
    train_y = np.array(en_train_dataset)
    test_x = np.array(de_test_dataset)
    test_y = np.array(en_test_dataset)

    train_ds = TensorDataset(torch.from_numpy(train_x), torch.from_numpy(train_y))
    test_ds = TensorDataset(torch.from_numpy(test_x), torch.from_numpy(test_y))

    train_loader = DataLoader(train_ds, shuffle=True, batch_size=batch_size, drop_last=True)
    test_loader = DataLoader(test_ds, shuffle=True, batch_size=batch_size, drop_last=True)

    return train_loader, test_loader

def load_and_preprocess_data(en_path, de_path):

    spacy_en = spacy.load('en_core_web_sm')
    spacy_de = spacy.load('de_core_news_sm')

    with open(test_path_en ) as en_raw_test:
        en_parsed_test = en_raw_test.readlines()
    with open(test_path_de) as de_raw_test:
        de_parsed_test = de_raw_test.readlines()

    en_tokenized = [tokenize_en(line, spacy_en) for line in en_parsed_test]
    de_tokenized = [tokenize_de(line, spacy_de) for line in de_parsed_test]

    test_df = pd.DataFrame({'en': en_tokenized, 'de': de_tokenized})

    return test_df

def tokenize_de(text, spacy_model):
    return [token.text for token in spacy_model(text)]

def tokenize_en(text, spacy_model):
    return [token.text for token in spacy_model(text)]

test = load_and_preprocess_data(test_path_en, test_path_de)