In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Lambda, Input, Dense, Layer, Concatenate
from tensorflow.keras.models import Model
import random as rn
import numpy as np
import tensorflow as tf

rn.seed(123)
np.random.seed(123)
tf.random.set_seed(123)

def sampling(args):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.

    # Arguments
        args (tensor): mean and log of variance of Q(z|X)

    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = tf.keras.backend.shape(z_mean)[0]
    dim = tf.keras.backend.int_shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim), seed=123)
    
    return z_mean + tf.keras.backend.exp(0.5 * z_log_var) * epsilon

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

input_shape = (original_dim, )
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50

y_train_one_hot = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test_one_hot = tf.keras.utils.to_categorical(y_test, num_classes=10)

NUM_CLASSES = y_train_one_hot.shape[1]

In [2]:
import tensorflow_probability as tfp
tfd = tfp.distributions

In [50]:
prior = tfd.Independent(tfd.Normal(loc=tf.zeros(16), scale=1),
                        reinterpreted_batch_ndims=1)
tfpl.MultivariateNormalTriL.params_size(encoded_size)

152

In [51]:

# tfpl = tfp.layers
# tfkl = tf.keras.layers
# base_depth = 100
# encoded_size = 16
# encoder = tf.keras.Sequential([
#     tfkl.InputLayer(input_shape=(28,28,1)),
#     tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
#     tfkl.Conv2D(base_depth, 5, strides=1,
#                 padding='same', activation=tf.nn.leaky_relu),
#     tfkl.Conv2D(base_depth, 5, strides=2,
#                 padding='same', activation=tf.nn.leaky_relu),
#     tfkl.Conv2D(2 * base_depth, 5, strides=1,
#                 padding='same', activation=tf.nn.leaky_relu),
#     tfkl.Conv2D(2 * base_depth, 5, strides=2,
#                 padding='same', activation=tf.nn.leaky_relu),
#     tfkl.Conv2D(4 * encoded_size, 7, strides=1,
#                 padding='valid', activation=tf.nn.leaky_relu),
#     tfkl.Flatten(),
#     tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
#                activation=None,name='dense'),
#     tfpl.MultivariateNormalTriL(
#         encoded_size,
#         activity_regularizer=tfpl.KLDivergenceRegularizer(prior, weight=1.0)),
# ])

# encoder.get_layer('dense').output

In [3]:
X_input = Input(shape=(x_train.shape[1],)) 
cond = Input(shape=(NUM_CLASSES,))

inputs = Concatenate()([X_input, cond])

encoder_h = Dense(intermediate_dim, activation='relu')(inputs)
#mu = Dense(latent_dim, activation='linear', name='mu')(encoder_h)
#l_sigma = Dense(latent_dim, activation='linear')(encoder_h)

#z = Lambda(sampling, output_shape = (latent_dim, ))([mu, l_sigma])

zc = Concatenate(name='z_condition')([z, cond])

decoder_hidden = Dense(intermediate_dim, activation='relu', name='decoder_hidden')
decoder_output = Dense(original_dim, activation='sigmoid', name='decoder_output')

decoder_intermediate = decoder_hidden(zc)
decoder_output_layer = decoder_output(decoder_intermediate)

cvae = Model([X_input, cond], [decoder_output_layer,mu, l_sigma])
encoder = Model(cvae.inputs, cvae.get_layer('mu').output)


decoder_z = Input(shape=(latent_dim,))
decoder_cond = Input(shape=(NUM_CLASSES,))

decoder_input = Concatenate()([decoder_z, decoder_cond])

slice_decoder_hidden = decoder_hidden(decoder_input)
slice_decoder_output = decoder_output(slice_decoder_hidden)
decoder = Model([decoder_z, decoder_cond], slice_decoder_output)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train_one_hot))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(128)

In [11]:
latent_prior = tfd.MultivariateNormalDiag(loc=tf.zeros([latent_dim]),
        scale_identity_multiplier=1.0)

In [None]:

epochs = 50

for epoch in range(epochs):
    
    loss_metric = tf.keras.metrics.Mean()
    
    for x_batch_train, y_batch_train in train_dataset:
        
        with tf.GradientTape() as tape:
            
            reconstructed, z_mu, z_sigma = cvae((x_batch_train, y_batch_train))
            
            mse_loss = mse_loss_fn(x_batch_train, reconstructed)
            mse_loss *= original_dim

            kl_loss = -0.5 * tf.reduce_sum(1 + z_sigma - tf.square(z_mu) - tf.exp(z_sigma), axis=-1)
            
            loss = tf.reduce_mean(mse_loss + kl_loss)

        grads = tape.gradient(loss, cvae.trainable_weights)
        optimizer.apply_gradients(zip(grads, cvae.trainable_weights))

        loss_metric(loss)

    if not epoch % 10:
        
        print('Epoch %s: mean loss = %s' % (epoch, loss_metric.result()))