In [1]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchtext import data, datasets
import spacy
from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

dev = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
tokenize_en = data.get_tokenizer("spacy", language='en_core_web_sm')
tokenize_de = data.get_tokenizer("spacy", language='de_core_news_sm')

src = data.Field(tokenize_en)
tgt = data.Field(tokenize_de)

train, val, test = datasets.Multi30k.splits(
    ('.en', '.de'), fields=(src, tgt) , root='./downloads')

src_list, trg_list = [], []
for dt_pnt in train:
    src_list.append(dt_pnt.src)
    trg_list.append(dt_pnt.trg)

train.fields['src'].build_vocab(src_list)
train.fields['trg'].build_vocab(trg_list)
train.fields['src'].numericalize([['hello', 'how', 'are', 'you', '<pad>']])

tensor([[6869],
        [ 898],
        [  12],
        [1751],
        [   1]])

In [3]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

In [4]:
def collate_fn(batch):
    src_list, trg_list = [], []
    for dt_pnt in batch:
        src_list.append(dt_pnt.src)
        trg_list.append(['<pad>'] + dt_pnt.trg)

    src_list = train.fields['src'].pad(src_list)
    trg_list = train.fields['trg'].pad(trg_list)

    src_list = train.fields['src'].numericalize(src_list).T.to(dev)
    trg_list = train.fields['trg'].numericalize(trg_list).T.to(dev)
    trg = trg_list[:, :-1]
    trg_y = trg_list[:,1:]

    pad = 1
    src_mask = (src_list != pad).unsqueeze(-1).unsqueeze(-1)
    trg_mask = (trg != pad).unsqueeze(-2)

    trg_mask = trg_mask & subsequent_mask(
        trg.size(-1)).type_as(trg_mask.data)
    
    return [src_list,
            trg,
            src_mask.to(dev),
            trg_mask.to(dev)[:,:,0].unsqueeze(-1).unsqueeze(-1),
            trg_y]


dl = DataLoader(
    train, shuffle=False, batch_size=8, collate_fn=collate_fn)
example =  next(iter(dl))
[a.shape for a in example]

[torch.Size([8, 14]),
 torch.Size([8, 14]),
 torch.Size([8, 14, 1, 1]),
 torch.Size([8, 14, 1, 1]),
 torch.Size([8, 14])]

In [13]:
from transformers.model import EncoderDecoder

model = EncoderDecoder(
    len(train.fields['src'].vocab),
    len(train.fields['trg'].vocab)).to(dev)
model(*example[:-1])

tensor([[[3.4382e-05, 3.8034e-05, 2.9569e-05,  ..., 2.6254e-05,
          3.0642e-05, 3.6709e-05],
         [4.2532e-05, 4.2545e-05, 2.8504e-05,  ..., 3.2837e-05,
          4.5003e-05, 3.6043e-05],
         [3.1261e-05, 3.4890e-05, 3.0717e-05,  ..., 3.6743e-05,
          3.4505e-05, 4.2707e-05],
         ...,
         [3.2411e-05, 3.5660e-05, 3.1765e-05,  ..., 4.4888e-05,
          2.8996e-05, 5.5060e-05],
         [3.1013e-05, 4.0144e-05, 3.6139e-05,  ..., 3.1309e-05,
          2.6182e-05, 5.4794e-05],
         [3.1932e-05, 4.1850e-05, 4.1147e-05,  ..., 4.3117e-05,
          2.5125e-05, 6.8875e-05]],

        [[5.5289e-05, 3.6400e-05, 3.8453e-05,  ..., 3.2903e-05,
          3.1037e-05, 4.6313e-05],
         [3.5351e-05, 4.3255e-05, 3.1099e-05,  ..., 4.2300e-05,
          3.2589e-05, 3.6187e-05],
         [4.7280e-05, 3.3941e-05, 3.2998e-05,  ..., 3.6821e-05,
          3.3285e-05, 5.4245e-05],
         ...,
         [3.8535e-05, 4.1556e-05, 3.1715e-05,  ..., 3.9869e-05,
          2.991