<img src="vae.png" width="300">

In [1]:
from modified import ModifiedBasicDecoder, ModifiedBeamSearchDecoder

import tensorflow as tf
import numpy as np

In [2]:
PARAMS = {
    'max_len': 15,
    'word_dropout_rate': 0.2,
    'beam_width': 5,
}

In [3]:
def build_vocab(index_from=4):
    PARAMS['word2idx'] = tf.keras.datasets.imdb.get_word_index()
    PARAMS['word2idx'] = {k: (v + index_from) for k, v in PARAMS['word2idx'].items()}
    PARAMS['word2idx']['<pad>'] = 0
    PARAMS['word2idx']['<start>'] = 1
    PARAMS['word2idx']['<unk>'] = 2
    PARAMS['word2idx']['<end>'] = 3
    
    PARAMS['idx2word'] = {i: w for w, i in PARAMS['word2idx'].items()}
    PARAMS['idx2word'][-1] = '-1'     # task-specific exception handling
    PARAMS['idx2word'][4] = '4'       # task-specific exception handling

    
def load_data(index_from=4):
    (X_train, _), (X_test, _) = tf.contrib.keras.datasets.imdb.load_data(
        num_words=None, index_from=index_from)
    return (X_train, X_test)


def word_dropout(x):
    is_dropped = np.random.binomial(1, PARAMS['word_dropout_rate'], x.shape)
    fn = np.vectorize(lambda x, k: PARAMS['word2idx']['<unk>'] if (
        k and (x not in range(4))) else x)
    return fn(x, is_dropped)

In [4]:
word2idx = build_vocab()
X = np.concatenate(load_data())

X = np.concatenate((
    tf.keras.preprocessing.sequence.pad_sequences(
        X, PARAMS['max_len'], truncating='post', padding='post'),
    tf.keras.preprocessing.sequence.pad_sequences(
        X, PARAMS['max_len'], truncating='pre', padding='post')))

enc_inp = X[:, 1:]
dec_inp = word_dropout(X)
dec_out = np.concatenate([X[:, 1:], np.full([X.shape[0], 1],
                                            PARAMS['word2idx']['<end>'])], 1)

<start> shown in australia as 'hydrosphere' this incredibly bad movie is so bad that you

shown in australia as 'hydrosphere' this incredibly bad movie is so bad that you

<start> shown in <unk> as <unk> this incredibly bad movie is so bad that you

shown in australia as 'hydrosphere' this incredibly bad movie is so bad that you <end>


In [5]:
def reparam_trick(z_mean, z_logvar):
    gaussian = tf.truncated_normal(tf.shape(z_logvar))
    z = z_mean + tf.exp(0.5 * z_logvar) * gaussian
    return z


def kl_w_fn(anneal_max, anneal_bias, global_step):
    return anneal_max * tf.sigmoid((10 / anneal_bias) * (
        tf.to_float32(global_step) - tf.constant(anneal_bias / 2)))


def kl_loss_fn(self, mean, gamma):
    return 0.5 * tf.reduce_sum(
        tf.exp(gamma) + tf.square(mean) - 1 - gamma) / tf.to_float(tf.shape(mean)[0])

In [6]:
def rnn_cell():
    return tf.nn.rnn_cell.GRUCell(PARAMS['rnn_size'],
                                  kernel_initializer=tf.orthogonal_initializer())


def forward(inputs, labels, reuse, is_training):
    enc_seq_len = tf.count_nonzero(inputs, 1, dtype=tf.int32)
    dec_seq_len = tf.count_nonzero(labels, 1, dtype=tf.int32)
    batch_sz = tf.shape(inputs)[0]
    
    with tf.variable_scope('Encoder', reuse=reuse):
        embedding = tf.get_variable('lookup_table', [len(PARAMS['word2idx']), PARAMS['embed_dims']])
        x = tf.nn.embedding_lookup(embedding, inputs)
        
        _, enc_state = tf.nn.dynamic_rnn(rnn_cell(), x, enc_seq_len, dtype=tf.float32)
        
        z_mean = tf.layers.dense(enc_state, PARAMS['latent_size'])
        z_logvar = tf.layers.dense(enc_state, PARAMS['latent_size'])
        
    z = reparam_trick(z_mean, z_logvar)
        
    with tf.variable_scope('Decoder', reuse=reuse):
        output_proj = tf.layers.Dense(len(PARAMS['word2idx']))
        dec_cell = rnn_cell()
        
        if is_training:
            init_state = cell.zero_state(batch_sz, tf.float32).clone(
                cell_state=enc_state)
            
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs = tf.nn.embedding_lookup(embedding, labels['dec_inp']),
                sequence_length = dec_seq_len)
            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell = cell,
                helper = helper,
                initial_state = init_state,
                output_layer = output_proj)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder,
                maximum_iterations = tf.reduce_max(dec_seq_len))
            return decoder_output.rnn_output
        else:
            decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                cell = cell,
                embedding = embedding,
                start_tokens = tf.tile(tf.constant([PARAMS['word2idx']['<start>']], tf.int32),
                                       [batch_sz]),
                end_token = PARAMS['word2idx']['<end>'],
                initial_state = init_state,
                beam_width = PARAMS['beam_width'],
                output_layer = output_proj)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder)
            return decoder_output.predicted_ids[:, :, 0]