In [None]:
# coding: utf-8

# In[1]:

# !pip install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl numpy matplotlib torchtext

# In[2]:

# Standard PyTorch imports
import numpy as np

from tensorflow.contrib.layers import layer_norm
import nn_utils

# For plots
#get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
import tensorflow as tf
real = 1
BATCH_SIZE = 16
#!conda install torchtext spacy
# !python -m spacy download en
# !python -m spacy download de
def detect_end(choice, eos_token=None):
    return choice.flatten()[0] == eos_token

if real:
  from torchtext import data
  from torchtext import datasets
  import tqdm
  import re
  import spacy

  spacy_de = spacy.load('de')
  spacy_en = spacy.load('en')

  url = re.compile('(<url>.*</url>)')

  def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(url.sub('@URL@', text))]

  def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))]

  # Testing IWSLT
  DE = data.Field(
      tokenize=tokenize_de,
      init_token='<bos>',
      eos_token='<eos>',
      include_lengths=True)
  EN = data.Field(
      tokenize=tokenize_en,
      init_token='<bos>',
      eos_token='<eos>',
      include_lengths=True)

  train, val, test = datasets.IWSLT.splits(
      exts=('.de', '.en'), fields=(DE, EN))

  train_it = data.Iterator(
      train,
      batch_size=BATCH_SIZE,
      sort_within_batch=True,
      train=True,
      repeat=False,
      shuffle=True)
  MIN_WORD_FREQ = 10
  MAX_NUM_WORDS = 10000
  DE.build_vocab(train.src, min_freq=MIN_WORD_FREQ, max_size=MAX_NUM_WORDS)
  EN.build_vocab(train.trg, min_freq=MIN_WORD_FREQ, max_size=MAX_NUM_WORDS)

  num_wds_input = len(DE.vocab.itos)
  num_wds_output = len(EN.vocab.itos)
else:
  num_wds_input = 1004


class masked_softmax:
  def __init__(self, v, mask, dim=2):
    #bs, query dimension, key dimension
    v_mask = v * mask
    v_max = tf.reduce_max(v_mask, dim, keep_dims=True)
    v_stable = v_mask - v_max

    v_exp = tf.exp(v_stable) * mask
    v_exp_sum = tf.reduce_sum(v_exp, dim, keep_dims=True)
    self.v_mask, self.v_max, self.v_stable, self.v_exp, self.v_exp_sum = \
        v_mask, v_max, v_stable, v_exp, v_exp_sum
    self.output = v_exp / (v_exp_sum + 1e-20)


class Encoder:
  def __init__(self, num_wds, wd_ind, mask, ndims=64, n_layers=6):
    self.num_wds = num_wds
    self.wd_ind = wd_ind
    self.mask = mask
    self.length = tf.shape(self.wd_ind)[1]
    self.wd_emb = tf.Variable(
        tf.random_uniform([self.num_wds, ndims], minval=-1, maxval=1.))
    self.wd_vec = tf.nn.embedding_lookup(self.wd_emb, wd_ind)
    self.pos = tf.reshape(
        tf.range(tf.cast(self.length, tf.float32), dtype=tf.float32),
        (1, -1, 1))
    self.divider_exponent = tf.reshape(
        tf.range(tf.cast(ndims // 2, tf.float32)),
        (1, 1, -1)) * 2. / tf.cast(ndims, tf.float32)
    self.divider = tf.pow(10000., self.divider_exponent)
    self.input_to_sinusoids = self.pos / self.divider
    self.pos_sin = tf.sin(self.input_to_sinusoids)
    self.pos_cos = tf.cos(self.input_to_sinusoids)
    # self.position = tf.reshape(
    #     tf.range(tf.cast(self.length, tf.float32), dtype=tf.float32) / 10000,
    #     (1, -1, 1))
    self.position = tf.concat((self.pos_sin, self.pos_cos), -1)
    self.w_tilde = embedding = self.wd_vec + self.position
    self.encoding = []
    self.attentionLayers = []
    for _ in range(n_layers):
      attentionLayer = AttentionLayer(embedding, mask)
      embedding = attentionLayer.output
      self.encoding.append(embedding)
      self.attentionLayers.append(attentionLayer)


class AttentionLayer:
  def __init__(self, X, mask, X_decode=None, decode_mask=None, ff_layer=True):
    bs, length, ndim = [v.value for v in X.shape]
    self.X = X
    if X_decode is None:
      self.q, self.k, self.v = [
          tf.tanh(tf.layers.dense(X, ndim)) for _ in range(3)
      ]
      decode_mask = mask
    else:
      self.q = tf.tanh(tf.layers.dense(X_decode, ndim))
      self.k, self.v = [tf.tanh(tf.layers.dense(X, ndim)) for _ in range(2)]
    #batch, attention queries, attention keys, embeddings
    self.q_expanded = tf.expand_dims(self.q, 2)
    self.k_expanded = tf.expand_dims(self.k, 1)
    self.v_expanded = tf.expand_dims(self.v, 1)
    self.s_raw = tf.reduce_sum(self.q_expanded * self.k_expanded, -1)
    self.mask = tf.expand_dims(decode_mask, 2) * tf.expand_dims(mask, 1)
    self.masked_softmax = masked_softmax(self.s_raw, self.mask)
    self.s = self.masked_softmax.output
    self.a = tf.expand_dims(self.s * self.mask, -1) * self.v_expanded
    #A is shape bs, query, key, emb
    self.a_compressed = tf.reduce_sum(self.a, 2)
    if X_decode is None:
      self.e = layer_norm(self.a_compressed + X)
    else:
      self.e = layer_norm(self.a_compressed + X_decode)
    if ff_layer:
      self.output = layer_norm(tf.layers.dense(self.e, ndim) + self.e)
    else:
      self.output = self.e


class Decoder:
  def __init__(self, num_wds, wd_ind, mask, encoder, ndims=20, n_layers=6):
    self.num_wds = num_wds
    self.wd_ind = wd_ind
    self.mask = mask
    self.encoder = encoder
    self.length = tf.shape(self.wd_ind)[1]
    self.wd_emb = tf.Variable(
        tf.random_uniform([self.num_wds, ndims], minval=-1, maxval=1.))
    self.wd_vec = tf.nn.embedding_lookup(self.wd_emb, wd_ind)
    self.pos = tf.reshape(
        tf.range(tf.cast(self.length, tf.float32), dtype=tf.float32),
        (1, -1, 1))
    self.divider_exponent = tf.reshape(
        tf.range(tf.cast(ndims // 2, tf.float32)),
        (1, 1, -1)) * 2. / tf.cast(ndims, tf.float32)
    self.divider = tf.pow(10000., self.divider_exponent)
    self.input_to_sinusoids = self.pos / self.divider
    self.pos_sin = tf.sin(self.input_to_sinusoids)
    self.pos_cos = tf.cos(self.input_to_sinusoids)
    # self.position = tf.reshape(
    #     tf.range(tf.cast(self.length, tf.float32), dtype=tf.float32) / 10000,
    #     (1, -1, 1))
    self.position = tf.concat((self.pos_sin, self.pos_cos), -1)
    self.w_tilde = embedding = self.wd_vec + self.position
    self.decoding = []
    self.self_attentions = []
    self.encoder_attentions = []
    self.early_outputs = []
    for l_idx in range(n_layers):
      attn = AttentionLayer(embedding, mask, ff_layer=False)
      self.self_attentions.append(attn)
      encode_attn = AttentionLayer(encoder.encoding[l_idx], encoder.mask,
                                   attn.output, mask)
      self.encoder_attentions.append(encode_attn)
      embedding = encode_attn.output
      if l_idx < n_layers - 1:
        early_output = tf.layers.dense(embedding, num_wds)
        #early_output_masked = masked_softmax(early_output, tf.expand_dims(mask, -1), dim=2).output
        self.early_outputs.append(early_output)

    self.output_raw = tf.layers.dense(embedding, num_wds)
    #bs, word in sentence of target, embedding

    self.outsoftmax = masked_softmax(self.output_raw, tf.expand_dims(mask, -1), dim=2)
    self.output = self.outsoftmax.output

class Transformer:
  def __init__(self, num_wds):
    self.num_wds = num_wds
    n_layers = 6
    ndims = 256
    self.learning_rate = tf.placeholder(tf.float32, None)
    self.wd_ind_src = wd_ind_src = tf.placeholder(tf.int32, (None, None))
    self.wd_ind_trg = wd_ind_trg = tf.placeholder(tf.int32, (None, None))
    self.input_lengths = tf.placeholder(tf.int32, [None])
    self.output_lengths = tf.placeholder(tf.int32, [None])
    self.input_mask = tf.sequence_mask(
        self.input_lengths,
        maxlen=tf.shape(self.wd_ind_src)[-1],
        dtype=tf.float32)
    self.output_mask = tf.sequence_mask(
        self.output_lengths,
        maxlen=tf.shape(self.wd_ind_trg)[-1],
        dtype=tf.float32)
    self.encoder = Encoder(
        num_wds, wd_ind_src, self.input_mask, n_layers=n_layers, ndims=ndims)
    self.decoder = Decoder(
        num_wds,
        wd_ind_trg,
        self.output_mask,
        self.encoder,
        n_layers=n_layers,
        ndims=ndims)
    opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
    self.prediction_mask = tf.concat((tf.zeros(
        (tf.shape(self.output_mask)[0], 1)), self.output_mask[:, :-1] - self.output_mask[:, 1:]),
                                     1)
    self.losses = tf.reduce_mean([tf.reduce_mean(
        tf.square(
            tf.reduce_max(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=self.wd_ind_trg, logits=logits) *
                self.prediction_mask, 1))) for logits in self.decoder.early_outputs + [self.decoder.output_raw]])
    self.loss = tf.reduce_mean(
        tf.square(
            tf.reduce_max(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=self.wd_ind_trg, logits=self.decoder.output_raw) *
                self.prediction_mask, 1)))
    self.optimizer, self.grad_norm_total = nn_utils.apply_clipped_optimizer(
        opt, self.losses)


# In[71]:

transformer = Transformer(num_wds_input)
MAX_LEN = 20


def predict_one(src_tensor, src_len, trg_tensor):
    NUM_MAX_DECODE = 5 #MAX_LEN
    src_len_decode = src_len[0:1]
    src_decode = src_tensor[0:1, :src_len_decode[0]]
    autoregressive = trg_tensor[0:1, 0:1]
    for dec_idx in range(10):
        pred = sess.run(
            transformer.decoder.outsoftmax.output[:,-1,:], {
                transformer.wd_ind_src: src_decode,
                transformer.input_lengths: src_len_decode,
                transformer.wd_ind_trg: autoregressive,
                transformer.output_lengths: np.ones(1)*autoregressive.shape[1],
            })
        choice = pred.argmax(1)
        autoregressive = np.concatenate((autoregressive, np.expand_dims(choice, 0)), -1)
        if detect_end(choice, None):
            break
    translation = [EN.vocab.itos[a] for a in autoregressive.flatten()]
    print(translation)
    
    
sess = tf.Session()
sess.run(tf.global_variables_initializer())
itr = 0
print_freq = 100
running_losses = []
if real:
  #for itr, train_batch in enumerate(tqdm.tqdm_notebook(train_it)):
  for ep in range(100):
      for train_batch in tqdm.tqdm_notebook(train_it):
        itr += 1
        src_tensor = train_batch.src[0].data.cpu().numpy().transpose()
        src_len = train_batch.src[1].cpu().numpy()
        trg_tensor = train_batch.trg[0].data.cpu().numpy().transpose()
        trg_len = train_batch.trg[1].cpu().numpy()
        src_tensor, trg_tensor = [t[:, :MAX_LEN] for t in [src_tensor, trg_tensor]]
        src_len, trg_len = [np.clip(t, 0, MAX_LEN) for t in [src_len, trg_len]]
        trg_len = np.ceil(
            np.random.uniform(size=trg_len.shape[0]) * (trg_len - 1)).astype(int)
        trn_feed_dict = {
            transformer.wd_ind_src: src_tensor,
            transformer.input_lengths: src_len,
            transformer.wd_ind_trg: trg_tensor,
            transformer.output_lengths: trg_len,
            transformer.learning_rate: 1e-2 / (np.sqrt(itr + 3))
        }
        _, loss = sess.run([transformer.optimizer, transformer.loss],
                           trn_feed_dict)
        running_losses.append(loss)
        if itr % print_freq == 0:
          running_losses = np.array(running_losses)
          print(itr, 'loss_mean', running_losses.mean(), 'loss_std', running_losses.std(), 'loss_min', running_losses.min(),
              'loss_max', running_losses.max())
          running_losses = []
          if itr % 1000 == 0:
            predict_one(src_tensor, src_len, trg_tensor)
        if itr > 2000000:
          break
else:

  src_tensor = np.random.randint(low=0, high=num_wds_input, size=(BATCH_SIZE, 81))
  src_len = np.random.randint(2, 81, BATCH_SIZE)
  trg_tensor = np.random.randint(low=0, high=num_wds_input, size=(BATCH_SIZE, 84))
  trg_len = np.random.randint(2, 84, BATCH_SIZE)

  fd = {
      transformer.wd_ind_src: src_tensor,
      transformer.wd_ind_trg: trg_tensor,
      transformer.input_lengths: src_len,
      transformer.output_lengths: trg_len,
      transformer.learning_rate: 1e-2}
  sess.run([transformer.optimizer, transformer.loss], fd)
# In[75]:


  from ._conv import register_converters as _register_converters




    Only loading the 'de' tokenizer.



    Only loading the 'en' tokenizer.

Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
keep_dims is deprecated, use keepdims instead


HBox(children=(IntProgress(value=0, max=12306), HTML(value='')))

100 loss_mean 51.476025 loss_std 12.626099 loss_min 26.642902 loss_max 85.44374
200 loss_mean 37.293224 loss_std 8.756538 loss_min 21.133072 loss_max 62.912003
300 loss_mean 37.59685 loss_std 9.864247 loss_min 17.859806 loss_max 62.316414
400 loss_mean 33.84445 loss_std 8.606287 loss_min 16.358826 loss_max 59.507896
500 loss_mean 34.02646 loss_std 9.37392 loss_min 10.199059 loss_max 59.48942
600 loss_mean 32.443905 loss_std 9.06887 loss_min 12.982786 loss_max 54.585754
700 loss_mean 31.449831 loss_std 8.041597 loss_min 11.362991 loss_max 52.382874
800 loss_mean 31.390495 loss_std 8.292342 loss_min 12.162487 loss_max 52.828674
900 loss_mean 28.854658 loss_std 8.7589445 loss_min 9.765327 loss_max 53.9932
1000 loss_mean 28.692131 loss_std 8.167957 loss_min 10.482076 loss_max 54.172657
['<bos>', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',']
1100 loss_mean 29.801691 loss_std 9.740038 loss_min 6.784786 loss_max 56.10309
1200 loss_mean 28.420649 loss_std 8.766122 loss_min 11.031422 loss_m

9300 loss_mean 14.700266 loss_std 6.527357 loss_min 0.46171954 loss_max 30.16943
9400 loss_mean 13.518267 loss_std 6.994354 loss_min 0.9813534 loss_max 31.448917
9500 loss_mean 14.548695 loss_std 7.0417233 loss_min 3.2677855 loss_max 36.77241
9600 loss_mean 14.738674 loss_std 5.9524465 loss_min 2.7007308 loss_max 31.728958
9700 loss_mean 14.599072 loss_std 7.622949 loss_min 1.5258963 loss_max 34.438484
9800 loss_mean 15.007189 loss_std 7.5095916 loss_min 1.3961005 loss_max 37.79969
9900 loss_mean 14.841008 loss_std 8.603515 loss_min 0.8038865 loss_max 38.003654
10000 loss_mean 14.834167 loss_std 7.307744 loss_min 0.60740757 loss_max 34.54245
['<bos>', ',', 'to', 'to', 'that', 'that', 'that', 'dangers', 'is', 'is', 'is']
10100 loss_mean 14.249538 loss_std 7.151873 loss_min 0.61586386 loss_max 31.221218
10200 loss_mean 14.523155 loss_std 6.9445753 loss_min 0.23923245 loss_max 36.593807
10300 loss_mean 16.047361 loss_std 7.524457 loss_min 0.15171422 loss_max 35.19431
10400 loss_mean 14.51

HBox(children=(IntProgress(value=0, max=12306), HTML(value='')))

12400 loss_mean 14.736526 loss_std 6.814895 loss_min 0.053014867 loss_max 32.84482
12500 loss_mean 14.506157 loss_std 6.6233845 loss_min 3.6254475 loss_max 33.253746
12600 loss_mean 14.34903 loss_std 6.174605 loss_min 0.5172762 loss_max 30.965567
12700 loss_mean 14.865951 loss_std 6.469075 loss_min 1.4860168 loss_max 33.172398
12800 loss_mean 14.609281 loss_std 7.657356 loss_min 0.23216556 loss_max 36.552063
12900 loss_mean 14.743555 loss_std 5.9788723 loss_min 1.4544481 loss_max 35.74065
13000 loss_mean 14.593563 loss_std 6.0436907 loss_min 3.478807 loss_max 36.51232
['<bos>', '.', '.', '.', ',', '<unk>', 'more', 'to', '<unk>', '<unk>', 'kind']
13100 loss_mean 13.130924 loss_std 7.464578 loss_min 0.08939794 loss_max 35.02022
13200 loss_mean 13.993592 loss_std 6.5846157 loss_min 1.6950214 loss_max 31.72578
13300 loss_mean 12.759207 loss_std 6.151493 loss_min 2.0911474 loss_max 30.15445
13400 loss_mean 14.020524 loss_std 6.6179433 loss_min 2.6937296 loss_max 32.607193
13500 loss_mean 13

21400 loss_mean 12.852788 loss_std 6.1127768 loss_min 0.08050847 loss_max 30.489552
21500 loss_mean 13.619551 loss_std 6.8504324 loss_min 1.2747835 loss_max 33.144196
21600 loss_mean 13.172037 loss_std 6.159949 loss_min 0.5850219 loss_max 29.750801
21700 loss_mean 12.661195 loss_std 6.2308564 loss_min 0.21816522 loss_max 31.570555
21800 loss_mean 12.467774 loss_std 6.614501 loss_min 0.4162619 loss_max 31.930447
21900 loss_mean 13.179471 loss_std 6.5771346 loss_min 0.14309283 loss_max 41.341274
22000 loss_mean 12.436319 loss_std 5.4508567 loss_min 0.2870277 loss_max 26.949848
['<bos>', 'a', 'to', 'to', 'in', 'in', 'in', '.', 'That', ',', '<unk>']
22100 loss_mean 13.50678 loss_std 6.5044003 loss_min 0.2546462 loss_max 29.000555
22200 loss_mean 14.185298 loss_std 6.655943 loss_min 0.79388607 loss_max 34.51637
22300 loss_mean 12.637961 loss_std 6.911795 loss_min 0.23218146 loss_max 34.763626
22400 loss_mean 13.319734 loss_std 5.87173 loss_min 0.81664413 loss_max 27.775196
22500 loss_mean 1

HBox(children=(IntProgress(value=0, max=12306), HTML(value='')))

24700 loss_mean 13.6941395 loss_std 6.166013 loss_min 0.09461602 loss_max 35.378975
24800 loss_mean 13.522131 loss_std 7.2986217 loss_min 0.077760935 loss_max 41.778767
24900 loss_mean 12.9293 loss_std 7.5191035 loss_min 0.60505366 loss_max 30.613993
25000 loss_mean 13.129961 loss_std 6.5929947 loss_min 0.07855506 loss_max 34.934807
['<bos>', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'is', '<unk>']
25100 loss_mean 12.51868 loss_std 6.8567204 loss_min 0.24519914 loss_max 37.27026
25200 loss_mean 13.589473 loss_std 7.359705 loss_min 1.0915486 loss_max 35.44741
25300 loss_mean 12.3281555 loss_std 6.990785 loss_min 0.24308649 loss_max 33.291374
25400 loss_mean 12.369857 loss_std 6.27489 loss_min 0.07141957 loss_max 34.587288
25500 loss_mean 12.366999 loss_std 7.828096 loss_min 0.025999984 loss_max 38.341312
25600 loss_mean 11.997226 loss_std 6.833388 loss_min 0.63391477 loss_max 36.02401
25700 loss_mean 12.878513 loss_std 7.9988866 loss_min 0.044315625 loss_max 34.614765
25800 loss_m

33800 loss_mean 12.807307 loss_std 6.899493 loss_min 0.5586353 loss_max 28.117334
33900 loss_mean 11.740256 loss_std 6.3129935 loss_min 0.11103764 loss_max 27.306847
34000 loss_mean 12.362658 loss_std 6.6178865 loss_min 0.16476645 loss_max 35.926037
['<bos>', ',', "'s", '<eos>', 'I', 'were', 'as', ',', 'were', 'I', '<unk>']
34100 loss_mean 12.683462 loss_std 5.955844 loss_min 0.4034228 loss_max 29.793713
34200 loss_mean 13.453871 loss_std 6.367729 loss_min 0.17687216 loss_max 28.671017
34300 loss_mean 12.569189 loss_std 5.946906 loss_min 0.009566019 loss_max 26.371254
34400 loss_mean 12.5308895 loss_std 6.269864 loss_min 0.24361321 loss_max 29.955656
34500 loss_mean 11.910586 loss_std 6.0561185 loss_min 0.11476501 loss_max 32.80216
34600 loss_mean 13.550953 loss_std 6.960074 loss_min 0.15983522 loss_max 36.30722
34700 loss_mean 12.2127495 loss_std 6.4776673 loss_min 0.060363054 loss_max 26.553928
34800 loss_mean 13.208958 loss_std 6.5603733 loss_min 0.046347998 loss_max 30.801651
34900

HBox(children=(IntProgress(value=0, max=12306), HTML(value='')))

37000 loss_mean 13.958851 loss_std 7.049064 loss_min 1.3451111 loss_max 31.178764
['<bos>', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',']
37100 loss_mean 12.365937 loss_std 6.41912 loss_min 0.1784303 loss_max 37.067146
37200 loss_mean 12.329975 loss_std 6.2329783 loss_min 0.42283463 loss_max 30.192047
37300 loss_mean 13.355958 loss_std 7.134558 loss_min 0.042868145 loss_max 37.763535
37400 loss_mean 12.493601 loss_std 5.9889216 loss_min 0.29457933 loss_max 29.749636
37500 loss_mean 12.903881 loss_std 5.904778 loss_min 1.5073617 loss_max 27.758694
37600 loss_mean 13.332389 loss_std 6.7638683 loss_min 0.07309822 loss_max 31.569733
37700 loss_mean 12.658224 loss_std 6.6222157 loss_min 0.18387923 loss_max 30.764368
37800 loss_mean 13.067207 loss_std 6.8078704 loss_min 0.5991127 loss_max 30.797268
37900 loss_mean 12.935075 loss_std 6.0380573 loss_min 0.030503364 loss_max 32.64238
38000 loss_mean 12.997198 loss_std 6.217144 loss_min 0.22607088 loss_max 30.47718
['<bos>', '<eos>', '<eos>'

In [None]:
NUM_MAX_DECODE = MAX_LEN

In [None]:

src_len_decode = src_len[0:1]
src_decode = src_tensor[0:1, :src_len_decode[0]]
autoregressive = trg_tensor[0:1, 0:1]


In [None]:
for dec_idx in range(10):
    pred = sess.run(
        transformer.decoder.outsoftmax.output[:,-1,:], {
            transformer.wd_ind_src: src_decode,
            transformer.input_lengths: src_len_decode,
            transformer.wd_ind_trg: autoregressive,
            transformer.output_lengths: np.ones(1)*autoregressive.shape[1],
        })
    choice = pred.argmax(1)
    autoregressive = np.concatenate((autoregressive, np.expand_dims(choice, 0)), -1)
    if detect_end(choice, None):
        break

In [None]:
pred.argmax(1)

In [None]:
pred

In [None]:
translation = [EN.vocab.itos[a] for a in autoregressive.flatten()]

In [None]:
translation

In [None]:
pred[:,12:20]