In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time

In [2]:
from torchtext import data, datasets
import spacy

In [3]:
%load_ext autoreload
%autoreload 2
%aimport train
%aimport optimizer
%aimport transformer

Create dataset:

In [4]:
spacy_en = spacy.load('en')
spacy_de = spacy.load('de')

In [5]:
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

BOS_WORD = '<s>'
EOS_WORD = '</s>'
BLANK_WORD = "<blank>"
SRC = data.Field(tokenize=tokenize_de, pad_token=BLANK_WORD)
TGT = data.Field(tokenize=tokenize_en, init_token = BOS_WORD, 
                 eos_token = EOS_WORD, pad_token=BLANK_WORD)

MAX_LEN = 50
train_data, val_data, test_data = datasets.IWSLT.splits(
    exts=('.de', '.en'), fields=(SRC, TGT), 
    filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
        len(vars(x)['trg']) <= MAX_LEN)
MIN_FREQ = 2
SRC.build_vocab(train_data.src, min_freq=MIN_FREQ)
TGT.build_vocab(train_data.trg, min_freq=MIN_FREQ)

Wrapper for loss:

In [6]:
class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                              y.contiguous().view(-1)) / norm.item()
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.item() * norm.item()

Train

In [7]:
pad_idx = TGT.vocab.stoi["<blank>"]
model = transformer.make_model(len(SRC.vocab), len(TGT.vocab), d_model=512, d_ff=2048, N=6)
model.cuda()
criterion = train.LabelSmoothing(size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1)
criterion.cuda()
BATCH_SIZE = 64
train_iter = train.WrapperIterator(train_data, batch_size=BATCH_SIZE, device=torch.device('cuda'),
                        repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                        batch_size_fn=train.batch_size_fn, train=True)
valid_iter =train.WrapperIterator(val_data, batch_size=BATCH_SIZE, device=torch.device('cuda'),
                        repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                        batch_size_fn=train.batch_size_fn, train=False)



In [8]:
model_opt = optimizer.WrapperOpt(model.src_embed[0].d_model, 1, 2000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
for epoch in range(10):
    model.train()
    train.run_epoch((train.rebatch(pad_idx, b) for b in train_iter), 
              model, 
              SimpleLossCompute(model.generator, criterion, model_opt))
    model.eval()
    loss = train.run_epoch((train.rebatch(pad_idx, b) for b in valid_iter), 
                      model, 
                     SimpleLossCompute(model.generator, criterion, model_opt))
    print(loss)

Epoch Step: 1 Loss: 9.073149
Epoch Step: 51 Loss: 8.501863
Epoch Step: 101 Loss: 7.678742


KeyboardInterrupt: 