In [1]:
import torch
from torch import nn
from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

In [2]:
from pathlib import Path
from urllib.request import urlopen
import linecache
from itertools import count
import pickle

class En2DeDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform=None, download=False, train=True):
        self.path = Path(folder_path)
        self.train = train
        self.transform = transform
        self.train_en_url = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en'
        self.train_de_url = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de'

        self.test_en_url = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2015.en'
        self.test_de_url = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2015.de'

        self.length = 4_468_841 if train else 2_170
        if download:
            self.__download()
        elif train:
            with (self.path / 'vocabs').open('rb') as f:
                self.vocabs = pickle.load(f)
            
    def __download(self):
        self.path.mkdir(parents=True, exist_ok=True)
        if self.train:
            files = (('en', 'train.en', self.train_en_url),
                     ('de', 'train.de', self.train_de_url))
        else:
            files = (('en', 'test.en', self.test_en_url),
                     ('de', 'test.de', self.test_de_url))

        self.vocabs = {'en': set(), 'de': set()}
        for lang, file, url in files:
            with urlopen(url) as webfile:
                localpath = self.path / file
                if localpath.exists():
                    localpath.unlink()
                with localpath.open("wb+") as localfile:
                    for i in tqdm(range(self.length)):
                        line = webfile.readline()
                        if self.train:
                            self.vocabs[lang].update(line.decode("utf-8").casefold().split(' '))
                        localfile.write(line)
                    assert(not line)
        vocab_path = self.path / 'vocabs'
        with vocab_path.open('wb') as f:
            pickle.dump(self.vocabs, f)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        files = ('train.en', 'train.de') if self.train else ('test.en', 'test.de')
        line_path = self.path / files[0]
        label_path = self.path / files[1]
        if not line_path.exists() or not label_path.exists():
            raise FileNotFoundError('Set download to True to download the dataset')
        
        line = linecache.getline(str(line_path.absolute()), idx)
        label = linecache.getline(str(label_path.absolute()), idx)

        if self.transform:
            line = self.transform(line)
            label = self.transform(label)
        return line, label


train_dataset = En2DeDataset('./downloads', download=False, train=True)
test_dataset = En2DeDataset('./downloads', download=False, train=False)

train_dtld = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=16)
test_dtld = torch.utils.data.DataLoader(test_dataset, shuffle=True)

a = next(iter(train_dtld))

In [3]:
from transformers.model import MultiHeadAttention, Transformer

In [4]:
model = Transformer(input_vocab=train_dataset.vocabs['en'],
                    output_vocab=train_dataset.vocabs['de'])

In [5]:
len(train_dataset.vocabs['de'])

1530303

In [6]:
model('Hola').shape

torch.Size([1, 1, 512])