In [1]:
# !pip install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl numpy matplotlib torchtext 

In [3]:
# Standard PyTorch imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy
from torch.autograd import Variable

# For plots
%matplotlib inline
import matplotlib.pyplot as plt


import tensorflow as tf

#!conda install torchtext spacy
# !python -m spacy download en
# !python -m spacy download de

from torchtext import data
from torchtext import datasets

import re
import spacy

spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

url = re.compile('(<url>.*</url>)')


def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(url.sub('@URL@', text))]


def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))]


# Testing IWSLT
DE = data.Field(tokenize=tokenize_de, init_token='<bos>', eos_token='<eos>', include_lengths=True)
EN = data.Field(tokenize=tokenize_en, init_token='<bos>', eos_token='<eos>', include_lengths=True)

train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN))


train_it = data.Iterator(train, batch_size=4, sort_within_batch=True, train=True, repeat=False, shuffle=True)
MIN_WORD_FREQ = 10
MAX_NUM_WORDS = 1000
DE.build_vocab(train.src, min_freq=MIN_WORD_FREQ, max_size=MAX_NUM_WORDS)
EN.build_vocab(train.trg, min_freq=MIN_WORD_FREQ, max_size=MAX_NUM_WORDS)

num_wds_input = len(DE.vocab.itos)
num_wds_output = len(EN.vocab.itos)

num_wds_input, num_wds_output



(1004, 1004)

In [7]:

from tensorflow.contrib.layers import layer_norm
import nn_utils

In [131]:

class masked_softmax:
    def __init__(self, v, mask, dim=1):
        #bs, query dimension, key dimension
        v_mask = v * mask
        v_max = tf.reduce_max(v_mask, dim, keep_dims=True)
        v_stable = v_mask - v_max

        v_exp = tf.exp(v_stable) * mask
        v_exp_sum = tf.reduce_sum(v_exp, dim, keep_dims=True)
        self.v_mask, self.v_max, self.v_stable, self.v_exp, self.v_exp_sum = \
            v_mask, v_max, v_stable, v_exp, v_exp_sum
        self.output =  v_exp / (v_exp_sum + 1e-20)


class Encoder:
  def __init__(self, num_wds, wd_ind, mask, ndims=20, n_layers=1):
    self.num_wds = num_wds
    self.wd_ind = wd_ind
    self.mask = mask
    self.length = tf.shape(self.wd_ind)[1]
    self.wd_emb = tf.Variable(
        tf.random_uniform([self.num_wds, ndims], minval=-1, maxval=1.))
    self.wd_vec = tf.nn.embedding_lookup(self.wd_emb, wd_ind)
    self.position = tf.reshape(
        tf.range(tf.cast(self.length, tf.float32), dtype=tf.float32) / 10000,
        (1, -1, 1))
    self.w_tilde = embedding = self.wd_vec + self.position
    self.encoding = []
    self.attentionLayers = []
    for _ in range(n_layers):
      attentionLayer = AttentionLayer(embedding, mask)
      embedding = attentionLayer.output
      self.encoding.append(embedding)
      self.attentionLayers.append(attentionLayer)


class AttentionLayer:
  def __init__(self, X, mask, X_decode=None, decode_mask=None, ff_layer=True):
    bs, length, ndim = [v.value for v in X.shape]
    self.X = X
    if X_decode is None:
      self.q, self.k, self.v = [
          tf.tanh(tf.layers.dense(X, ndim)) for _ in range(3)
      ]
      decode_mask = mask
    else:
      self.q = tf.tanh(tf.layers.dense(X_decode, ndim))
      self.k, self.v = [tf.tanh(tf.layers.dense(X, ndim)) for _ in range(2)]
    #batch, attention queries, attention keys, embeddings
    self.q_expanded = tf.expand_dims(self.q, 2)
    self.k_expanded = tf.expand_dims(self.k, 1)
    self.v_expanded = tf.expand_dims(self.v, 1)
    self.s_raw = tf.reduce_sum(self.q_expanded * self.k_expanded, -1)
    self.mask = tf.expand_dims(decode_mask, 2) * tf.expand_dims(mask, 1)
    self.masked_softmax = masked_softmax(self.s_raw, self.mask)
    self.s = self.masked_softmax.output
    self.a = tf.expand_dims(self.s * self.mask, -1) * self.v_expanded
    #A is shape bs, query, key, emb
    self.a_compressed = tf.reduce_sum(self.a, 2)
    if X_decode is None:
        self.e = layer_norm(self.a_compressed + X)
    else:
        self.e = layer_norm(self.a_compressed + X_decode)
    if ff_layer:
      self.output = layer_norm(tf.layers.dense(self.e, ndim) + self.e)
    else:
      self.output = self.e


class Decoder:
  def __init__(self, num_wds, wd_ind, mask, encoder, ndims=20, n_layers=1):
    self.num_wds = num_wds
    self.wd_ind = wd_ind
    self.mask = mask
    self.encoder = encoder
    self.length = tf.shape(self.wd_ind)[1]
    self.wd_emb = tf.Variable(
        tf.random_uniform([self.num_wds, ndims], minval=-1, maxval=1.))
    self.wd_vec = tf.nn.embedding_lookup(self.wd_emb, wd_ind)
    self.position = tf.reshape(
        tf.range(tf.cast(self.length, tf.float32), dtype=tf.float32) / 10000,
        (1, -1, 1))
    self.w_tilde = embedding = self.wd_vec + self.position
    self.decoding = []
    self.self_attentions = []
    self.encoder_attentions = []
    for l_idx in range(n_layers):
      attn = AttentionLayer(embedding, mask, ff_layer=False)
      self.self_attentions.append(attn)
      encode_attn = AttentionLayer(encoder.encoding[l_idx], encoder.mask,
                                   attn.output, mask)
      self.encoder_attentions.append(encode_attn)
      embedding = encode_attn.output

    self.output_raw = tf.layers.dense(embedding, num_wds)
    #bs, word in sentence of target, embedding
    
    self.masked_softmax = masked_softmax(self.output_raw, mask)
    self.output = self.masked_softmax.output


class Transformer:
  def __init__(self, num_wds):
    self.num_wds = num_wds
    self.learning_rate = tf.placeholder(tf.float32, None)
    self.wd_ind_src = wd_ind_src = tf.placeholder(tf.int32, (None, None))
    self.wd_ind_trg = wd_ind_trg = tf.placeholder(tf.int32, (None, None))
    self.input_lengths = tf.placeholder(tf.int32, [None])
    self.output_lengths = tf.placeholder(tf.int32, [None])
    self.input_mask = tf.sequence_mask(
        self.input_lengths,
        maxlen=tf.shape(self.wd_ind_src)[-1],
        dtype=tf.float32)
    self.output_mask = tf.sequence_mask(
        self.output_lengths,
        maxlen=tf.shape(self.wd_ind_trg)[-1],
        dtype=tf.float32)
    self.encoder = Encoder(num_wds, wd_ind_src, self.input_mask)
    self.decoder = Decoder(num_wds, wd_ind_trg, self.output_mask, self.encoder)
    opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
    self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=self.wd_ind_trg, logits=self.decoder.output_raw) * self.output_mask)
    self.optimizer, self.grad_norm_total = nn_utils.apply_clipped_optimizer(
        opt, self.loss)



In [136]:


transformer = Transformer(num_wds_input)


sess = tf.Session()
sess.run(tf.global_variables_initializer())
for train_batch in train_it:
    src_tensor  = train_batch.src[0].data.cpu().numpy().transpose()
    src_len = train_batch.src[1].cpu().numpy()
    trg_tensor  = train_batch.trg[0].data.cpu().numpy().transpose()
    trg_len = train_batch.trg[1].cpu().numpy()
#     print(src_tensor.shape, src_len.shape, trg_tensor.shape, trg_len.shape)
#     print(src_tensor, src_len, trg_tensor, trg_len)
    trn_feed_dict = {transformer.wd_ind_src : src_tensor, transformer.input_lengths : src_len,
                    transformer.wd_ind_trg : trg_tensor, transformer.output_lengths : trg_len,
                    transformer.learning_rate : 1e-2}
    _,loss = sess.run([transformer.optimizer, transformer.loss], trn_feed_dict)
    print(loss)


5.2241344
4.1677203
3.0242639
4.125008
3.0598974
4.186294
3.8097103
4.3658805
3.446637
3.403835
3.1183062
3.3539703
2.508477
2.9842038
2.8408532
3.0314736
3.408704
2.0514636
1.8326241
2.6074464
2.327334
2.2641985
2.8061223
1.3473037
2.389687
2.2487276
1.4788653
1.3561007
2.4342747
1.5998297
2.4982827
2.049074
1.788495
1.2488563
1.3435767
1.3553156
1.8865074
1.5290121
1.5308939
2.064324
1.3196985
1.0255867
2.3608482
1.9188416
1.5824369
0.97924995
1.5048205
1.1546017
1.4323416
0.9919289
1.514226
1.0640409
1.4528157
1.3348244
0.7683308
1.1942271
1.3846107
0.6223112
1.2278935
1.2060142
1.0063374
1.1859266
0.6871687
1.3848866
1.2084869
1.060476
0.99177545
0.55127287
0.8096385
0.67267025
1.0020802
1.3161674
0.9653521
0.5362376
0.6868691
1.161632
0.5444901
0.72492146
0.9198456
1.0695945
1.0160395
0.77016884
0.8900901
0.48960805
0.652707
0.6958576
0.6533941
0.5528134
0.49366683
0.53368205
0.564244
0.5127789
0.47876883
0.6521485
0.993804
0.84442204
0.47907633
1.1945776
0.43566784
0.4702915
1.01

0.0073305275
0.00526249
0.005677248
0.0038444211
0.023071691
0.010023771
0.02665178
0.0037841436
0.0011397912
0.069947906
0.0010886745
0.0003613947
0.0034772835
0.0021976505
0.011086265
0.00024881837
0.008097199
0.006793519
0.053471413
0.0061328295
0.0755811
0.0027963328
0.0017145023
0.0028453367
0.005497146
0.017357497
0.0020964083
0.020921413
0.003566575
0.0731629
0.0022570626
0.0041529164
0.00340041
0.00062050694
0.0013398404
0.00053863885
0.00044605604
0.023197623
0.015108791
0.006029957
0.006234276
0.093518704
0.004030966
0.052772164
0.0012220616
0.0068859304
0.015309802
0.016544366
0.01240046
0.017381577
0.0011561422
0.0023517981
0.021097803
0.00075036124
0.0051371823
0.0018710006
0.0021657147
0.0028947964
0.0014999441
0.00033876844
0.0012194318
0.004286629
0.050379734
0.004905951
0.00191205
0.0036475935
0.040779635
0.043435562
0.0017944172
0.00302316
0.00070584944
0.0047501978
0.013034207
0.011882876
0.002284447
0.0114504555
0.00013739316
0.0010584623
0.008302085
0.0017695504
0.

0.00041336202
0.0008682358
0.00039313757
0.0001646451
0.00015138114
0.001002277
0.0001958813
0.0005880952
0.00025457528
0.00024751603
0.005431298
0.0002086362
0.00041601612
4.6567176e-05
0.00026340157
0.0004947261
4.9169703e-05
6.218848e-05
0.5511606
0.0003390398
0.00035679733
0.0003706657
0.0004525237
0.00017091562
0.0016932226
0.0024641799
0.0009223542
0.00049170357
0.00025385496
0.0094900625
0.00024277299
0.0008217217
0.0007859924
0.00026096983
0.00012231224
0.0006766619
0.00032053213
0.00037615886
0.0026059707
0.00078822894
0.00034576782
0.0013885832
0.0005596405
0.00096716767
0.00017411615
0.00016946155
0.00024514404
0.023184365
0.0006325484
0.012389659
0.00035957378
8.956113e-05
0.000636048
0.00050809555
0.0010458638
0.000340077
0.0001565141
0.0018854719
0.0055719553
0.00018505339
0.00043734838
0.001223454
0.001456819
0.05179802
0.00050398847
0.00020151946
0.00042180455
0.0009176832
0.0008593436
0.0002485548
0.0005547124
0.00022883088
0.00029219108
0.00015851675
0.00023563203
0.0

9.207658e-05
0.00011356792
0.00033812269
7.376567e-05
1.964106e-05
0.0007049819
7.587248e-05
5.3524665e-05
0.0015853893
7.4511176e-05
7.459551e-05
4.23496e-05
0.00024049864
2.0968653e-05
4.8809354e-05
0.0002831305
2.7631526e-05
0.00014280966
8.3139064e-05
7.51793e-05
0.00011695015
0.00018381859
0.015586635
7.651671e-05
4.5518813e-05
7.761083e-05
2.0305979e-05
0.00018099324
0.00023646175
0.0002525717
0.00012734097
0.00016181522
0.00057983026
7.7004726e-05
0.00021913157
0.00023846736
2.058113e-05
0.00027199028
0.00017344984
0.000496827
0.005367055
8.89361e-05
0.0001487192
0.0002662687
0.0006814251
0.00043374195
0.00011672608
0.00011437977
0.0001460023
0.00014509441
0.0013254085
0.00020573517
0.00014300182
0.00015770602
6.453152e-05
0.00028347707
0.00032068885
0.00014764411
5.2654355e-05
0.009528777
0.00027243295
0.00032572087
0.0002922667
0.0054493817
0.00033742026
0.000187685
0.0005528265
0.00030417065
0.00046052248
0.0016849424
0.00013610724
0.00031171626
0.0001327208
0.00018191013
0.0

KeyboardInterrupt: 

In [133]:
self = transformer

In [134]:
s, m = sess.run([self.encoder.attentionLayers[0].s, self.encoder.attentionLayers[0].mask], trn_feed_dict)

In [129]:
s[-1,-1,-1]

nan

In [130]:
m[-1,-1,-1]

0.0

In [128]:
(s * m)[-1, -1, -1]

nan

In [83]:
sess.run(self.wd_ind_src, trn_feed_dict).shape

(4, 33)

In [81]:
sess.run(self.encoder.attentionLayers[0].masked_softmax.v_exp_sum, trn_feed_dict).shape

(4, 1, 33)

In [90]:
sess.run(self.encoder.attentionLayers[0].masked_softmax.output, trn_feed_dict).shape

(4, 33, 33)

In [36]:
sess.run(tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=self.wd_ind_trg, logits=self.decoder.output_raw) * self.output_mask, trn_feed_dict)

array([[6.8806496, 7.2538576, 6.836245 , 6.8354425, 6.8945065, 7.029778 ,
        7.261009 , 7.1071267, 7.1146455, 6.8747187, 6.6536813, 7.0499544,
        7.261133 , 7.0428286, 7.2384424, 6.873402 , 7.1073694, 7.1073966,
        6.6534758, 7.0587964, 6.6961355, 6.5816846, 6.873296 , 7.0611773,
        6.934389 , 7.4266753, 6.873235 , 7.107667 , 7.1076937, 7.2384634,
        6.8879128, 7.261527 , 6.8138814, 7.107829 , 7.107856 , 7.2384715,
        6.873083 , 7.107937 , 7.107964 , 7.238476 , 7.0586863, 7.1080456,
        6.871643 , 6.872977 , 7.1081266, 6.7329335, 7.104479 ],
       [      nan,       nan,       nan,       nan,       nan,       nan,
              nan,       nan,       nan,       nan,       nan,       nan,
              nan,       nan,       nan,       nan,       nan,       nan,
              nan,       nan,       nan,       nan,       nan,       nan,
              nan,       nan,       nan,       nan,       nan,       nan,
              nan,       nan,       nan,       n