In [135]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchtext import data, datasets
import spacy
from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

dev = 'cuda' if torch.cuda.is_available() else 'cpu'

In [144]:
tokenize_en = data.get_tokenizer("spacy", language='en_core_web_sm')
tokenize_de = data.get_tokenizer("spacy", language='de_core_news_sm')

src = data.Field(tokenize_en)
tgt = data.Field(tokenize_de)

train, val, test = datasets.Multi30k.splits(
    ('.en', '.de'), fields=(src, tgt) , root='./downloads')

src_list, trg_list = [], []
for dt_pnt in train:
    src_list.append(dt_pnt.src)
    trg_list.append(dt_pnt.trg)

train.fields['src'].build_vocab(src_list)
train.fields['trg'].build_vocab(trg_list)
train.fields['src'].numericalize([['hello', 'how', 'are', 'you', '<pad>']])

tensor([[6869],
        [ 898],
        [  12],
        [1751],
        [   1]])

In [139]:
spacy.load("de_core_news_sm").tokenizer('hola').text

'hola'

In [123]:
train.fields['src'].pad_token

'<pad>'

In [8]:
len(train.fields['src'].vocab)

15458

In [10]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

In [196]:
def collate_fn(batch):
    src_list, trg_list = [], []
    for dt_pnt in batch:
        src_list.append(dt_pnt.src)
        trg_list.append(dt_pnt.trg)

    src_list = train.fields['src'].pad(src_list)
    trg_list = train.fields['trg'].pad(trg_list)

    src_list = train.fields['src'].numericalize(src_list)
    trg_list = train.fields['trg'].numericalize(trg_list)
    pad = 1
    src_mask = (src_list != pad).unsqueeze(-2)
    trg_mask = (trg_list != pad).unsqueeze(-2)

    trg_mask = trg_mask & subsequent_mask(
        trg_list.size(-1)).type_as(trg_mask.data)
    
    return src_list, trg_list, src_mask, trg_mask


dl = DataLoader(
    train, shuffle=False, batch_size=8, collate_fn=collate_fn)
[a.shape for a in next(iter(dl))]

[torch.Size([14, 8]),
 torch.Size([14, 8]),
 torch.Size([14, 1, 8]),
 torch.Size([14, 8, 8])]

In [5]:
from transformers.model import EncoderDecoder

model = EncoderDecoder(
    len(train.fields['src'].vocab),
    len(train.fields['trg'].vocab))