In [1]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchtext import data, datasets
import spacy
from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

dev = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
from utils import Multi30KEn2DeDatasetTokenizer
dataset = Multi30KEn2DeDatasetTokenizer(dev=dev)

In [3]:
dl = DataLoader(
    dataset.train,
    shuffle=True,
    batch_size=400,
    collate_fn=dataset.collate_fn)
example =  next(iter(dl))
{a: example[a].shape for a in example}
del example

In [4]:
a = np.array([batch['ntokens'].cpu() for batch in dl]).mean()
del dl
a

4813.465753424657

In [5]:
from transformers.model import EncoderDecoder, Trainer
import os

def get_trainer():
    model = EncoderDecoder(
        len(dataset.src_vocab),
        len(dataset.trg_vocab),
        dropout=.1).to(dev)

    trainer = Trainer(
        model, dataset, dev=dev, criterion='cross_entropy')
    try:
        chks = os.listdir('./chkpnts')
        ns = [int(chk.split('checkpnt_step-')[1].split('k.pt')[0])
             for chk in chks]
        if len(ns) == 0:
            raise Exception('No checkpoints in ./chkpnts')
        n = sorted(ns)[-1]
        print(f'checkpnt_epoch-{n}k.pt')
        trainer.load(f'chkpnts/checkpnt_step-{n}k.pt')

    except Exception as e:
        print('error: ', e)
    return trainer, model

trainer, model = get_trainer()

checkpnt_epoch-0k.pt


In [6]:
f'{sum(p.numel() for p in model.parameters()) / 1e6} * 1e6 parameters'

'77.570878 * 1e6 parameters'

In [None]:
trainer.train_loop(500_000, batch_size=400, save=True, notify=True)

  0%|          | 0/500000 [00:00<?, ?it/s]

In [None]:
test_dl = DataLoader(
    dataset.train, shuffle=True, batch_size=1, collate_fn=dataset.collate_fn)
example =  next(iter(test_dl))

src = example['src']
src_mask = example['src_mask']

model.eval()
trg = model.translate(
    src, src_mask, dataset.start_symbol, dataset.subsequent_mask, dev=dev)
dataset.itos(src[0], field='src'), dataset.itos(trg[0])

In [None]:
plt.plot(trainer.losses)