In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
import numpy as np
import tensorflow as tf
import json

In [3]:
with open('dataset-bpe.json') as fopen:
    data = json.load(fopen)

In [4]:
train_X = data['train_X']
train_Y = data['train_Y']
test_X = data['test_X']
test_Y = data['test_Y']

In [5]:
EOS = 2
GO = 1
vocab_size = 32000

In [6]:
train_Y = [i + [2] for i in train_Y]
test_Y = [i + [2] for i in test_Y]

In [16]:
from tensor2tensor.utils import beam_search

def pad_second_dim(x, desired_size):
    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))
    return tf.concat([x, padding], 1)

class Translator:
    def __init__(self, size_layer, num_layers, embedded_size, learning_rate):
        
        def cells(size_layer, reuse=False):
            return tf.nn.rnn_cell.BasicRNNCell(size_layer,reuse=reuse)
        
        self.X = tf.placeholder(tf.int32, [None, None])
        self.Y = tf.placeholder(tf.int32, [None, None])
        
        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)
        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)
        batch_size = tf.shape(self.X)[0]
        
        embeddings = tf.Variable(tf.random_uniform([vocab_size, embedded_size], -1, 1))
        encoder_embedded = tf.nn.embedding_lookup(embeddings, self.X)
        
        for n in range(num_layers):
            (out_fw, out_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = cells(size_layer // 2),
                cell_bw = cells(size_layer // 2),
                inputs = encoder_embedded,
                sequence_length = self.X_seq_len,
                dtype = tf.float32,
                scope = 'bidirectional_rnn_%d'%(n))
            encoder_embedded = tf.concat((out_fw, out_bw), 2)
        
        bi_state = tf.concat((state_fw, state_bw), -1)
        encoder_state = tuple([bi_state] * num_layers)
        
        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])
        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)
        dense = tf.layers.Dense(vocab_size)
        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([cells(size_layer) for _ in range(num_layers)])
        
        training_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs = tf.nn.embedding_lookup(embeddings, decoder_input),
                sequence_length = self.Y_seq_len,
                time_major = False)
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell = decoder_cells,
                helper = training_helper,
                initial_state = encoder_state,
                output_layer = dense)
        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = training_decoder,
                impute_finished = True,
                maximum_iterations = tf.reduce_max(self.Y_seq_len))
        self.training_logits = training_decoder_output.rnn_output
        
        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                embedding = embeddings,
                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),
                end_token = EOS)
        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell = decoder_cells,
                helper = predicting_helper,
                initial_state = encoder_state,
                output_layer = dense)
        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = predicting_decoder,
                impute_finished = True,
                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))
        self.fast_result = predicting_decoder_output.sample_id
        
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks)
        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)
        y_t = tf.argmax(self.training_logits,axis=2)
        y_t = tf.cast(y_t, tf.int32)
        self.prediction = tf.boolean_mask(y_t, masks)
        mask_label = tf.boolean_mask(self.Y, masks)
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [17]:
size_layer = 512
num_layers = 2
embedded_size = 256
learning_rate = 1e-3
batch_size = 128
epoch = 20

In [18]:
tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Translator(size_layer, num_layers, embedded_size, learning_rate)
sess.run(tf.global_variables_initializer())

Instructions for updating:
This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [19]:
pad_sequences = tf.keras.preprocessing.sequence.pad_sequences

In [20]:
batch_x = pad_sequences(train_X[:10], padding='post')
batch_y = pad_sequences(train_Y[:10], padding='post')

sess.run([model.fast_result, model.cost, model.accuracy], 
         feed_dict = {model.X: batch_x, model.Y: batch_y})

[array([[ 7041, 18531, 30711,  1739, 12146,  2835,  7798,  2890,  2698,
          4719, 18187,  4328, 13869, 31295,  9404, 11027,  2934, 14448,
         30491, 27088, 12032, 24876, 19813, 29033, 21136, 22400, 15029,
         12060, 25025, 25914, 27492, 24334, 11805,  9301,  7573, 19687,
          6376,  3552,  9032, 23418, 16544, 17757,  9092, 29078,  9458,
          5691, 19373, 30539, 27605, 19633, 25595, 22955, 23336, 30068,
          4279,  4003, 13061, 29064, 30123, 13727,  2609, 12456,  7578,
          6108, 17477, 10421, 12575, 18437, 25522, 25121, 22630, 21740],
        [17956,  3084,  1602, 11595, 18838, 11749, 19924, 10231,  1185,
         15923,  7877, 29854,  8062, 14342, 23727, 18701, 25000,  4564,
          1755, 13671,  5859, 30610, 29094, 13241, 17423,  6377,  3206,
          1816, 29632, 24088, 30386, 16604, 30927,  5336, 20420, 12963,
         16035,   826, 15211,    19, 11203,  8852, 29168, 15080,   548,
          9467, 27338, 26813, 23767, 26826, 23329,  5731, 12611

In [21]:
import tqdm

for e in range(epoch):
    pbar = tqdm.tqdm(
        range(0, len(train_X), batch_size), desc = 'minibatch loop')
    train_loss, train_acc, test_loss, test_acc = [], [], [], []
    for i in pbar:
        index = min(i + batch_size, len(train_X))
        batch_x = pad_sequences(train_X[i : index], padding='post')
        batch_y = pad_sequences(train_Y[i : index], padding='post')
        feed = {model.X: batch_x,
                model.Y: batch_y}
        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],
                                    feed_dict = feed)
        train_loss.append(loss)
        train_acc.append(accuracy)
        pbar.set_postfix(cost = loss, accuracy = accuracy)
    
    
    pbar = tqdm.tqdm(
        range(0, len(test_X), batch_size), desc = 'minibatch loop')
    for i in pbar:
        index = min(i + batch_size, len(test_X))
        batch_x = pad_sequences(test_X[i : index], padding='post')
        batch_y = pad_sequences(test_Y[i : index], padding='post')
        feed = {model.X: batch_x,
                model.Y: batch_y,}
        accuracy, loss = sess.run([model.accuracy,model.cost],
                                    feed_dict = feed)

        test_loss.append(loss)
        test_acc.append(accuracy)
        pbar.set_postfix(cost = loss, accuracy = accuracy)
    
    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,
                                                                 np.mean(train_loss),np.mean(train_acc)))
    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,
                                                              np.mean(test_loss),np.mean(test_acc)))

minibatch loop: 100%|██████████| 1563/1563 [08:26<00:00,  3.09it/s, accuracy=0.187, cost=5.31]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.75it/s, accuracy=0.199, cost=4.67]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 1, training avg loss 6.094989, training avg acc 0.149891
epoch 1, testing avg loss 5.148298, testing avg acc 0.200459


minibatch loop: 100%|██████████| 1563/1563 [08:23<00:00,  3.10it/s, accuracy=0.234, cost=4.69]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.84it/s, accuracy=0.226, cost=4.32]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 2, training avg loss 4.863626, training avg acc 0.221084
epoch 2, testing avg loss 4.732797, testing avg acc 0.230729


minibatch loop: 100%|██████████| 1563/1563 [08:28<00:00,  3.07it/s, accuracy=0.261, cost=4.32]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.84it/s, accuracy=0.247, cost=4.13]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 3, training avg loss 4.487107, training avg acc 0.248027
epoch 3, testing avg loss 4.561285, testing avg acc 0.247583


minibatch loop: 100%|██████████| 1563/1563 [08:31<00:00,  3.06it/s, accuracy=0.277, cost=4.09]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.75it/s, accuracy=0.247, cost=4.08]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 4, training avg loss 4.269172, training avg acc 0.264541
epoch 4, testing avg loss 4.518294, testing avg acc 0.253836


minibatch loop: 100%|██████████| 1563/1563 [08:33<00:00,  3.04it/s, accuracy=0.284, cost=3.92]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.80it/s, accuracy=0.269, cost=4]   
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 5, training avg loss 4.123998, training avg acc 0.275203
epoch 5, testing avg loss 4.479871, testing avg acc 0.258289


minibatch loop: 100%|██████████| 1563/1563 [08:33<00:00,  3.04it/s, accuracy=0.294, cost=3.76]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.82it/s, accuracy=0.269, cost=3.95]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 6, training avg loss 4.019932, training avg acc 0.283128
epoch 6, testing avg loss 4.476587, testing avg acc 0.260980


minibatch loop: 100%|██████████| 1563/1563 [08:34<00:00,  3.04it/s, accuracy=0.306, cost=3.65]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.80it/s, accuracy=0.29, cost=3.92] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 7, training avg loss 3.936753, training avg acc 0.289535
epoch 7, testing avg loss 4.484711, testing avg acc 0.261932


minibatch loop: 100%|██████████| 1563/1563 [08:34<00:00,  3.04it/s, accuracy=0.319, cost=3.52]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.77it/s, accuracy=0.274, cost=3.98]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 8, training avg loss 3.868007, training avg acc 0.294768
epoch 8, testing avg loss 4.497402, testing avg acc 0.261676


minibatch loop: 100%|██████████| 1563/1563 [08:31<00:00,  3.06it/s, accuracy=0.329, cost=3.42]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.84it/s, accuracy=0.269, cost=3.97]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 9, training avg loss 3.813459, training avg acc 0.298996
epoch 9, testing avg loss 4.509614, testing avg acc 0.262880


minibatch loop: 100%|██████████| 1563/1563 [08:28<00:00,  3.07it/s, accuracy=0.329, cost=3.37]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.86it/s, accuracy=0.263, cost=4]   
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 10, training avg loss 3.764877, training avg acc 0.302706
epoch 10, testing avg loss 4.537440, testing avg acc 0.261739


minibatch loop: 100%|██████████| 1563/1563 [08:29<00:00,  3.07it/s, accuracy=0.336, cost=3.3] 
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.84it/s, accuracy=0.29, cost=4.04] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 11, training avg loss 3.720549, training avg acc 0.306112
epoch 11, testing avg loss 4.539324, testing avg acc 0.263286


minibatch loop: 100%|██████████| 1563/1563 [08:29<00:00,  3.07it/s, accuracy=0.34, cost=3.21] 
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.87it/s, accuracy=0.312, cost=3.95]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 12, training avg loss 3.671926, training avg acc 0.310362
epoch 12, testing avg loss 4.551347, testing avg acc 0.264103


minibatch loop: 100%|██████████| 1563/1563 [08:32<00:00,  3.05it/s, accuracy=0.351, cost=3.22]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.79it/s, accuracy=0.269, cost=4.14]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 13, training avg loss 3.639534, training avg acc 0.313105
epoch 13, testing avg loss 4.582617, testing avg acc 0.263138


minibatch loop: 100%|██████████| 1563/1563 [08:31<00:00,  3.06it/s, accuracy=0.349, cost=3.15]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.82it/s, accuracy=0.274, cost=4.17]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 14, training avg loss 3.607992, training avg acc 0.315521
epoch 14, testing avg loss 4.609492, testing avg acc 0.262273


minibatch loop: 100%|██████████| 1563/1563 [08:30<00:00,  3.06it/s, accuracy=0.37, cost=3.1]  
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.85it/s, accuracy=0.263, cost=4.16]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 15, training avg loss 3.581495, training avg acc 0.317952
epoch 15, testing avg loss 4.626431, testing avg acc 0.262556


minibatch loop: 100%|██████████| 1563/1563 [08:31<00:00,  3.06it/s, accuracy=0.362, cost=3.13]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.93it/s, accuracy=0.306, cost=4.03]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 16, training avg loss 3.552525, training avg acc 0.320681
epoch 16, testing avg loss 4.659593, testing avg acc 0.262065


minibatch loop: 100%|██████████| 1563/1563 [08:31<00:00,  3.06it/s, accuracy=0.36, cost=3.04] 
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.84it/s, accuracy=0.285, cost=4.21]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 17, training avg loss 3.536075, training avg acc 0.322026
epoch 17, testing avg loss 4.667953, testing avg acc 0.261820


minibatch loop: 100%|██████████| 1563/1563 [08:30<00:00,  3.06it/s, accuracy=0.362, cost=3.07]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.80it/s, accuracy=0.312, cost=4.22]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 18, training avg loss 3.529873, training avg acc 0.322200
epoch 18, testing avg loss 4.683846, testing avg acc 0.261584


minibatch loop: 100%|██████████| 1563/1563 [08:30<00:00,  3.06it/s, accuracy=0.377, cost=3.01]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.80it/s, accuracy=0.312, cost=4.23]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 19, training avg loss 3.527786, training avg acc 0.321770
epoch 19, testing avg loss 4.694051, testing avg acc 0.261892


minibatch loop: 100%|██████████| 1563/1563 [08:31<00:00,  3.06it/s, accuracy=0.377, cost=2.95]
minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  5.83it/s, accuracy=0.317, cost=4.1] 

epoch 20, training avg loss 3.473880, training avg acc 0.327499
epoch 20, testing avg loss 4.705862, testing avg acc 0.263245





In [22]:
from tensor2tensor.utils import bleu_hook

results = []
for i in tqdm.tqdm(range(0, len(test_X), batch_size)):
    index = min(i + batch_size, len(test_X))
    batch_x = pad_sequences(test_X[i : index], padding='post')
    feed = {model.X: batch_x}
    p = sess.run(model.fast_result,feed_dict = feed)
    result = []
    for row in p:
        result.append([i for i in row if i > 3])
    results.extend(result)
    
rights = []
for r in test_Y:
    rights.append([i for i in r if i > 3])
    
bleu_hook.compute_bleu(reference_corpus = rights,
                       translation_corpus = results)

100%|██████████| 40/40 [00:17<00:00,  2.24it/s]


0.019748569