In [None]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt

# Use to generate new data by manipulation

## What is this?
VAE stands for Variational Auto Encoder. Which means that is trained to reproduce or reconstruct from a latent space the input. Is an strategy to learn that latent space

It has 2 important blocks the Encoder wich will be responsable for condensing the data into a low dimensional latent space or just a vector.
And the Decoder by taking one point(how do you choose a point? By sampling) from the latent space is able to reconstruct a new image.

## How do VAEs build a latent space?
That is what difference a VAE from a AE that the V which means variational and allows the VAE to create continuous and structured latent spaces.

## How do you sample the latent space?


The latent space is structured, non sparse continuous and low dimentional where each direction encode a meanful axis of variation(V) of the data. And that means it can be manipulated with content vectors. Vectors that are isolated and represent a concept.
Like the concept of smile another image representation can be added to the smile concept vector then passed to the decoder to create a new image with the person smiling.

```There are concept vectors for any independent direction of the latent space```
```deeplearning with bayesian inference```

## Steps
- Build the encoder
    - The output will be 2 vectors mean and variance
- Build the sampler using a random small vector along with the 2 vectors the encoder will give us
    - normal dist value = mean * exp(std) * epsilon / epsilon is a random small vector from the latent space
- Build the loss
    - kubell-divergence
    - reconstruction loss mean already coded in keras
- Build the decoder
    - the decoder will take the sampled input and reconstruct it to a valid image!

# Encoder

The encoder will transform the image into 2 parameters vectors that will be used to form a normal distribution, mean_vector and standard_deviation_vector

Dummy dataset Nmist

In [None]:
IMAGE_HEIGHT = IMAGE_WIDTH = 64
DATASET = "monet"
# DATASET = "celebs"
# DATASET = "cat_vs_dogs"
COLOR_MODE = "rgb"

if DATASET == "cat_vs_dogs":
  whole_dataset = tf.keras.utils.image_dataset_from_directory("./PetImages/",
                                                              label_mode=None,
                                                                image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
                                                                  batch_size=256,
                                                                  shuffle=True,
                                                                  color_mode=COLOR_MODE,
                                                                  smart_resize=True)

if DATASET == "celebs":
  whole_dataset = tf.keras.utils.image_dataset_from_directory("./Celeb_images/",
                                                              label_mode=None,
                                                                image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
                                                                  batch_size=256,
                                                                  shuffle=True,
                                                                  color_mode=COLOR_MODE,
                                                                  smart_resize=True)
  
if DATASET == "monet":
  whole_dataset = tf.keras.utils.image_dataset_from_directory("./art_images/monet_jpg/",
                                                              label_mode=None,
                                                                image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
                                                                  batch_size=256,
                                                                  shuffle=True,
                                                                  color_mode=COLOR_MODE,
                                                                  smart_resize=True)

In [None]:
whole_dataset

In [None]:
i = 0
for batch in whole_dataset:
    for image in batch:
        i+=1

In [None]:
whole_dataset = whole_dataset.map(lambda x: x/255., num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
LATENT_DIMS = 2
IMAGE_CHANNELS = 3 if COLOR_MODE == "rgb" else 1
PROY_DIM = 128 #16

# the input is the output of the sample(encoder(image))
# CHANNELS_ENCODER_OUTPUT = 256
CONVS_NUMBER_ENCODER = 5

In [None]:
# I need to use a fix size.. This should not be a limitation in the future.
# If I create a class I can surpass that limitation
X_input_encoder = tf.keras.Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS))
filters_decoder = 16

# add more convs
#Last conv
X = tf.keras.layers.Conv2D(filters=filters_decoder, kernel_size=3, activation="relu", strides=2, padding="same")(X_input_encoder)

for channel_index in range(CONVS_NUMBER_ENCODER-1):
    filters_decoder *= 2
    X = tf.keras.layers.Conv2D(filters=filters_decoder, kernel_size=3, activation="relu", strides=2, padding="same")(X)
    CHANNELS_ENCODER_OUTPUT = filters_decoder
    #Adding more convs diminish kl divergence at first then increases it??

X = tf.keras.layers.Flatten()(X)

# Try a proyection model sequence latter
X = tf.keras.layers.Dense(units=PROY_DIM, activation="relu")(X)
# X = tf.keras.Sequential([tf.keras.layers.Dense(units=PROY_DIM, activation="relu"), tf.keras.layers.Dense(units=PROY_DIM)])(X)

z_mean = tf.keras.layers.Dense(units=LATENT_DIMS, name="z_mean")(X)
z_log_var = tf.keras.layers.Dense(units=LATENT_DIMS, name="z_log_var")(X)

encoder = tf.keras.Model(inputs=X_input_encoder, outputs=[z_mean, z_log_var])

In [None]:
encoder.summary()

# Sampler!

We are sampling from a normal distribution

normal = mu + exp(sigma) * epsilon

epsilon is a random number from a normal distribution

In [None]:
class Sampler(tf.keras.layers.Layer):
    def call(self, z_mean, z_sigma):
        # get the batch size
        batch_size = tf.shape(z_mean)[0]
        latent_dim = tf.shape(z_mean)[1]

        #Epsilon should be the same size as our vectors
        # here we are in the training and everything gets processed in batch
        epsilon = tf.random.normal(shape=(batch_size, latent_dim))

        #This returns a sample point from the distribution we are trying to find.
        # A normal distribution
        # Why over 2? is that 2 the N elements?
        return z_mean + tf.math.exp(z_sigma/2) * epsilon

# Decoder

In [None]:
import math

In [None]:
#Image is 256 but we have 2 transpose layers the image will get upscaled 2 times and we only applied 1 layer of convs with strides 2 
# At this point the image is 128 but if we are going to upscale 2 times for the image to be again 256 it should be 64
CONVS_NUMBER = 3
IMAGE_HEIGHT_DECODE = IMAGE_WIDTH_DECODE = math.ceil(IMAGE_HEIGHT/(2**CONVS_NUMBER))

X_input_decoder = tf.keras.Input(shape=(LATENT_DIMS,))

X = tf.keras.layers.Dense(units=IMAGE_HEIGHT_DECODE * IMAGE_WIDTH_DECODE * CHANNELS_ENCODER_OUTPUT)(X_input_decoder)
X = tf.keras.layers.Reshape((IMAGE_HEIGHT_DECODE, IMAGE_WIDTH_DECODE, CHANNELS_ENCODER_OUTPUT))(X)

filters = CHANNELS_ENCODER_OUTPUT
for channel_index in range(CONVS_NUMBER):
    X = tf.keras.layers.Conv2DTranspose(filters, 3, activation="relu", strides=2, padding="same")(X)
    filters = math.ceil(filters/2)

X_decoder_output = tf.keras.layers.Conv2D(IMAGE_CHANNELS, 3, activation="sigmoid", padding="same")(X)

#X_input_decoder latent input
decoder = tf.keras.Model(X_input_decoder, X_decoder_output, name="decoder")

In [None]:
decoder.summary()

In [None]:
sample = tf.constant(np.random.normal(size=(8, 5, 4, 1)))

In [None]:
sample.shape

In [None]:
tf.reduce_sum(sample, axis=(1, 2))

In [None]:
tf.reduce_mean(tf.reduce_sum(sample, axis=(1, 2)))

In [None]:
class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, sampler, **kwars):
        """ 
         It has 3 main blocks 
            - encoder
            - decoder
            - sampler
        """
        super().__init__(**kwars)
        self.encoder = encoder
        self.decoder = decoder
        self.sampler = sampler

        # sum all losses
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_flat_loss")

        # how close the reconstructed sample by the decoder is to the original source
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")

        # divergence from the distribution created to model the latent space and the real one which is a normal distribution
        # From where the epsilon point comes from?
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    def call(self, inputs):
        z_mean, z_sigma = self.encoder(inputs)
        return self.decoder(self.sampler(z_mean, z_sigma))

    @property
    def metrics(self):
        return [self.total_loss_tracker,
                self.reconstruction_loss_tracker,
                self.kl_loss_tracker]
    
    def train_step(self, batch_data):
        with tf.GradientTape() as tape:
            z_mean, z_sigma = self.encoder(batch_data)
            sampled_point_z = self.sampler(z_mean, z_sigma)
            reconstructed_data = self.decoder(sampled_point_z)

            reconstruction_loss = tf.reduce_mean(
                # The image needs to be normalized between 0 and 1 for this to make sense
                # binary difference between input data and reconstructed
                tf.reduce_sum(tf.keras.losses.binary_crossentropy(batch_data, reconstructed_data),
                              # sum by sample, reduce the whole row to 1 value
                              axis=(1, 2))
            )

            kl_loss = -0.5 * (1 + z_sigma - tf.math.square(z_mean) - tf.math.exp(z_sigma))

            # this adds all 2 losses into one
            # how well the input was reconstructed and the distribution difference to create the latent space
            total_loss = reconstruction_loss + tf.reduce_mean(kl_loss)

        # Standard way to propagate the error signal?
        gradients = tape.gradient(total_loss, self.trainable_weights)
        # gradients = tape.gradient(total_loss, [z_mean, z_sigma])
        self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))
        # self.optimizer.apply_gradients(zip(gradients, [z_mean, z_sigma]))

        # add and average the loss so far up to this batch
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "total_loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result()
        }

In [None]:
tf.keras.backend.clear_session()

In [None]:
vae = VAE(encoder=encoder, decoder=decoder, sampler=Sampler())

In [None]:
vae.compile(optimizer=tf.keras.optimizers.legacy.Adam(), run_eagerly=None)

# TODO: add a callback to preview the output of the generator vs the original image

In [None]:
class VAEMonitor(tf.keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=LATENT_DIMS):
        self.num_img = num_img
        self.latent_dim = latent_dim
  
    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(
            shape=(self.num_img, self.latent_dim))
        generated_images = self.model.decoder(random_latent_vectors)
        generated_images *= 255 
        generated_images.numpy()
        for i in range(self.num_img):
            img = tf.keras.utils.array_to_img(generated_images[i])
            img.save(f"./celeb_vae_gen/generated_img_{epoch:03d}_{i}.png")

In [None]:
vae.fit(whole_dataset, epochs=30, callbacks=[VAEMonitor(num_img=10, latent_dim=LATENT_DIMS)])

In [None]:
sample = next(iter(whole_dataset))[0]
sample = np.expand_dims(sample, axis=0)

In [None]:
reconstructed = vae.predict(sample)

In [None]:
import matplotlib.pyplot as plt

In [None]:
reconstructed[0].shape
reconstructed = reconstructed[0]

In [None]:
reconstructed *= 128
reconstructed += 64
reconstructed = np.clip(0, 255, reconstructed)
reconstructed = reconstructed.astype("uint8")

In [None]:
plt.imshow(sample[0])

In [None]:
mu, var = encoder(sample)

In [None]:
mu 

In [None]:
z = Sampler()(mu, var)

In [None]:
z

In [None]:
plt.imshow((decoder(tf.constant([[.2, -0.2]], dtype="float32"))[0].numpy()*255).astype("uint8"))

In [None]:
plt.imshow((decoder(z)[0].numpy()*255).astype("uint8"))

In [None]:
plt.imshow(reconstructed)

In [None]:
m, s = encoder(sample)