# Test to see if VAE can produce a good model

In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import keras
from keras import layers
import numpy as np
import matplotlib.pyplot as plt
import gym

In [2]:
from vae_recurrent import VAE, create_decoder, create_encoder
from util import random_observation_sequence, transform_observations

In [14]:
import tensorflow as tf
import tensorflow_probability as tfp
import keras
from keras import layers
import numpy as np


class Sampling(layers.Layer):
    """Uses (z_mean, z_stddev) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_stddev = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))

        return z_mean + z_stddev * epsilon


def create_encoder(input_dim, latent_dim, hidden_units=[16, 8]):

    encoder_inputs = keras.Input(shape=input_dim)

    x = encoder_inputs
    for n in hidden_units:
        x = layers.Dense(n, activation="silu")(x)

    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_std = layers.Dense(latent_dim, name="z_stddev")(x)  # output log of sd
    z_stddev = tf.exp(z_log_std)  # exponentiate to get sd
    z = Sampling()([z_mean, z_stddev])
    encoder = keras.Model(encoder_inputs, [z_mean, z_stddev, z], name="encoder")

    return encoder


def create_decoder(latent_dim, output_dim, hidden_units=[16, 8]):

    latent_inputs = keras.Input(shape=(latent_dim,))

    x = latent_inputs
    for n in hidden_units:
        x = layers.Dense(n, activation="silu")(x)

    decoder_outputs = layers.Dense(output_dim)(x)
    decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")

    return decoder


class VAE(keras.Model):
    def __init__(self, encoder, decoder, latent_dim, reg_mean, reg_stddev, recon_stddev=0.05, llik_scaling=1, kl_scaling=1, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

        self.latent_dim = latent_dim

        self.reg_mean = reg_mean
        self.reg_stddev = reg_stddev

        self.reconstruction_stddev = recon_stddev

        self.llik_scaling = llik_scaling
        self.kl_scaling = kl_scaling

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def call(self, inputs, training=None, mask=None):
        _, _, z = self.encoder(inputs)
        reconstruction = self.decoder(z)
        return reconstruction

    def compute_loss(self, x=None):
        z_mean, z_stddev, z = self.encoder(x)
        reconstruction = self.decoder(z)

        reconstruction_loss = nll_gaussian(reconstruction, x, self.reconstruction_stddev**2, use_consts=False) * self.llik_scaling

        posterior_dist = tfp.distributions.MultivariateNormalDiag(loc=z_mean, scale_diag=z_stddev)
        reg_dist = tfp.distributions.MultivariateNormalDiag(loc=self.reg_mean, scale_diag=self.reg_stddev)
        kl_loss = tfp.distributions.kl_divergence(posterior_dist, reg_dist) * self.kl_scaling

        # kl_loss = tf.reduce_sum(kl_loss, axis=1)
        total_loss = reconstruction_loss + kl_loss
        return total_loss


    def train_step(self, data):

        # unpack data
        # x, reg_vals = data
        x = data
        # reg_mean, reg_stddev = reg_vals
        with tf.GradientTape() as tape:
            z_mean, z_stddev, z = self.encoder(x)
            reconstruction = self.decoder(z)

            # TODO why is it not this? Why should it be log prob instead?
            # reconstruction_loss = keras.losses.binary_crossentropy(x, reconstruction) * self.llik_scaling  # need scaling to stop collapse

            # TODO fix this because it seems totally wrong
            # prob dist of reconstruction and log prob of obs under this distribution
            # reconstruction_dist = tfp.distributions.MultivariateNormalDiag(loc=reconstruction, scale_diag=tf.ones_like(reconstruction) * self.reconstruction_stddev)
            # reconstruction_loss = -1 * reconstruction_dist.log_prob(x)

            reconstruction_loss = nll_gaussian(reconstruction, x, self.reconstruction_stddev**2, use_consts=False) * self.llik_scaling

            posterior_dist = tfp.distributions.MultivariateNormalDiag(loc=z_mean, scale_diag=z_stddev)
            reg_dist = tfp.distributions.MultivariateNormalDiag(loc=self.reg_mean, scale_diag=self.reg_stddev)
            kl_loss = tfp.distributions.kl_divergence(posterior_dist, reg_dist) * self.kl_scaling

            # kl_loss = tf.reduce_sum(kl_loss, axis=1)
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),  # TODO should this be total_loss not loss
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


def nll_gaussian(pred, target, variance, use_consts=True):

    neg_log_prob = ((pred - target)**2/(2*variance))

    if use_consts:
        const = 0.5*np.log(2*np.pi*variance)
        neg_log_prob += const

    return tf.reduce_sum(neg_log_prob, axis=1)


In [15]:
env = gym.make('MountainCarContinuous-v0')
env.action_space.seed(42)

observation_max = np.array([0.6, 0.07])
observation_min = np.array([-1.2, -0.07])

all_observations = []
all_observations_scaled = []
num_episodes = 10

obs_stddev = [0.05, 0.05]
# obs_stddev = [0, 0]

for i in range(num_episodes):
    o, a, r = random_observation_sequence(env, 1000, epsilon=0.1)
    # o_scaled = transform_observations(o, observation_max, observation_min, [0.05, 0.05])
    o_scaled = transform_observations(o, observation_max, observation_min, obs_stddev)
    o_scaled

    o = transform_observations(o, observation_max, observation_min, [0, 0])

    all_observations_scaled.append(o_scaled)
    all_observations.append(o)

all_observations = np.vstack(all_observations)
all_observations_scaled = np.vstack(all_observations_scaled)
all_observations.shape

  deprecation(
  deprecation(


(6057, 2)

In [16]:
enc = create_encoder(2, 2, [20])
dec = create_decoder(2, 2, [20])

vae = VAE(enc, dec, 2, [0, 0], [0.3, 0.3], llik_scaling=1, kl_scaling=1, recon_stddev=0.05)
vae.compile(optimizer=tf.keras.optimizers.Adam())

In [17]:
vae.fit(all_observations_scaled, epochs=20, batch_size=64)

Epoch 1/20


2022-08-15 22:07:49.562691: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x175179730>

In [18]:
all_observations

array([[-1.16355234e-01,  0.00000000e+00],
       [-1.16411825e-01, -7.27777504e-04],
       [-1.16524643e-01, -1.45044001e-03],
       ...,
       [ 7.04051044e-01,  8.80452724e-01],
       [ 7.70915303e-01,  8.59683273e-01],
       [ 8.36608463e-01,  8.44626288e-01]])

In [19]:
all_observations_scaled

array([[-0.11767816, -0.03169052],
       [ 0.03034906,  0.04527695],
       [-0.07049181,  0.00613367],
       ...,
       [ 0.67064226,  0.91221653],
       [ 0.76796412,  0.87188531],
       [ 0.87081527,  0.83921904]])

In [20]:
vae(all_observations_scaled)

<tf.Tensor: shape=(6057, 2), dtype=float32, numpy=
array([[-0.19115901,  0.01483192],
       [ 0.0132947 ,  0.06215818],
       [-0.05159848,  0.0013382 ],
       ...,
       [ 0.6358071 ,  0.92173404],
       [ 0.8268296 ,  0.9244214 ],
       [ 0.9040931 ,  0.8984023 ]], dtype=float32)>

In [21]:
vae.compute_loss(all_observations_scaled)

<tf.Tensor: shape=(6057,), dtype=float32, numpy=
array([ 3.181439 ,  3.2379115,  3.1048906, ...,  9.992523 , 10.016426 ,
       10.375691 ], dtype=float32)>

In [12]:
0.5*np.log(2*np.pi*(0.05**2))

-2.076793740349318

In [15]:
enc = create_encoder(2, 2, [20])
dec = create_decoder(2, 2, [20])
vae = VAE(enc, dec, [0, 0], [0.3, 0.3], llik_scaling=10000, kl_scaling=1)
vae.compile(optimizer=tf.keras.optimizers.Adam())
vae(all_observations_scaled)

<tf.Tensor: shape=(5843, 2), dtype=float32, numpy=
array([[0.53299767, 0.5704217 ],
       [0.6018372 , 0.44273585],
       [0.58923995, 0.44603327],
       ...,
       [0.50266194, 0.48137212],
       [0.45818692, 0.52328986],
       [0.44035968, 0.50200146]], dtype=float32)>

Random test to make sure posterior collapse is avoided

In [119]:
a = np.random.random((10, 2))
print(a)
vae.decoder(a)

[[0.12924743 0.30960184]
 [0.34146105 0.63325081]
 [0.2006226  0.23677983]
 [0.39715259 0.65883864]
 [0.33274673 0.92715291]
 [0.30385739 0.43389812]
 [0.20808629 0.38681099]
 [0.77297407 0.25049935]
 [0.63391301 0.06432549]
 [0.41720471 0.97503232]]


<tf.Tensor: shape=(10, 2), dtype=float32, numpy=
array([[0.2079564 , 0.3798874 ],
       [0.17498101, 0.44049573],
       [0.19405416, 0.3613656 ],
       [0.1657172 , 0.4435221 ],
       [0.18144417, 0.5038548 ],
       [0.17846933, 0.3992164 ],
       [0.19482864, 0.39318183],
       [0.10141916, 0.33977422],
       [0.11905786, 0.30723736],
       [0.16754702, 0.51014173]], dtype=float32)>

## MNIST

In [120]:
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train/255
x_train_flat = x_train.reshape(x_train.shape[0], x_train.shape[1]**2)

In [121]:
x_train.max()

1.0

In [122]:
enc = create_encoder(10, 784, [256, 128])
dec = create_decoder(10, 784, [256, 128])

vae = VAE(enc, dec, [0]*10, [1]*10, llik_scaling=100, kl_scaling=1)
vae.compile(optimizer=tf.keras.optimizers.Adam())

In [123]:
vae.fit(x_train_flat, batch_size=64, epochs=20)

Epoch 1/20


2022-08-07 13:54:56.950320: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
196/938 [=====>........................] - ETA: 7s - loss: 23.0395 - reconstruction_loss: 19.8954 - kl_loss: 3.1952

KeyboardInterrupt: 

In [None]:
out = vae(x_train_flat)
out = out.numpy().reshape(out.shape[0], 28, 28)
out.shape

In [None]:
example = 16
plt.imshow(out[example])

In [None]:
plt.imshow(x_train[example])

In [None]:
z = np.random.normal([0]*10, scale=[1]*10, size=(1, 10))
fake_num = vae.decoder(z).numpy().reshape(28,28)
plt.imshow(fake_num)