In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import numpy as np
import tensorflow as tf
import json

In [3]:
with open('dataset-bpe.json') as fopen:
    data = json.load(fopen)

In [4]:
train_X = data['train_X']
train_Y = data['train_Y']
test_X = data['test_X']
test_Y = data['test_Y']

In [5]:
EOS = 2
GO = 1
vocab_size = 32000

In [6]:
train_Y = [i + [2] for i in train_Y]
test_Y = [i + [2] for i in test_Y]

In [20]:
from tensor2tensor.utils import beam_search

def pad_second_dim(x, desired_size):
    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))
    return tf.concat([x, padding], 1)

class Translator:
    def __init__(self, size_layer, num_layers, embedded_size, learning_rate):
        
        def cells(reuse=False):
            return tf.nn.rnn_cell.BasicRNNCell(size_layer,reuse=reuse)
        
        self.X = tf.placeholder(tf.int32, [None, None])
        self.Y = tf.placeholder(tf.int32, [None, None])
        
        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)
        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)
        batch_size = tf.shape(self.X)[0]
        
        embeddings = tf.Variable(tf.random_uniform([vocab_size, embedded_size], -1, 1))
        
        _, encoder_state = tf.nn.dynamic_rnn(
            cell = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)]), 
            inputs = tf.nn.embedding_lookup(embeddings, self.X),
            sequence_length = self.X_seq_len,
            dtype = tf.float32)
        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])
        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)
        dense = tf.layers.Dense(vocab_size)
        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)])
        
        training_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs = tf.nn.embedding_lookup(embeddings, decoder_input),
                sequence_length = self.Y_seq_len,
                time_major = False)
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell = decoder_cells,
                helper = training_helper,
                initial_state = encoder_state,
                output_layer = dense)
        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = training_decoder,
                impute_finished = True,
                maximum_iterations = tf.reduce_max(self.Y_seq_len))
        self.training_logits = training_decoder_output.rnn_output
        
        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                embedding = embeddings,
                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),
                end_token = EOS)
        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell = decoder_cells,
                helper = predicting_helper,
                initial_state = encoder_state,
                output_layer = dense)
        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = predicting_decoder,
                impute_finished = True,
                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))
        self.fast_result = predicting_decoder_output.sample_id
        
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y,
                                                     weights = masks)
        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)
        y_t = tf.argmax(self.training_logits,axis=2)
        y_t = tf.cast(y_t, tf.int32)
        self.prediction = tf.boolean_mask(y_t, masks)
        mask_label = tf.boolean_mask(self.Y, masks)
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [21]:
size_layer = 512
num_layers = 2
embedded_size = 256
learning_rate = 1e-3
batch_size = 128
epoch = 20

In [22]:
tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Translator(size_layer, num_layers, embedded_size, learning_rate)
sess.run(tf.global_variables_initializer())



In [23]:
pad_sequences = tf.keras.preprocessing.sequence.pad_sequences

In [24]:
batch_x = pad_sequences(train_X[:10], padding='post')
batch_y = pad_sequences(train_Y[:10], padding='post')

sess.run([model.fast_result, model.cost, model.accuracy], 
         feed_dict = {model.X: batch_x, model.Y: batch_y})

[array([[ 8092,   364, 23527, 16731, 12432, 24937, 21081, 21142, 14804,
         18234, 25024, 14908, 27933, 23633,  8088, 26961, 10696,  8350,
          2336,  1560,  9475,  7028,  1952,  9737, 27888, 27603, 21554,
          5376, 30761, 24453, 14154, 25407, 13988, 11466, 27134, 17576,
         10293,   435, 28450, 28138, 31434,  2669,  9231, 21043, 29167,
          3865, 12123, 26151, 22312, 20040, 16020,  7213, 14383,  6306,
         17745, 20872, 21499, 20713, 27365,  4323, 30281,   647, 23627,
         10903, 12831, 24343,   869, 27604, 15795, 15732, 28700,  7564],
        [13267, 25553,   398, 27889,  9162, 18628, 28777,  7077, 25680,
         10341, 31146, 16193, 15781,  5684, 20317,  1807,  3450,  9648,
         19283, 18304, 14013, 11349,  8453,  3195,   463, 20584,  8176,
         19834, 13095, 24631, 26027,  5502,  2789, 13208, 19939, 28314,
          7387, 16891, 23451, 19947, 17939,  7440,  4705,  5147, 29115,
          9981, 21912,  9427, 19677, 30367, 11677, 18783, 29427

In [25]:
import tqdm

for e in range(epoch):
    pbar = tqdm.tqdm(
        range(0, len(train_X), batch_size), desc = 'minibatch loop')
    train_loss, train_acc, test_loss, test_acc = [], [], [], []
    for i in pbar:
        index = min(i + batch_size, len(train_X))
        batch_x = pad_sequences(train_X[i : index], padding='post')
        batch_y = pad_sequences(train_Y[i : index], padding='post')
        feed = {model.X: batch_x,
                model.Y: batch_y}
        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],
                                    feed_dict = feed)
        train_loss.append(loss)
        train_acc.append(accuracy)
        pbar.set_postfix(cost = loss, accuracy = accuracy)
    
    
    pbar = tqdm.tqdm(
        range(0, len(test_X), batch_size), desc = 'minibatch loop')
    for i in pbar:
        index = min(i + batch_size, len(test_X))
        batch_x = pad_sequences(test_X[i : index], padding='post')
        batch_y = pad_sequences(test_Y[i : index], padding='post')
        feed = {model.X: batch_x,
                model.Y: batch_y,}
        accuracy, loss = sess.run([model.accuracy,model.cost],
                                    feed_dict = feed)

        test_loss.append(loss)
        test_acc.append(accuracy)
        pbar.set_postfix(cost = loss, accuracy = accuracy)
    
    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,
                                                                 np.mean(train_loss),np.mean(train_acc)))
    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,
                                                              np.mean(test_loss),np.mean(test_acc)))

minibatch loop: 100%|██████████| 1563/1563 [06:34<00:00,  3.96it/s, accuracy=0.149, cost=5.87]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.99it/s, accuracy=0.161, cost=5.22]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 1, training avg loss 6.435658, training avg acc 0.124611
epoch 1, testing avg loss 5.722437, testing avg acc 0.162499


minibatch loop: 100%|██████████| 1563/1563 [06:33<00:00,  3.98it/s, accuracy=0.182, cost=5.46]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.30it/s, accuracy=0.167, cost=4.85]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 2, training avg loss 5.501996, training avg acc 0.174138
epoch 2, testing avg loss 5.348543, testing avg acc 0.181107


minibatch loop: 100%|██████████| 1563/1563 [06:53<00:00,  3.78it/s, accuracy=0.183, cost=5.25]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.36it/s, accuracy=0.226, cost=4.72]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 3, training avg loss 5.225808, training avg acc 0.188166
epoch 3, testing avg loss 5.185829, testing avg acc 0.189879


minibatch loop: 100%|██████████| 1563/1563 [06:55<00:00,  3.77it/s, accuracy=0.197, cost=5.06]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.92it/s, accuracy=0.231, cost=4.58]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 4, training avg loss 5.058506, training avg acc 0.198194
epoch 4, testing avg loss 5.081503, testing avg acc 0.196356


minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.208, cost=4.92]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.87it/s, accuracy=0.22, cost=4.57] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 5, training avg loss 4.939404, training avg acc 0.206569
epoch 5, testing avg loss 4.986885, testing avg acc 0.204661


minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.226, cost=4.82]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.91it/s, accuracy=0.237, cost=4.57]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 6, training avg loss 4.844442, training avg acc 0.213995
epoch 6, testing avg loss 4.951038, testing avg acc 0.208642


minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.235, cost=4.74]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.89it/s, accuracy=0.22, cost=4.54] 
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 7, training avg loss 4.767762, training avg acc 0.220543
epoch 7, testing avg loss 4.901244, testing avg acc 0.212730


minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.234, cost=4.67]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.89it/s, accuracy=0.226, cost=4.55]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 8, training avg loss 4.699906, training avg acc 0.225791
epoch 8, testing avg loss 4.878349, testing avg acc 0.215059


minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.244, cost=4.61]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.94it/s, accuracy=0.231, cost=4.46]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 9, training avg loss 4.651851, training avg acc 0.229174
epoch 9, testing avg loss 4.843753, testing avg acc 0.218257


minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.252, cost=4.53]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.91it/s, accuracy=0.242, cost=4.49]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 10, training avg loss 4.612467, training avg acc 0.232239
epoch 10, testing avg loss 4.829657, testing avg acc 0.220272


minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.256, cost=4.48]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.91it/s, accuracy=0.237, cost=4.45]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 11, training avg loss 4.579409, training avg acc 0.234620
epoch 11, testing avg loss 4.824563, testing avg acc 0.221824


minibatch loop: 100%|██████████| 1563/1563 [06:41<00:00,  3.90it/s, accuracy=0.249, cost=4.46]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.91it/s, accuracy=0.242, cost=4.46]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 12, training avg loss 4.544057, training avg acc 0.237470
epoch 12, testing avg loss 4.818059, testing avg acc 0.222746


minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.259, cost=4.37]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.29it/s, accuracy=0.247, cost=4.45]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 13, training avg loss 4.518851, training avg acc 0.239499
epoch 13, testing avg loss 4.801131, testing avg acc 0.225936


minibatch loop: 100%|██████████| 1563/1563 [06:35<00:00,  3.95it/s, accuracy=0.258, cost=4.39]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.12it/s, accuracy=0.242, cost=4.45]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 14, training avg loss 4.487974, training avg acc 0.242044
epoch 14, testing avg loss 4.815407, testing avg acc 0.222855


minibatch loop: 100%|██████████| 1563/1563 [06:37<00:00,  3.94it/s, accuracy=0.267, cost=4.35]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.15it/s, accuracy=0.231, cost=4.48]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 15, training avg loss 4.474396, training avg acc 0.243263
epoch 15, testing avg loss 4.852764, testing avg acc 0.220577


minibatch loop: 100%|██████████| 1563/1563 [06:57<00:00,  3.75it/s, accuracy=0.27, cost=4.28] 
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.11it/s, accuracy=0.253, cost=4.42]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 16, training avg loss 4.440702, training avg acc 0.246357
epoch 16, testing avg loss 4.798965, testing avg acc 0.225011


minibatch loop: 100%|██████████| 1563/1563 [06:36<00:00,  3.94it/s, accuracy=0.265, cost=4.25]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.11it/s, accuracy=0.237, cost=4.46]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 17, training avg loss 4.406374, training avg acc 0.249040
epoch 17, testing avg loss 4.798480, testing avg acc 0.225326


minibatch loop: 100%|██████████| 1563/1563 [06:36<00:00,  3.94it/s, accuracy=0.271, cost=4.27]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.20it/s, accuracy=0.253, cost=4.48]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 18, training avg loss 4.431464, training avg acc 0.246378
epoch 18, testing avg loss 4.807458, testing avg acc 0.226043


minibatch loop: 100%|██████████| 1563/1563 [06:57<00:00,  3.75it/s, accuracy=0.273, cost=4.19]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.12it/s, accuracy=0.226, cost=4.49]
minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]

epoch 19, training avg loss 4.376853, training avg acc 0.251078
epoch 19, testing avg loss 4.796037, testing avg acc 0.226825


minibatch loop: 100%|██████████| 1563/1563 [06:36<00:00,  3.94it/s, accuracy=0.253, cost=4.28]
minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.13it/s, accuracy=0.242, cost=4.35]

epoch 20, training avg loss 4.349658, training avg acc 0.253415
epoch 20, testing avg loss 4.827734, testing avg acc 0.224312





In [29]:
from tensor2tensor.utils import bleu_hook

results = []
for i in tqdm.tqdm(range(0, len(test_X), batch_size)):
    index = min(i + batch_size, len(test_X))
    batch_x = pad_sequences(test_X[i : index], padding='post')
    feed = {model.X: batch_x}
    p = sess.run(model.fast_result,feed_dict = feed)
    result = []
    for row in p:
        result.append([i for i in row if i > 3])
    results.extend(result)
    
rights = []
for r in test_Y:
    rights.append([i for i in r if i > 3])
    
bleu_hook.compute_bleu(reference_corpus = rights,
                       translation_corpus = results)

100%|██████████| 40/40 [00:23<00:00,  1.68it/s]


0.005418866