In [1]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchtext import data, datasets
import spacy
from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

dev = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
tokenize_en = data.get_tokenizer("spacy", language='en_core_web_sm')
tokenize_de = data.get_tokenizer("spacy", language='de_core_news_sm')

src = data.Field(tokenize_en)
tgt = data.Field(tokenize_de)

train, val, test = datasets.Multi30k.splits(
    ('.en', '.de'), fields=(src, tgt) , root='./downloads')

src_list, trg_list = [], []
for dt_pnt in train:
    src_list.append(dt_pnt.src)
    trg_list.append(dt_pnt.trg)

specials = ['<pad>', '<s>', '</s>', "<blank>", "<unk>"]
train.fields['src'].build_vocab(src_list, specials=specials)
train.fields['trg'].build_vocab(trg_list, specials=specials)

def itos(t, field='trg'):
    s = []
    for c in t:
        s.append(train.fields[field].vocab.itos[c])
    return ' '.join(s)

In [3]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return subsequent_mask == 0

In [4]:
def collate_fn(batch):
    src_list, trg_list = [], []
    for dt_pnt in batch:
        src_list.append(['<s>'] + dt_pnt.src + ['</s>'])
        trg_list.append(['<s>'] + dt_pnt.trg + ['</s>'])

    src_list = train.fields['src'].pad(src_list)
    trg_list = train.fields['trg'].pad(trg_list)
    
    src_list = train.fields['src'].numericalize(src_list).T.to(dev)
    trg_list = train.fields['trg'].numericalize(trg_list).T.to(dev)
    trg = trg_list[:, :-1]
    trg_y = trg_list[:,1:]

    pad = int(train.fields['trg'].numericalize([['<pad>']]))
    src_mask = (src_list != pad).unsqueeze(-2).unsqueeze(-3)
    trg_mask = (trg != pad).unsqueeze(-2).unsqueeze(-2)

    trg_mask = trg_mask & subsequent_mask(
        trg.size(-1)).type_as(trg_mask.data)
        
    return {'src': src_list,
            'trg': trg,
            'src_mask': src_mask.to(dev),
            'trg_mask': trg_mask.to(dev),
            'trg_y': trg_y}


dl = DataLoader(
    train, shuffle=True, batch_size=128, collate_fn=collate_fn)
example =  next(iter(dl))
{a: example[a].shape for a in example}

{'src': torch.Size([128, 26]),
 'trg': torch.Size([128, 24]),
 'src_mask': torch.Size([128, 1, 1, 26]),
 'trg_mask': torch.Size([128, 1, 24, 24]),
 'trg_y': torch.Size([128, 24])}

In [5]:
from transformers.model import EncoderDecoder, Trainer
import os

def get_trainer():
    model = EncoderDecoder(
        len(train.fields['src'].vocab),
        len(train.fields['trg'].vocab)).to(dev)

    trainer = Trainer(model)
    try:
        chks = os.listdir('./chkpnts')
        ns = [int(chk.split('checkpnt_step-')[1].split('k.pt')[0])
             for chk in chks]
        if len(ns) == 0:
            raise Exception('No checkpoints in ./chkpnts')
        n = sorted(ns)[-1]
        print(f'checkpnt_epoch-{n}k.pt')
        trainer.load(f'chkpnts/checkpnt_step-{n}k.pt')

    except Exception as e:
        print('error: ', e)
    return trainer, model

trainer, model = get_trainer()

checkpnt_epoch-0k.pt


In [6]:
trainer.train_loop(1_000_000, dl)

  0%|          | 0/1000000 [00:00<?, ?it/s]


KeyboardInterrupt



In [13]:
src = example['src']
tgt = example['trg']
src_mask = example['src_mask']
tgt_mask = example['trg_mask']
trg_y = example['trg_y']

output = model(src, tgt, src_mask, tgt_mask)
phrase = torch.argmax(output[0], dim=-1)
itos(phrase)

'Ein Ein Ein <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>'