In [1]:
from model import Transformer

In [2]:
import torch

In [3]:
import io
import spacy
import numpy as np

from torchtext import data
from torchtext import datasets

In [4]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')


def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

DE = data.Field(tokenize=tokenize_de,
                init_token='<SOS>',
                eos_token='<EOS>',
                fix_length=20,
                lower=True,
                batch_first=True)
EN = data.Field(tokenize=tokenize_en,
                init_token='<SOS>',
                eos_token='<EOS>',
                lower=True,
                fix_length=20,
                batch_first=True)

In [5]:
train = datasets.TranslationDataset(path='./data/train', 
                                    exts=('.de.txt', '.en.txt'),
                                    fields=(DE, EN))

test = datasets.TranslationDataset(path='./data/test', 
                                   exts=('.de.txt', '.en.txt'),
                                   fields=(DE, EN))

In [6]:
DE.build_vocab(train.src, min_freq=3)
EN.build_vocab(train, max_size=50000)

In [13]:
test_iter = data.BucketIterator(dataset=test, batch_size=15)

In [8]:
model = Transformer(len(DE.vocab),len(EN.vocab))

In [9]:
model.load_state_dict(torch.load('model-50000.pth'))

In [10]:
model.cuda()

Transformer(
  (src_emb): Input_Embedding(
    (word_emb): Embedding(39112, 512, padding_idx=2)
    (position_emb): Embedding(20, 512)
  )
  (trg_emb): Input_Embedding(
    (word_emb): Embedding(50004, 512, padding_idx=2)
    (position_emb): Embedding(20, 512)
  )
  (encoder): Encoder(
    (model): ModuleList(
      (0): EncoderBlock(
        (attention): MultiAttention(
          (fc_q): Linear(in_features=512, out_features=512, bias=True)
          (fc_k): Linear(in_features=512, out_features=512, bias=True)
          (fc_v): Linear(in_features=512, out_features=512, bias=True)
          (normalize): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
        )
        (fc): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (normalize): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
      )
      (1): EncoderBlock(


In [11]:
def make_batch(batch):
    src = batch.src
    src_mask = 1-(src==1)
    src_mask.unsqueeze_(2)
    src_position = torch.LongTensor([range(src.size(1))]*src.size(0)).to(src.device)
    
    trg = batch.trg
    trg_mask = 1-(trg==1)
    trg_mask.unsqueeze_(2)
    trg_position = torch.LongTensor([range(trg.size(1))]*trg.size(0)).to(trg.device)
    return src,src_mask,src_position,trg,trg_mask,trg_position

In [14]:
test_batch = next(iter(test_iter))

In [15]:
src,src_mask,src_position,trg,trg_mask,trg_position = make_batch(test_batch)

In [37]:
trg_ = torch.zeros_like(src)
trg_[:,0] = 2
trg_mask_ = torch.zeros_like(src_mask)
trg_mask_[:,0,:] = 1
trg_position_ = src_position

In [38]:
output = model.inference(src,src_mask,src_position.cuda(),trg_,trg_mask_,trg_position_.cuda())

In [42]:
' '.join([EN.vocab.itos[i] for i in trg[0]])

u"<SOS> it 's the symbol of all that we are and all that we can be as an astonishingly <EOS>"

In [43]:
' '.join([EN.vocab.itos[i] for i in output[0]])

u"<SOS> it 's the symbol of all of what we are and what we 're an amazing network for <EOS>"