In [1]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm

In [2]:
maxlen = 20
max_vocab = 20000

In [3]:
word2idx = tf.keras.datasets.imdb.get_word_index()
word2idx = {k: (v + 4) for k, v in word2idx.items()}
word2idx['<PAD>'] = 0
word2idx['<START>'] = 1
word2idx['<UNK>'] = 2
word2idx['<END>'] = 3
idx2word = {i: w for w, i in word2idx.items()}

In [4]:
(train_X, _), (test_X, _) = tf.contrib.keras.datasets.imdb.load_data(num_words = max_vocab, index_from= 4)

In [5]:
X = np.concatenate([train_X, test_X])

In [6]:
X = np.concatenate((tf.keras.preprocessing.sequence.pad_sequences(
                            X, maxlen, truncating='post', padding='post'),
                        tf.keras.preprocessing.sequence.pad_sequences(
                            X, maxlen, truncating='pre', padding='post')))

In [7]:
Y_input = X[:]
Y_output = np.concatenate([X[:, 1:], np.full([X.shape[0], 1], word2idx['<END>'])], 1)
X = X[:, 1:]

In [8]:
X.shape, Y_input.shape, Y_output.shape

((100000, 19), (100000, 20), (100000, 20))

In [9]:
import decoder

In [10]:
class VAE:
    def __init__(self, size_layer, num_layers, embedded_size, dict_size, learning_rate,
                dropout = 0.8):
        
        self.X = tf.placeholder(tf.int32, [None, None])
        self.Y_input = tf.placeholder(tf.int32, [None, None])
        self.Y_output = tf.placeholder(tf.int32, [None, None])
        self.lambda_coeff = tf.placeholder(tf.float32, shape=())
        batch_size = tf.shape(self.X)[0]
        
        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)
        self.Y_seq_len = tf.count_nonzero(self.Y_input, 1, dtype=tf.int32)
        
        encoder_embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))
        decoder_embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))
        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)
        decoder_embedded = tf.nn.embedding_lookup(decoder_embeddings, self.Y_input)
        
        main = tf.strided_slice(self.Y_input, [0, 0], [batch_size, -1], [1, 1])
        decoder_input = tf.concat([tf.fill([batch_size, 1], word2idx['<START>']), main], 1)
        
        for i in range(num_layers):
            with tf.variable_scope('encoder_%d'%(i)):
                cell_fw = tf.contrib.rnn.LayerNormBasicLSTMCell(size_layer)
                cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw, input_keep_prob=dropout)

                cell_bw = tf.contrib.rnn.LayerNormBasicLSTMCell(size_layer)
                cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw, input_keep_prob=dropout)
                
                self.enc_output, self.enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                                                  cell_bw,
                                                                                  encoder_embedded,
                                                                                  self.X_seq_len,
                                                                                  dtype=tf.float32)
        
        self.encoder_state = tf.concat([self.enc_state[0][1], self.enc_state[1][1]], axis=-1)
        self.encoder_out = tf.concat(self.enc_output, 2)
        self.z_mean = tf.layers.dense(self.encoder_state, size_layer)
        self.z_log_sigma = tf.layers.dense(self.encoder_state, size_layer)
        
        epsilon = tf.random_normal(tf.shape(self.z_log_sigma))
        self.z_vector = self.z_mean + tf.exp(self.z_log_sigma)
        
        dense = tf.layers.Dense(dict_size)
        
        decoder_cells = []
        for i in range(num_layers):
            dec_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(2 * size_layer)
            dec_cell = tf.contrib.rnn.DropoutWrapper(dec_cell, input_keep_prob=dropout)
            decoder_cells.append(dec_cell)
        
        self.decoder_cells = tf.nn.rnn_cell.MultiRNNCell(decoder_cells)
        
        attn_mech = tf.contrib.seq2seq.LuongAttention(2 * size_layer,
                                                      self.encoder_out,
                                                      memory_sequence_length=self.X_seq_len)
        
        attn_cell = tf.contrib.seq2seq.AttentionWrapper(self.decoder_cells, attn_mech, size_layer)
        self.init_state = attn_cell.zero_state(batch_size, tf.float32)
        
        training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embedded,
                                                            sequence_length=self.Y_seq_len,
                                                            time_major=False)
        
        training_decoder = decoder.BasicDecoder(attn_cell,
                                                training_helper,
                                                initial_state=self.init_state,
                                                latent_vector=self.z_vector,
                                                output_layer=dense)
        self.training_logits, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                       output_time_major=False,
                                                                       impute_finished=True,
                                                                       maximum_iterations=
                                                                       tf.reduce_max(self.Y_seq_len))
        self.training_logits = self.training_logits.rnn_output
        
        start_tokens = tf.tile(tf.constant([word2idx['<START>']], dtype=tf.int32), 
                               [batch_size])
        inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings,
                                                                    start_tokens,
                                                                    word2idx['<END>'])
        inference_decoder = decoder.BasicDecoder(attn_cell,
                                                         inference_helper,
                                                         initial_state=self.init_state,
                                                         latent_vector=self.z_vector,
                                                         output_layer=dense)
        self.inference_logits, _, _ = tf.contrib.seq2seq.dynamic_decode(inference_decoder,
                                                                        output_time_major=False,
                                                                        impute_finished=True,
                                                                        maximum_iterations=
                                                                        tf.reduce_max(self.X_seq_len))
        self.logits = self.inference_logits.sample_id
        self.kl_loss = -0.5 * tf.reduce_sum(1.0 + 2 * self.z_log_sigma - self.z_mean ** 2 - 
                             tf.exp(2 * self.z_log_sigma), 1)
        self.kl_loss = tf.scalar_mul(self.lambda_coeff, self.kl_loss)
        
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,
                                                     targets = self.Y_output,
                                                     weights = masks)
        self.cost = tf.reduce_sum(self.cost + self.kl_loss)
        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)

In [11]:
size_layer = 128
num_layers = 2
embedded_size = 128
learning_rate = 1e-3
batch_size = 8
epoch = 7

In [12]:
tf.reset_default_graph()
sess = tf.InteractiveSession()
model = VAE(size_layer, num_layers, embedded_size, len(word2idx), learning_rate)
sess.run(tf.global_variables_initializer())

In [13]:
def word_dropout(x):
    is_dropped = np.random.binomial(1, 0.8, x.shape)
    fn = np.vectorize(lambda x, k: word2idx['<UNK>'] if (
                      k and (x not in range(4))) else x)
    return fn(x, is_dropped)

def inf_inp(test_strs):
    x = [[word2idx.get(w, 2) for w in s.split()] for s in test_strs]
    x = tf.keras.preprocessing.sequence.pad_sequences(
        x, maxlen, truncating='post', padding='post')
    return x

In [14]:
test_strings = ['i love this film and i think it is one of the best films',
             'this movie is a waste of time and there is no point to watch it']

In [15]:
y_input = word_dropout(Y_input[:2])
y_output = Y_output[:2]
x = X[:2]

In [16]:
sess.run([model.cost, model.kl_loss],
         feed_dict = {model.X: x, model.Y_input: y_input,
                      model.Y_output: y_output,
                      model.lambda_coeff: 0})

[22.774961, array([0., 0.], dtype=float32)]

In [17]:
r_aug = sess.run(model.logits, feed_dict = {model.X: inf_inp(test_strings)})[0]
' '.join([idx2word[r] for r in r_aug])

"tacones lois's krista's krista's onto onto krista's krista's krista's d'orleans' rettig krista's krista's rettig rettig"

In [18]:
epoch = 10
batch_size = 32

In [19]:
iter_i = 0
lambda_val = 0.0

In [20]:
for e in range(epoch):
    pbar = tqdm(
        range(0, len(X), batch_size), desc = 'minibatch loop')
    cost = 0
    for i in pbar:
        iter_i += 1
        if iter_i <= 3000:
            lambda_val = np.round((np.tanh((iter_i - 4500) / 1000) + 1) / 2, decimals=6)
            
        index = min(i + batch_size, len(X))
        y_input = word_dropout(Y_input[i: index])
        y_output = Y_output[i: index]
        x = X[i: index]
        c, _ = sess.run([model.cost, model.optimizer],
         feed_dict = {model.X: x, model.Y_input: y_input,
                      model.Y_output: y_output,
                      model.lambda_coeff: lambda_val})
        cost += c
        pbar.set_postfix(cost = c)
    cost /= (len(X) / batch_size)
    r_aug = sess.run(model.logits, feed_dict = {model.X: inf_inp(test_strings)})[0]
    print('epoch %d, average loss %f'%(e + 1, cost))
    print('real string: %s'%(test_strings[0]))
    print('augmented string: %s'%(' '.join([idx2word[r] for r in r_aug])))

minibatch loop: 100%|██████████| 3125/3125 [19:34<00:00,  2.63it/s, cost=3.85]
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 1, average loss 53.976270
real string: i love this film and i think it is one of the best films
augmented string: i love this film and and i think it is one of the the best


minibatch loop: 100%|██████████| 3125/3125 [19:18<00:00,  2.74it/s, cost=0.855]
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 2, average loss 4.715931
real string: i love this film and i think it is one of the best films
augmented string: i love this film film i i think it is one of the the best


minibatch loop: 100%|██████████| 3125/3125 [19:15<00:00,  2.71it/s, cost=0.735]
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 3, average loss 2.217968
real string: i love this film and i think it is one of the best films
augmented string: i love this film and i think think it is one of the best best


minibatch loop: 100%|██████████| 3125/3125 [19:15<00:00,  2.72it/s, cost=0.288] 
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 4, average loss 1.686360
real string: i love this film and i think it is one of the best films
augmented string: i love this film and i think think it is one of the best best


minibatch loop: 100%|██████████| 3125/3125 [19:15<00:00,  2.67it/s, cost=0.146] 
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 5, average loss 0.854036
real string: i love this film and i think it is one of the best films
augmented string: i love this film and i think it is one one of the best best


minibatch loop: 100%|██████████| 3125/3125 [19:15<00:00,  2.63it/s, cost=0.933] 
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 6, average loss 2.273920
real string: i love this film and i think it is one of the best films
augmented string: i love this film film and i think it is one of of the the


minibatch loop: 100%|██████████| 3125/3125 [19:15<00:00,  2.71it/s, cost=0.0989]
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 7, average loss 0.978381
real string: i love this film and i think it is one of the best films
augmented string: i love this film and i think it is one one of the the best


minibatch loop: 100%|██████████| 3125/3125 [19:15<00:00,  2.70it/s, cost=0.0298]
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 8, average loss 0.279438
real string: i love this film and i think it is one of the best films
augmented string: i love this film and i think it is one one of the the best


minibatch loop: 100%|██████████| 3125/3125 [19:15<00:00,  2.71it/s, cost=0.0195]
minibatch loop:   0%|          | 0/3125 [00:00<?, ?it/s]

epoch 9, average loss 1.825285
real string: i love this film and i think it is one of the best films
augmented string: i love this film and i think it is one of of the the best


minibatch loop: 100%|██████████| 3125/3125 [19:15<00:00,  2.70it/s, cost=0.0204]

epoch 10, average loss 0.386087
real string: i love this film and i think it is one of the best films
augmented string: i love this film and i think it is one of of the <PAD> best





In [25]:
r_aug = sess.run(model.logits, feed_dict = {model.X: inf_inp(test_strings)})[0]
print('augmented string: %s'%(' '.join([idx2word[r] for r in r_aug])))

augmented string: i love this film and i think it is one of of the <PAD> best
