## Autoencoder Encoder-Decoder

In [1]:
import tensorflow as tf
import numpy as np
import random

random.seed(1337)

class EncoderDecoder:
    def __init__(self, vocabulary={}, state_size=64, n_max_length=30):     
        self.state_size = state_size
        self.n_max_length = n_max_length
        self.vocabulary = vocabulary

        
        ######################
        # Graph Construction #
        ######################
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.sen_en = tf.placeholder(tf.int32, shape=(None, self.n_max_length), name="sen_en")
            self.sen_de = tf.placeholder(tf.int32, shape=(None, self.n_max_length), name="sen_de")
            self.sen_en_length = tf.placeholder(tf.int32, shape=(None,), name="sen_en_length")
            self.sen_de_length = tf.placeholder(tf.int32, shape=(None,), name="sen_de_length")
            
            batch_size = tf.shape(self.sen_en)[0]
            
            # TODO sen_en_embedding could also be self-trained embedding: embedding_lookup
            self.sen_en_embedding = tf.one_hot(self.sen_en, len(self.vocabulary))
            self.sen_de_embedding = tf.one_hot(self.sen_de, len(self.vocabulary))
            
            # build encoder decoder structure
            with tf.variable_scope("encoder") as scope:
                self.cell_en = tf.contrib.rnn.BasicLSTMCell(self.state_size)
            with tf.variable_scope("decoder") as scope:
                self.cell_de = tf.contrib.rnn.BasicLSTMCell(self.state_size)
            with tf.variable_scope("encoder") as scope:
                self.cell_en_init = self.cell_en.zero_state(batch_size, tf.float32)
                self.h_state_en, self.final_state_en = tf.nn.dynamic_rnn(
                    self.cell_en,
                    self.sen_en_embedding,
                    sequence_length=self.sen_en_length,
                    initial_state=self.cell_en_init,
                )
            with tf.variable_scope("decoder") as scope:
                self.cell_de_init = self.final_state_en
                self.h_state_de, self.final_state_de = tf.nn.dynamic_rnn(
                    self.cell_de,
                    self.sen_de_embedding,
                    sequence_length=self.sen_de_length,
                    initial_state=self.cell_de_init,
                )
            

            with tf.variable_scope("softmax") as scope:
                W = tf.get_variable("W", [self.state_size, len(self.vocabulary)], initializer=tf.random_normal_initializer(seed=None))
                b = tf.get_variable("b", [len(self.vocabulary)], initializer=tf.random_normal_initializer(seed=None))               
            self.logits = tf.reshape(
                tf.add(tf.matmul(tf.reshape(self.h_state_de, (-1, self.state_size)), W), b),
                shape=(-1, self.n_max_length, len(self.vocabulary))
            )
            self.prediction = tf.nn.softmax(self.logits)
                
            # construct loss and train op
            self.cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=self.sen_en,
                logits=self.logits
            )        
            #self.mask = tf.sign(tf.reduce_max(tf.abs(self.sen_de_embedding), 2))
            self.mask = tf.sequence_mask(self.sen_de_length, maxlen=self.n_max_length)
            self.loss = tf.reduce_mean(
                #tf.reduce_sum(tf.multiply(self.cross_ent, self.mask), 1) / tf.reduce_sum(self.mask, 1)
                tf.divide(
                    tf.reduce_sum(
                        tf.where(
                            self.mask,
                            self.cross_ent,
                            tf.zeros_like(self.cross_ent)
                        ), 1
                    ),
                    tf.to_float(self.sen_de_length)
                )
            )
            
            """
            optimizer = tf.train.AdamOptimizer()
            self.op_train = optimizer.minimize(self.loss)
            """
            # Calculate and clip gradients
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            self.clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1)
            # Optimization
            optimizer = tf.train.AdamOptimizer()
            self.op_train = optimizer.apply_gradients(zip(self.clipped_gradients, params))
            
            # initializer
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1)
            self.sess = tf.Session(
                graph=self.graph,
                config=tf.ConfigProto(gpu_options=gpu_options)
            )           
            self.init = tf.global_variables_initializer()
            self.sess.run(self.init)
            
    def train(self, batch_sen_en, batch_sen_de, batch_sen_en_length, batch_sen_de_length):
        """
        Parameters
        ----------
        batch_sen_en: numpy, shape=(n, max_length), dtype=int
        batch_sen_de: numpy, shape=(n, max_length), dtype=int
        batch_sen_en_length: numpy, shape=(n,), dtype=int
        batch_sen_de_length: numpy, shape=(n,), dtype=int
        """
        assert batch_sen_en.shape[0] == batch_sen_de.shape[0]
        _, loss, prediction, sen_en_embedding, mask, cross_ent, clipped_gradients = self.sess.run(
            [self.op_train, self.loss, self.prediction, self.sen_en_embedding, self.mask, self.cross_ent, self.clipped_gradients],
            feed_dict={
                self.sen_en: batch_sen_en,
                self.sen_de: batch_sen_de,
                self.sen_en_length: batch_sen_en_length,
                self.sen_de_length: batch_sen_de_length,
            }
        )
        return loss, prediction, sen_en_embedding, mask, cross_ent, clipped_gradients
        
    def predict(self, batch_sen_en, batch_sen_de, batch_sen_en_length, batch_sen_de_length):
        """
        Parameters
        ----------
        batch_sen_en: numpy, shape=(n, max_length), dtype=int
        batch_sen_de: numpy, shape=(n, max_length), dtype=int
        batch_sen_en_length: numpy, shape=(n,), dtype=int
        batch_sen_de_length: numpy, shape=(n,), dtype=int
        """
        assert batch_sen_en.shape[0] == batch_sen_de.shape[0]
        loss, prediction = self.sess.run(
            [self.loss, self.prediction],
            feed_dict={
                self.sen_en: batch_sen_en,
                self.sen_de: batch_sen_de,
                self.sen_en_length: batch_sen_en_length,
                self.sen_de_length: batch_sen_de_length,
            }
        )
        return loss, prediction

    
def evaluate(batch_sen_en, batch_sen_en_length, batch_prediction, vocabulary):
    """
    Parameters
    ----------
    batch_sen_en: numpy, shape=(n, max_length), dtype=int
    batch_sen_en_length: numpy, shape=(n,), dtype=int
    batch_prediction: numpy, shape=(n, max_length, len(vocabulary))
    """
    assert batch_sen_en.shape[0] == batch_prediction.shape[0]
    acc_word = 0
    acc_sen_end = 0
    for i in range(batch_sen_en.shape[0]):
        is_first_end = False
        for j in range(batch_sen_en_length[i]):
            cur_pred_word = np.argmax(batch_prediction[i, j])
            if cur_pred_word == batch_sen_en[i, j]:
                acc_word += 1
                if not is_first_end and cur_pred_word == vocabulary["</s>"]:
                    acc_sen_end += 1
            if cur_pred_word == vocabulary["</s>"]:
                is_first_end = True
    return 1. * acc_word / np.sum(batch_sen_en_length), 1. * acc_sen_end / batch_sen_en.shape[0]

In [2]:
def generate_data(n, max_length, origin_vocabulary, extend_vocabulary):
    sen_en = np.full((n, max_length), extend_vocabulary["<pad>"], dtype=np.int32)
    sen_de = np.full((n, max_length), extend_vocabulary["<pad>"], dtype=np.int32)
    sen_en_length = np.zeros((n,), dtype=np.int32)
    sen_de_length = np.zeros((n,), dtype=np.int32)

    def get_random_sequence(length, max_length):
        x = np.full((max_length), extend_vocabulary["<pad>"], dtype=np.int32)
        for i in range(length):
            x[i] = extend_vocabulary[random.choice(list(origin_vocabulary))]
        return x

    for i in range(n):
        l = random.randint(max_length // 2, max_length-1)
        sen_en[i, :] = get_random_sequence(l, max_length)
        sen_en[i, l] = extend_vocabulary["</s>"]
        sen_de[i, 1:l+1] = sen_en[i, :l]
        sen_de[i, 0] = extend_vocabulary["<s>"]
        sen_en_length[i] = l + 1
        sen_de_length[i] = l + 1
    
    return sen_en, sen_de, sen_en_length, sen_de_length

def get_total_accuracy(data_sen_en, data_sen_de, data_sen_en_length, data_sen_de_length, extend_vocabulary):
    n_hit_word, n_hit_length = 0, 0
    n_total_word, n_total_length = 0, 0
    cur_idx = 0
    while cur_idx < data_sen_en.shape[0]:
        batch_sen_en = data_sen_en[cur_idx: cur_idx + n_batch_size]
        batch_sen_de = data_sen_de[cur_idx: cur_idx + n_batch_size]
        batch_sen_en_length = data_sen_en_length[cur_idx: cur_idx + n_batch_size]
        batch_sen_de_length = data_sen_de_length[cur_idx: cur_idx + n_batch_size]
        
        _, predictions = pretrained_lstm.predict(
            batch_sen_en, batch_sen_de, batch_sen_en_length, batch_sen_de_length
        )
        cur_idx += n_batch_size
        cur_acc_word, cur_acc_length = evaluate(batch_sen_en, batch_sen_en_length, predictions, extend_vocabulary)
        n_hit_word += cur_acc_word * np.sum(batch_sen_en_length)
        n_total_word += np.sum(batch_sen_en_length)
        n_hit_length += cur_acc_length * batch_sen_en.shape[0]
        n_total_length += batch_sen_en.shape[0]
    return 1. * n_hit_word / n_total_word, 1. * n_hit_length / n_total_length
    
# hyperparameter
vocabulary_size = 200
origin_vocabulary = {}
for i in range(vocabulary_size):
    origin_vocabulary["{}".format(i)] = len(origin_vocabulary)
extend_vocabulary = dict(origin_vocabulary)
for w in ["<pad>", "<unk>", "<s>", "</s>"]:
    extend_vocabulary[w] = len(extend_vocabulary)
#vocabulary = {"<pad>": 0, "<unk>": 1, "<s>": 2, "</s>": 3, "a": 4, "b": 5}
state_size=64
n_max_length=30
n_batch_size=100

# generate training/testing data
n_train = 60000
n_test = 10000
train_sen_en, train_sen_de, train_sen_en_length, train_sen_de_length = generate_data(n_train, n_max_length,
                                                                                     origin_vocabulary, extend_vocabulary)
test_sen_en, test_sen_de, test_sen_en_length, test_sen_de_length = generate_data(n_test, n_max_length,
                                                                                 origin_vocabulary, extend_vocabulary)
print(train_sen_en[0])
print(test_sen_en[0])

pretrained_lstm = EncoderDecoder(vocabulary=extend_vocabulary, state_size=state_size, n_max_length=n_max_length)

for epoch in range(20):
    cur_idx = 0
    while cur_idx < train_sen_en.shape[0]:
        batch_sen_en = train_sen_en[cur_idx: cur_idx + n_batch_size]
        batch_sen_de = train_sen_de[cur_idx: cur_idx + n_batch_size]
        batch_sen_en_length = train_sen_en_length[cur_idx: cur_idx + n_batch_size]
        batch_sen_de_length = train_sen_de_length[cur_idx: cur_idx + n_batch_size]
        
        loss, predictions, sen_en_embedding, mask, cross_ent, clipped_gradients = pretrained_lstm.train(
            batch_sen_en, batch_sen_de, batch_sen_en_length, batch_sen_de_length
        )
        cur_idx += n_batch_size
    print("epoch", epoch, "last batch", evaluate(batch_sen_en, batch_sen_en_length, predictions, extend_vocabulary))

print("train", get_total_accuracy(
    train_sen_en, train_sen_de, train_sen_en_length, train_sen_de_length, extend_vocabulary
))
print("test", get_total_accuracy(
    test_sen_en, test_sen_de, test_sen_en_length, test_sen_de_length, extend_vocabulary
)) 

[109 142  72 196 164  88  36  51 150 147  49 115  19  34 112  48  44  68
  72   5  57  49  95 104 203 200 200 200 200 200]
[ 65 164 191  17  66 142  54  43 140 192 197 105 168 143 169  33 180  19
  52 196 193   6  54  14 183 203 200 200 200 200]
epoch 0 last batch (0.052609067579127457, 0.0)
epoch 1 last batch (0.064157399486740804, 0.0)
epoch 2 last batch (0.088109495295124032, 0.23)
epoch 3 last batch (0.094097519247219846, 0.24)
epoch 4 last batch (0.11291702309666382, 0.59)
epoch 5 last batch (0.12917023096663816, 0.76)
epoch 6 last batch (0.12745936698032506, 0.47)
epoch 7 last batch (0.14499572284003423, 0.71)
epoch 8 last batch (0.14414029084687768, 0.85)
epoch 9 last batch (0.16295979469632163, 0.66)
epoch 10 last batch (0.17108639863130881, 0.78)
epoch 11 last batch (0.16680923866552608, 0.77)
epoch 12 last batch (0.1775021385799829, 0.76)
epoch 13 last batch (0.18092386655260906, 0.65)
epoch 14 last batch (0.19289991445680069, 0.63)
epoch 15 last batch (0.18776732249786143, 0