In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import numpy as np
import tensorflow as tf
import json

In [3]:
with open('dataset-bpe.json') as fopen:
    data = json.load(fopen)

In [4]:
train_X = data['train_X']
train_Y = data['train_Y']
test_X = data['test_X']
test_Y = data['test_Y']

In [5]:
EOS = 2
GO = 1
vocab_size = 32000

In [6]:
train_Y = [i + [2] for i in train_Y]
test_Y = [i + [2] for i in test_Y]

In [30]:
from tensor2tensor.utils import beam_search

def pad_second_dim(x, desired_size):
    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))
    return tf.concat([x, padding], 1)

def hop_forward(memory_o, memory_i, response_proj, inputs_len, questions_len):
    match = memory_i
    match = pre_softmax_masking(match, inputs_len)
    match = tf.nn.softmax(match)
    match = post_softmax_masking(match, questions_len)
    response = tf.multiply(match, memory_o)
    return response_proj(response)


def pre_softmax_masking(x, seq_len):
    paddings = tf.fill(tf.shape(x), float('-inf'))
    T = tf.shape(x)[1]
    max_seq_len = tf.shape(x)[2]
    masks = tf.sequence_mask(seq_len, max_seq_len, dtype = tf.float32)
    masks = tf.tile(tf.expand_dims(masks, 1), [1, T, 1])
    return tf.where(tf.equal(masks, 0), paddings, x)


def post_softmax_masking(x, seq_len):
    T = tf.shape(x)[2]
    max_seq_len = tf.shape(x)[1]
    masks = tf.sequence_mask(seq_len, max_seq_len, dtype = tf.float32)
    masks = tf.tile(tf.expand_dims(masks, -1), [1, 1, T])
    return x * masks

def embed_seq(x, vocab_size, zero_pad = True):
    lookup_table = tf.get_variable(
        'lookup_table', [vocab_size, size_layer], tf.float32
    )
    if zero_pad:
        lookup_table = tf.concat(
            (tf.zeros([1, size_layer]), lookup_table[1:, :]), axis = 0
        )
    return tf.nn.embedding_lookup(lookup_table, x)

def sinusoidal_position_encoding(inputs, mask, repr_dim):
    T = tf.shape(inputs)[1]
    pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1])
    i = np.arange(0, repr_dim, 2, np.float32)
    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])
    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)
    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1]) * tf.expand_dims(tf.to_float(mask), -1)

def quest_mem(x, vocab_size, size_layer):
    en_masks = tf.sign(x)
    x = embed_seq(x, vocab_size)
    x += sinusoidal_position_encoding(x, en_masks, size_layer)
    return x

class Translator:
    def __init__(self, size_layer, num_layers, embedded_size, learning_rate,
                beam_width = 5, n_hops = 3):
        
        self.X = tf.placeholder(tf.int32, [None, None])
        self.Y = tf.placeholder(tf.int32, [None, None])
        
        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)
        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)
        batch_size = tf.shape(self.X)[0]
        
        lookup_table = tf.get_variable('lookup_table', [vocab_size, size_layer], tf.float32)
        
        with tf.variable_scope('memory_o'):
            memory_o = quest_mem(self.X, vocab_size, size_layer)
        
        with tf.variable_scope('memory_i'):
            memory_i = quest_mem(self.X, vocab_size, size_layer)
            
        with tf.variable_scope('interaction'):
            response_proj = tf.layers.Dense(size_layer)
            for _ in range(n_hops):
                answer = hop_forward(memory_o,
                                     memory_i,
                                     response_proj,
                                     self.X_seq_len,
                                     self.X_seq_len)
                memory_i = answer
                
        def cells(reuse=False):
            return tf.nn.rnn_cell.LSTMCell(size_layer,initializer=tf.orthogonal_initializer(),reuse=reuse)
        
        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])
        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)
        dense = tf.layers.Dense(vocab_size)
        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)])

        init_state = answer[:,-1]
        encoder_state = tf.nn.rnn_cell.LSTMStateTuple(c=init_state, h=init_state)
        encoder_state = tuple([encoder_state] * num_layers)
        
        print(encoder_state)
        vocab_proj = tf.layers.Dense(vocab_size)
        
        helper = tf.contrib.seq2seq.TrainingHelper(
            inputs = tf.nn.embedding_lookup(lookup_table, decoder_input),
            sequence_length = tf.to_int32(self.Y_seq_len))
        
        decoder = tf.contrib.seq2seq.BasicDecoder(cell = decoder_cells,
                                                  helper = helper,
                                                  initial_state = encoder_state,
                                                  output_layer = vocab_proj)
        
        decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder = decoder,
                                                                maximum_iterations = tf.reduce_max(self.Y_seq_len))
        
        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding = lookup_table,
                                                          start_tokens = tf.tile(
                                                              tf.constant([GO], 
                                                                          dtype=tf.int32), 
                                                              [tf.shape(init_state)[0]]),
                                                          end_token = EOS)
        decoder = tf.contrib.seq2seq.BasicDecoder(
            cell = decoder_cells,
            helper = helper,
            initial_state = encoder_state,
            output_layer = vocab_proj)
        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = decoder,
            maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))
        self.training_logits = decoder_output.rnn_output
        self.logits = decoder_output.sample_id
        self.fast_result = predicting_decoder_output.sample_id
        
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks)
        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)
        y_t = tf.argmax(self.training_logits,axis=2)
        y_t = tf.cast(y_t, tf.int32)
        self.prediction = tf.boolean_mask(y_t, masks)
        mask_label = tf.boolean_mask(self.Y, masks)
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [31]:
size_layer = 512
num_layers = 2
embedded_size = 256
learning_rate = 1e-3
batch_size = 128
epoch = 20

In [32]:
tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Translator(size_layer, num_layers, embedded_size, learning_rate)
sess.run(tf.global_variables_initializer())



(LSTMStateTuple(c=<tf.Tensor 'strided_slice_1:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'strided_slice_1:0' shape=(?, 512) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'strided_slice_1:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'strided_slice_1:0' shape=(?, 512) dtype=float32>))


In [28]:
pad_sequences = tf.keras.preprocessing.sequence.pad_sequences

In [33]:
batch_x = pad_sequences(train_X[:10], padding='post')
batch_y = pad_sequences(train_Y[:10], padding='post')

sess.run([model.fast_result, model.cost, model.accuracy], 
         feed_dict = {model.X: batch_x, model.Y: batch_y})

[array([[19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,
         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,
         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,
         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,
          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,
          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,
         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,
         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],
        [ 2823, 21088, 14140, 14140, 26778,  1178,  1178, 28964, 28964,
         28964, 11199, 11199, 24198, 10398, 27551, 10398, 18156, 18529,
          8502,  8502, 18146,  2039, 21595, 21595,  2379,  2379,  2379,
          9760,  9760,  9760, 24497, 17615, 17615,  4110,  8817,  8817,
         24513, 24513, 10747, 27780, 27780, 25660, 25660, 25660, 27313,
         27313,  7231, 23454,  9132,  9132,  9132, 25565, 13134

In [34]:
import tqdm

for e in range(epoch):
    pbar = tqdm.tqdm(
        range(0, len(train_X), batch_size), desc = 'minibatch loop')
    train_loss, train_acc, test_loss, test_acc = [], [], [], []
    for i in pbar:
        index = min(i + batch_size, len(train_X))
        batch_x = pad_sequences(train_X[i : index], padding='post')
        batch_y = pad_sequences(train_Y[i : index], padding='post')
        feed = {model.X: batch_x,
                model.Y: batch_y}
        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],
                                    feed_dict = feed)
        train_loss.append(loss)
        train_acc.append(accuracy)
        pbar.set_postfix(cost = loss, accuracy = accuracy)
    
    
    pbar = tqdm.tqdm(
        range(0, len(test_X), batch_size), desc = 'minibatch loop')
    for i in pbar:
        index = min(i + batch_size, len(test_X))
        batch_x = pad_sequences(test_X[i : index], padding='post')
        batch_y = pad_sequences(test_Y[i : index], padding='post')
        feed = {model.X: batch_x,
                model.Y: batch_y,}
        accuracy, loss = sess.run([model.accuracy,model.cost],
                                    feed_dict = feed)

        test_loss.append(loss)
        test_acc.append(accuracy)
        pbar.set_postfix(cost = loss, accuracy = accuracy)
    
    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,
                                                                 np.mean(train_loss),np.mean(train_acc)))
    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,
                                                              np.mean(test_loss),np.mean(test_acc)))

minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.136, cost=5.99]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.31it/s, accuracy=0.151, cost=5.5] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 1, training avg loss 6.897701, training avg acc 0.094511
epoch 1, testing avg loss 5.886279, testing avg acc 0.144228


minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.85it/s, accuracy=0.176, cost=5.32]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.51it/s, accuracy=0.204, cost=4.79]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 2, training avg loss 5.496175, training avg acc 0.165414
epoch 2, testing avg loss 5.236630, testing avg acc 0.178004


minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.86it/s, accuracy=0.206, cost=4.95]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.67it/s, accuracy=0.22, cost=4.51] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 3, training avg loss 5.038396, training avg acc 0.190332
epoch 3, testing avg loss 4.943071, testing avg acc 0.197108


minibatch loop: 100%|██████████| 1563/1563 [06:44<00:00,  3.87it/s, accuracy=0.225, cost=4.7] 
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.63it/s, accuracy=0.242, cost=4.34]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 4, training avg loss 4.765566, training avg acc 0.208819
epoch 4, testing avg loss 4.766912, testing avg acc 0.211949


minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.85it/s, accuracy=0.239, cost=4.5] 
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.68it/s, accuracy=0.253, cost=4.22]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 5, training avg loss 4.574376, training avg acc 0.223101
epoch 5, testing avg loss 4.658642, testing avg acc 0.221847


minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.85it/s, accuracy=0.246, cost=4.35]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.56it/s, accuracy=0.253, cost=4.14]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 6, training avg loss 4.432821, training avg acc 0.233963
epoch 6, testing avg loss 4.589426, testing avg acc 0.228429


minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.85it/s, accuracy=0.253, cost=4.21]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.66it/s, accuracy=0.28, cost=4.09] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 7, training avg loss 4.322593, training avg acc 0.242401
epoch 7, testing avg loss 4.546140, testing avg acc 0.233800


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.255, cost=4.09]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.67it/s, accuracy=0.29, cost=4.05] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 8, training avg loss 4.231865, training avg acc 0.249445
epoch 8, testing avg loss 4.521348, testing avg acc 0.236877


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.265, cost=3.98]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.67it/s, accuracy=0.296, cost=4]   
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 9, training avg loss 4.154636, training avg acc 0.255596
epoch 9, testing avg loss 4.501642, testing avg acc 0.239343


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.275, cost=3.89]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.64it/s, accuracy=0.296, cost=3.98]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 10, training avg loss 4.086490, training avg acc 0.261151
epoch 10, testing avg loss 4.492503, testing avg acc 0.241737


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.284, cost=3.8] 
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.66it/s, accuracy=0.301, cost=3.98]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 11, training avg loss 4.025933, training avg acc 0.266182
epoch 11, testing avg loss 4.490067, testing avg acc 0.243361


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.285, cost=3.73]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.65it/s, accuracy=0.301, cost=3.97]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 12, training avg loss 3.971081, training avg acc 0.270903
epoch 12, testing avg loss 4.493936, testing avg acc 0.244468


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.296, cost=3.66]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.67it/s, accuracy=0.296, cost=3.97]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 13, training avg loss 3.920249, training avg acc 0.275454
epoch 13, testing avg loss 4.502479, testing avg acc 0.245590


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.303, cost=3.6] 
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.69it/s, accuracy=0.285, cost=3.97]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 14, training avg loss 3.873923, training avg acc 0.279623
epoch 14, testing avg loss 4.512751, testing avg acc 0.245827


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.312, cost=3.54]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.65it/s, accuracy=0.285, cost=3.97]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 15, training avg loss 3.831019, training avg acc 0.283554
epoch 15, testing avg loss 4.526254, testing avg acc 0.245474


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.322, cost=3.48]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.68it/s, accuracy=0.28, cost=3.96] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 16, training avg loss 3.789867, training avg acc 0.287467
epoch 16, testing avg loss 4.540810, testing avg acc 0.245914


minibatch loop: 100%|██████████| 1563/1563 [07:19<00:00,  3.56it/s, accuracy=0.33, cost=3.42] 
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.74it/s, accuracy=0.285, cost=3.96]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 17, training avg loss 3.750946, training avg acc 0.291335
epoch 17, testing avg loss 4.559640, testing avg acc 0.246204


minibatch loop: 100%|██████████| 1563/1563 [07:18<00:00,  3.56it/s, accuracy=0.336, cost=3.36]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.64it/s, accuracy=0.274, cost=3.97]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 18, training avg loss 3.714676, training avg acc 0.294952
epoch 18, testing avg loss 4.579117, testing avg acc 0.245720


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.341, cost=3.33]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.64it/s, accuracy=0.29, cost=3.95] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 19, training avg loss 3.680817, training avg acc 0.298401
epoch 19, testing avg loss 4.608824, testing avg acc 0.246128


minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.349, cost=3.26]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.65it/s, accuracy=0.29, cost=3.93] 

epoch 20, training avg loss 3.650501, training avg acc 0.301656
epoch 20, testing avg loss 4.604999, testing avg acc 0.246473





In [35]:
from tensor2tensor.utils import bleu_hook

In [36]:
results = []
for i in tqdm.tqdm(range(0, len(test_X), batch_size)):
    index = min(i + batch_size, len(test_X))
    batch_x = pad_sequences(test_X[i : index], padding='post')
    feed = {model.X: batch_x}
    p = sess.run(model.fast_result,feed_dict = feed)
    result = []
    for row in p:
        result.append([i for i in row if i > 3])
    results.extend(result)

100%|██████████| 40/40 [00:08<00:00,  4.62it/s]


In [37]:
rights = []
for r in test_Y:
    rights.append([i for i in r if i > 3])