In [48]:
import random

import numpy as np
import yaml
import torch
from experiment import EXPERIMENT_CATALOG

In [2]:
with open('configs/pretrain_baseline.yaml') as fin:
    config = yaml.safe_load(fin)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
experiment = EXPERIMENT_CATALOG['pretrain_baseline'](config, device)




In [3]:
model = experiment.model
test_iter = experiment.task.test
model.eval()

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(5851, 256)
    (dropout): Dropout(p=0.5, inplace=False)
    (rnn): LSTM(256, 256, num_layers=2, dropout=0.5)
    (out): Linear(in_features=256, out_features=5851, bias=True)
  )
  (attention): Attention(
    (attn_combine): Linear(in_features=512, out_features=256, bias=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(4244, 256)
    (dropout): Dropout(p=0.5, inplace=False)
    (rnn): LSTM(256, 256, num_layers=2, dropout=0.5)
    (out): Linear(in_features=256, out_features=4244, bias=True)
  )
)

In [70]:
src_str = experiment.source.tokenize('В стильном ресторане отеля подают блюда средиземноморской кухни из органических местных продуктов.')
src = experiment.source.process([src_str]).to(device)
trg_str = experiment.target.tokenize('Organic, locally sourced ingredients are blended in the Mediterranean dishes of its stylish restaurant.')
trg = experiment.target.process([trg_str]).to(device)

In [71]:
batch_size = 1
max_len = trg.shape[0]
trg_vocab_size = model.decoder.output_dim

# tensor to store decoder outputs
outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(device)

# last hidden state of the encoder is used as the initial hidden state of the decoder
_, enc_outputs, hidden, cell = model.encoder(src)

# first input to the decoder is the <sos> tokens
dec_input = trg[0, :]
attn_list = []
for t in range(1, max_len):
    output, hidden, cell = model.decoder(dec_input, hidden, cell)
    hidden, attn_scores = model.attention(enc_outputs, hidden, return_map=True)
    attn_list.append(attn_scores)
    outputs[t] = output
    teacher_force = random.random() < 0.
    top1 = output.max(1)[1]
    dec_input = (top1 if teacher_force else trg[t])

In [72]:
len(attn_list)

17

In [73]:
len(trg_str)

16

In [74]:
a = [attn.to('cpu').detach().numpy()[0] for attn in attn_list]

In [75]:
a = np.array(a)

In [76]:
src_os = ['<sos>', *src_str, '<eos>']
trg_os = ['<sos>', *trg_str, '<eos>']
[(trg_os[i], src_os[attn]) for i, attn in enumerate(np.argmax(a, axis=1)[:, 1])]

[('<sos>', '.'),
 ('organic', '.'),
 (',', '.'),
 ('locally', '.'),
 ('sourced', '.'),
 ('ingredients', 'ресторане'),
 ('are', 'ресторане'),
 ('blended', 'подают'),
 ('in', 'подают'),
 ('the', 'органических'),
 ('mediterranean', 'органических'),
 ('dishes', 'органических'),
 ('of', '<eos>'),
 ('its', 'подают'),
 ('stylish', 'подают'),
 ('restaurant', '<eos>'),
 ('.', '<eos>')]

In [77]:
src_str

['в',
 'стильном',
 'ресторане',
 'отеля',
 'подают',
 'блюда',
 'средиземноморской',
 'кухни',
 'из',
 'органических',
 'местных',
 'продуктов',
 '.']

In [66]:
src_str

12