In [None]:
from model import Transformer

In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim

In [None]:
import io
import spacy
import numpy as np

from torchtext import data
from torchtext import datasets

In [None]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')


def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

DE = data.Field(tokenize=tokenize_de,
                init_token='<SOS>',
                eos_token='<EOS>',
                fix_length=20,
                lower=True,
                batch_first=True)
EN = data.Field(tokenize=tokenize_en,
                init_token='<SOS>',
                eos_token='<EOS>',
                lower=True,
                fix_length=20,
                batch_first=True)

In [None]:
train = datasets.TranslationDataset(path='./data/train', 
                                    exts=('.de.txt', '.en.txt'),
                                    fields=(DE, EN))

In [None]:
test = datasets.TranslationDataset(path='./data/test', 
                                   exts=('.de.txt', '.en.txt'),
                                   fields=(DE, EN))

In [None]:
DE.build_vocab(train.src, min_freq=3)
EN.build_vocab(train, max_size=50000)

In [None]:
train_iter = data.BucketIterator(dataset=train, 
                                 batch_size=32,
                                 shuffle=True,
                                 sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)))

In [None]:
test_iter = data.BucketIterator(dataset=test, 
                                batch_size=5)

In [None]:
def make_batch(batch):
    src = batch.src
    src_mask = 1-(src==1)
    src_mask.unsqueeze_(2)
    src_position = torch.LongTensor([range(src.size(1))]*src.size(0)).to(src.device)
    
    trg = batch.trg
    trg_mask = 1-(trg==1)
    trg_mask.unsqueeze_(2)
    trg_position = torch.LongTensor([range(trg.size(1))]*trg.size(0)).to(trg.device)
    return src,src_mask,src_position,trg,trg_mask,trg_position

## Training

In [None]:
model = Transformer(len(DE.vocab),len(EN.vocab))

In [None]:
cuda = torch.device('cuda')

In [None]:
model = model.to(cuda)

In [None]:
train_step = optim.Adam(model.parameters(),lr=1e-4)

In [None]:
loss_ = []

for i,batch in enumerate(train_iter):
    src,src_mask,src_position,trg,trg_mask,trg_position = make_batch(batch)
    logit = model(src,src_mask,src_position,trg,trg_mask,trg_position)
    loss = F.cross_entropy(logit[:,:-1,:].contiguous().view(-1,50004),
                           trg[:,1:].contiguous().view(-1),ignore_index=1)
    train_step.zero_grad()
    loss.backward()
    train_step.step()
    loss_.append(loss.item())
    
    if i % 1000 == 0:
        print "epoch:%s , loss:%s" % (i,np.mean(loss_)) 
        loss_ = []
    
    if i == 30000:
        torch.save(model.state_dict(), "model-30000.pth")
        break

In [None]:
_,p = logit.max(-1)

In [None]:
' '.join([EN.vocab.itos[i] for i in trg[3]])

In [None]:
' '.join([EN.vocab.itos[i] for i in p[3]])

In [None]:
test_batch = next(iter(test_iter))

In [None]:
src,src_mask,src_position,trg,trg_mask,trg_position = make_batch(test_batch)

In [None]:
trg_ = torch.zeros_like(src)
trg_mask_ = torch.zeros_like(src_mask)
trg_mask_[:,0,:] = 2
trg_position_ = src_position

In [None]:
o = model.inference(src,src_mask,src_position.cuda(),trg_,trg_mask_,trg_position_.cuda())

In [None]:
' '.join([EN.vocab.itos[i] for i in trg[1]])

In [None]:
' '.join([EN.vocab.itos[i] for i in o[1]])