# Variational AutoEncoder (VAE) 
 (inspired by the work of @AlexAdam)

In [1]:
import sys
sys.path.append("..")

In [13]:
import numpy as np

from keras import objectives, backend as K
from keras.layers import Input, Dense, Embedding, Bidirectional, LSTM, TimeDistributed, RepeatVector
from keras.models import Model
from keras.optimizers import Adam
from keras.utils import to_categorical

import vae


class Hyper(vae.Hyper):
    def __init__(self,
            vocab_size=1000, 
            embedding_dim=64, 
            max_length=300,
            batch_size=10,
            lr=0.001, 
            latent_dim=435,
            intermediate_dim=200,
            encoder_hidden_dim=500, 
            decoder_hidden_dim=500,
            epochs=50,
            epsilon_std=0.01):
        vae.Hyper.__init__(self,            
            batch_size=batch_size,
            lr=lr, 
            latent_dim=latent_dim,
            intermediate_dim=intermediate_dim,
            epochs=epochs,
            epsilon_std=epsilon_std)
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.max_length = max_length
        self.encoder_hidden_dim=encoder_hidden_dim
        self.decoder_hidden_dim=decoder_hidden_dim


class TextVae(vae.Vae):
    def __init__(self, hyper):
        vae.Vae.__init__(self, hyper)

    def build_optimizer(self):
        return Adam(lr=self.h.lr)
        
    def compute_vae_loss(self, x, x_decoded_mean, z_mean, z_log_var):
        print("shapes")
        print(x.shape)
        print(x.dtype)
        print(x_decoded_mean.shape)
        x = K.flatten(x)
#         x = K.one_hot(x, self.h.vocab_size)
        x_decoded_mean = K.flatten(x_decoded_mean)
        xent_loss = self.h.max_length * objectives.binary_crossentropy(x, x_decoded_mean)
        kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return xent_loss + kl_loss

#     def build_vae_loss_layer(self, x, x_decoded_mean, z_mean, z_log_var):
#         return vae.Vae.build_vae_loss_layer(self, self.x_embed, x_decoded_mean, z_mean, z_log_var)
        
    def build_encoder(self):
        x = Input(shape=(self.h.max_length,))
        self.x_embed = Embedding(self.h.vocab_size, self.h.embedding_dim, input_length=self.h.max_length)(x)
        h = Bidirectional(LSTM(self.h.encoder_hidden_dim, 
                return_sequences=True, 
                name='encoder_rnn_1'), 
            merge_mode='concat')(self.x_embed)
        h = Bidirectional(LSTM(self.h.encoder_hidden_dim, 
                return_sequences=False, 
                name='encoder_rnn_2'),
            merge_mode='concat')(h)

        return x, Dense(self.h.intermediate_dim, activation='relu', name='encoder_output')(h)
            
    def build_decoder_layers(self):
        decoder_rnn_1 = LSTM(self.h.decoder_hidden_dim, 
            return_sequences=True, 
            name='decoder_rnn_1')
        decoder_rnn_2 = LSTM(self.h.decoder_hidden_dim,
            return_sequences=True, 
            name='decoder_rnn_2')
        decoder_mean = TimeDistributed(Dense(self.h.vocab_size, activation='softmax'), name='decoded_mean')
        return decoder_rnn_1, decoder_rnn_2, decoder_mean

    def build_decoder(self, z):
        h_decoded = decoder_h(z)
        return decoder_mean(h_decoded)

    def build_decoder(self, encoded):
        decoder_rnn_1, decoder_rnn_2, decoder_mean = self.decoder_layers

        h = RepeatVector(self.h.max_length)(encoded)
        h = decoder_rnn_1(h)
        h = decoder_rnn_2(h)

        return decoder_mean(h)
    
    def build_generator(self):
        decoder_rnn_1, decoder_rnn_2, decoder_mean = self.decoder_layers

        decoder_input = Input(shape=(self.h.latent_dim,))

        h = RepeatVector(self.h.max_length)(decoder_input)
        h = decoder_rnn_1(h)
        h = decoder_rnn_2(h)

        return Model(decoder_input, decoder_mean(h))


In [14]:
hyper = Hyper()
model = TextVae(hyper)

shapes
(?, ?, ?)
<dtype: 'float32'>
(?, 300, 1000)


In [6]:
from keras.datasets import imdb
from keras.preprocessing.sequence import pad_sequences

MAX_LENGTH = 300
NUM_WORDS = 1000

In [7]:
(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=NUM_WORDS)

print("Training data")
print(X_train.shape)
print(y_train.shape)

print("Number of words:")
print(len(np.unique(np.hstack(X_train))))

X_train = pad_sequences(X_train, maxlen=MAX_LENGTH)
X_test = pad_sequences(X_test, maxlen=MAX_LENGTH)

train_indices = np.random.choice(np.arange(X_train.shape[0]), 2000, replace=False)
test_indices = np.random.choice(np.arange(X_test.shape[0]), 1000, replace=False)

X_train = X_train[train_indices]
y_train = y_train[train_indices]

X_test = X_test[test_indices]
y_test = y_test[test_indices]

temp = np.zeros((X_train.shape[0], MAX_LENGTH, NUM_WORDS))
temp[np.expand_dims(np.arange(X_train.shape[0]), axis=0).reshape(X_train.shape[0], 1), np.repeat(np.array([np.arange(MAX_LENGTH)]), X_train.shape[0], axis=0), X_train] = 1

X_train_one_hot = temp

temp = np.zeros((X_test.shape[0], MAX_LENGTH, NUM_WORDS))
temp[np.expand_dims(np.arange(X_test.shape[0]), axis=0).reshape(X_test.shape[0], 1), np.repeat(np.array([np.arange(MAX_LENGTH)]), X_test.shape[0], axis=0), X_test] = 1

X_test_one_hot = temp

Training data
(25000,)
(25000,)
Number of words:
998


In [15]:
model.fit(x=X_train, y=X_train_one_hot, batch_size=10, epochs=1, validation_data=(X_test, X_test_one_hot))

Train on 2000 samples, validate on 1000 samples
Epoch 1/1


<keras.callbacks.History at 0x7f319480cdd8>

In [18]:
encode = model.encode(X_test[:3])
decode = model.generate(encode)

In [21]:
X_test[:3]

array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   1,   6, 171,   7,  61,
          2, 927,  28,   2,  14,   2, 114, 791,  38,  13,  80,   2,   8,
         49,   7,   4,   2,   2,  10,  10, 300,  14,   9, 441,  21,  24,
         18,   4,   2,   2, 104,  45, 954,  21, 820, 131,   2,  10,  10,
        241,   2, 119, 602,  47,   8,  30,  44,   2,   4,   2, 250,  15,
          2,   2,   9,  96,   2,  10,  10, 342,   2,   2,  27,   2,  39,
        443,   2,  82,  48,  61,   2,   2, 272,  40,   2, 146, 170,   8,
         28,   8,   2,  18,   6,   2,  10,  10, 470

In [23]:
decode.argmax(axis=-1)

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 