In [1]:
import warnings
warnings.filterwarnings(action='ignore')

import torch
from tensor2tensor.visualization import attention
from tokenization import Tokenizer, PretrainedTokenizer
from model import Transformer

vocab_file_src='fra.vocab'
vocab_file_tgt='eng.vocab'
pretrained_model_src='fra.model'
pretrained_model_tgt='eng.model'
# Input parameters
max_seq_len=80
no_cuda=False
model = '.model/model.ep15'

In [2]:
# Load tokenizer
tokenizer_src = PretrainedTokenizer(pretrained_model = pretrained_model_src, vocab_file = vocab_file_src)
tokenizer_tgt = PretrainedTokenizer(pretrained_model = pretrained_model_tgt, vocab_file = vocab_file_tgt)

# Load model
device = 'cuda' if torch.cuda.is_available() and not no_cuda else 'cpu'
model = torch.load(model).to(device)
model.eval()

# Make input
text = 'Je ferai n\'importe quoi pour lui.'
tokens = tokenizer_src.tokenize(text)
tokens = tokens[:max_seq_len]

src_ids = tokenizer_src.convert_tokens_to_ids(tokens)
padding_length = max_seq_len - len(src_ids)
src_ids = src_ids + ([tokenizer_src.pad_token_id] * padding_length)
src_ids = torch.tensor(src_ids).unsqueeze(0).to(device)
tgt_ids = torch.tensor([tokenizer_tgt.bos_token_id]).unsqueeze(0).to(device)

print('--------------------------------------------------------')
print('tokens: {}'.format(tokens))
print('src_ids: {}'.format(src_ids))
print('tgt_ids: {}'.format(tgt_ids))
print('--------------------------------------------------------')

# Inference
for i in range(max_seq_len):
    outputs, encoder_attns, decoder_attns, enc_dec_attns = model(src_ids, tgt_ids)
    output_token_id = outputs[:,-1,:].argmax(dim=-1).item()

    if output_token_id == tokenizer_tgt.eos_token_id:
        break
    else:
        tgt_ids = torch.cat((tgt_ids, torch.tensor([output_token_id]).unsqueeze(0).to(device)), dim=-1)

ids = tgt_ids[0].tolist()
tokens = tokenizer_tgt.convert_ids_to_tokens(ids)
print(tokenizer_tgt.detokenize(tokens)) # I will do anything for him.

--------------------------------------------------------
tokens: ['▁Je', '▁ferai', '▁n', "'", 'importe', '▁quoi', '▁pour', '▁lui', '.']
src_ids: tensor([[  34, 1746,   25, 4910, 1404,  400,   98,  223, 4906,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]], device='cuda:0')
tgt_ids: tensor([[2]], device='cuda:0')
--------------------------------------------------------
I'll do anything for him.


In [3]:
src_tokens = tokenizer_src.convert_ids_to_tokens(src_ids[0].tolist())

In [4]:
encoder_attns = [attn.detach().cpu() for attn in encoder_attns]
decoder_attns = [attn.detach().cpu() for attn in decoder_attns]
enc_dec_attns = [attn.detach().cpu() for attn in enc_dec_attns]

In [5]:
attention.show(src_tokens, tokens, encoder_attns, decoder_attns, enc_dec_attns)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [5]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>