# Convolutional Variational Autoencoder

The goal of this notebook is to show how to train a Variational Autoencoder (VAE) ([1](https://arxiv.org/abs/1312.6114), [2](https://arxiv.org/abs/1401.4082)) model on the Indonesia SFINCS runs dataset (generated in by the `prepare/prepare.ipynb` notebook). It is based on the [CVAE notebook](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/cvae.ipynb) from [tensorflow](https://www.tensorflow.org/tutorials/generative/cvae). A VAE is a probabilistic take on the autoencoder. An autoencoder a model that uses dimension reduction to represent higher dimensional and spaced data into a latent vector space, similar to other dimensionality reduction techniques like PCA. Unlike a traditional autoencoder, which maps the input onto a latent vector, a VAE maps the input data into the parameters of a probability distribution, such as the mean and a standard deviation. This approach can also generate data that is variable.
In our example rain will not always lead to a flooding event, but sometimes. This variation is also taken into account in this model.

## Setup

In [None]:
# the probablistic libraries for tensorflow
!pip install tensorflow-probability

# to generate gifs
!pip install imageio

# extra examples
!pip install git+https://github.com/tensorflow/docs

  Resolved https://github.com/tensorflow/docs to commit 541d9fbaa497cbd8e3587c8118e34dc1dc2d04b6


  Preparing metadata (setup.py) ... [?25l-

 done


[?25hCollecting astor (from tensorflow-docs==2023.7.13.64986)
  Using cached astor-0.8.1-py2.py3-none-any.whl (27 kB)


Building wheels for collected packages: tensorflow-docs


  Building wheel for tensorflow-docs (setup.py) ... [?25l-

 \

 done
[?25h  Created wheel for tensorflow-docs: filename=tensorflow_docs-2023.7.13.64986-py3-none-any.whl size=183666 sha256=702b78374f073b14b5182f3f3786dedd66efab19ab3a5f65e244195786bd6839
  Stored in directory: /tmpfs/tmp/pip-ephem-wheel-cache-zw802_7g/wheels/fc/f8/3b/5d21409a59cb1be9b1ade11f682039ced75b84de9dd6a0c8de
Successfully built tensorflow-docs


Installing collected packages: astor, tensorflow-docs


Successfully installed astor-0.8.1 tensorflow-docs-2023.7.13.64986


In [None]:
import glob
import time

import imageio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import sklearn.preprocessing
import tensorflow as tf
import tensorflow_probability as tfp
import xarray as xr
from IPython import display

## Load the Indonesia dataset
Each example is composed of a a number of input variables and an expected flood map. 

In [None]:
# open the training examples
ds = xr.open_zarr("../prepare/test-training.zarr")
ds = ds.sel(example=np.arange(0, 100))

In [None]:
img_shape = (128, 128)

In [None]:
ranges = {}
for var in ds.variables:
    ranges[var] = (
        ds.variables[var].min().values.item(),
        ds.variables[var].max().values.item(),
    )
ranges

In [None]:
tensors = pd.read_json("../prepare/tensors.json")
inputs = list(tensors.query("role == 'input'")["name"].values)
outputs = list(tensors.query("role == 'output'")["name"].values)
inputs

In [None]:
num_channels = len(inputs)

# Convert the netcdf's to tensors.
Here we convert and scale our data to a data structure that tensorflow understands. See this [tutorial](https://www.noahbrenowitz.com/post/loading_netcdfs/) for a nice introduction 

In [None]:
# tf.convert_to_tensor(ds["a"])
tensor_data = {}
for i, row in tensors.iterrows():
    # floats only
    arr = ds[row["name"]].values.astype("float32")
    # compute to 0-1
    min_i, max_i = ranges[row["name"]]
    arr = (arr - min_i) / (max_i - min_i)
    # store scaled date
    tensor_data[row["name"]] = tf.convert_to_tensor(arr)

In [None]:
arr = ds[row["name"]].values
min_i, max_i = ranges[row["name"]]
arr = (arr - min_i) / (max_i - min_i)
arr

In [None]:
tf_ds = tf.data.Dataset.from_tensor_slices(tensor_data)

## Use *tf.data* to batch and shuffle the data

In [None]:
train_size = 100
batch_size = 16
train_dataset = tf_ds.shuffle(train_size).batch(batch_size)

In [None]:
a = next(iter(train_dataset))
batch_example = train_dataset.take(1)
example = list(batch_example)[0]

In [None]:
def generate_xy(example):
    # example = list(batch_example)[0]
    x = tf.concat([example[x][..., np.newaxis] for x in inputs], axis=-1)
    y = example[outputs[1]][..., np.newaxis]
    return x, y 
xy_dataset = train_dataset.map(generate_xy)

## Define the encoder and decoder networks with *tf.keras.Sequential*

In this VAE example, use two small ConvNets for the encoder and decoder networks. In the literature, these networks are also referred to as inference/recognition and generative models respectively. Use `tf.keras.Sequential` to simplify implementation. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions.

### Encoder network
This defines the approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for specifying the conditional distribution of the latent representation $z$. 
In this example, simply model the distribution as a diagonal Gaussian, and the network outputs the mean and log-variance parameters of a factorized Gaussian. 
Output log-variance instead of the variance directly for numerical stability.

### Decoder network 
This defines the conditional distribution of the observation $p(x|z)$, which takes a latent sample $z$ as input and outputs the parameters for a conditional distribution of the observation.
Model the latent distribution prior $p(z)$ as a unit Gaussian.

### Reparameterization trick
To generate a sample $z$ for the decoder during training, you can sample from the latent distribution defined by the parameters outputted by the encoder, given an input observation $x$.
However, this sampling operation creates a bottleneck because backpropagation cannot flow through a random node.

To address this, use a reparameterization trick.
In our example, you approximate $z$ using the decoder parameters and another parameter $\epsilon$ as follows:

$$z = \mu + \sigma \odot \epsilon$$

where $\mu$ and $\sigma$ represent the mean and standard deviation of a Gaussian distribution respectively. They can be derived from the decoder output. The $\epsilon$ can be thought of as a random noise used to maintain stochasticity of $z$. Generate $\epsilon$ from a standard normal distribution.

The latent variable $z$ is now generated by a function of $\mu$, $\sigma$ and $\epsilon$, which would enable the model to backpropagate gradients in the encoder through $\mu$ and $\sigma$ respectively, while maintaining stochasticity through $\epsilon$.

### Network architecture
For the encoder network, use two convolutional layers followed by a fully-connected layer. In the decoder network, mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling.


In [27]:
class CVAE(tf.keras.Model):
    """Convolutional variational autoencoder."""

    def __init__(self, img_shape, latent_dim, num_channels, batch_size):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_channels = num_channels
        self.batch_size = batch_size
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(
                    input_shape=(img_shape[0], img_shape[1], num_channels),
                    batch_size=batch_size,
                ),
                tf.keras.layers.Conv2D(
                    padding="same",
                    filters=32, kernel_size=3, strides=(2, 2), activation="relu"
                ),
                # consider max pooling, here?
                tf.keras.layers.Conv2D(
                    padding="same",
                    filters=64, kernel_size=3, strides=(2, 2), activation="relu"
                ),
                tf.keras.layers.Flatten(),
                # No activation
                tf.keras.layers.Dense(latent_dim + latent_dim),
            ]
        )

        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
                # why 32
                tf.keras.layers.Dense(
                    units=img_shape[0] // 4 * img_shape[1] // 4 * 32,
                    activation=tf.nn.relu,
                ),
                tf.keras.layers.Reshape(
                    target_shape=(img_shape[0] // 4, img_shape[1] // 4, 32)
                ),
                tf.keras.layers.Conv2DTranspose(
                    filters=64,
                    kernel_size=3,
                    strides=2,
                    padding="same",
                    activation="relu",
                ),
                tf.keras.layers.Conv2DTranspose(
                    filters=32,
                    kernel_size=3,
                    strides=2,
                    padding="same",
                    activation="relu",
                ),
                # No activation
                tf.keras.layers.Conv2DTranspose(
                    filters=1, kernel_size=3, strides=1, padding="same"
                ),
            ]
        )

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=(self.batch_size, self.latent_dim))
        return eps * tf.exp(logvar * .5) + mean

    def split(self, encoded):
        mean, logvar = tf.split(encoded, num_or_size_splits=2, axis=1)
        return mean, logvar  

    
    def call(self, x):
        encoded = self.encoder(x)
        # reparameterization trick
        mean, logvar = self.split(encoded)
        # random normal distribution + whatever....
        parameterized = self.reparameterize(mean, logvar)
        
        decoded = self.decoder(parameterized)
        return decoded

## Define the loss function and the optimizer

VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:

$$\log p(x) \ge \text{ELBO} = \mathbb{E}_{q(z|x)}\left[\log \frac{p(x, z)}{q(z|x)}\right].$$

In practice, optimize the single sample Monte Carlo estimate of this expectation:

$$\log p(x| z) + \log p(z) - \log q(z|x),$$
where $z$ is sampled from $q(z|x)$.

Note: You could also analytically compute the KL term, but here you incorporate all three terms in the Monte Carlo estimator for simplicity.

In [28]:
model = CVAE(img_shape=img_shape, latent_dim=100, num_channels=num_channels, batch_size=batch_size)
model.compile(optimizer="adam", loss="mean_squared_error")
input_shape = tf.TensorShape([None, img_shape[0], img_shape[1], num_channels])
model.build(input_shape=input_shape)

In [29]:
model.fit(xy_dataset)

2024-02-14 16:55:50.238729: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




InvalidArgumentError: Graph execution error:

Detected at node 'cvae_2/mul_1' defined at (most recent call last):
    File "<frozen runpy>", line 198, in _run_module_as_main
    File "<frozen runpy>", line 88, in _run_code
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 725, in start
      self.io_loop.start()
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/opt/local/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
      self._run_once()
    File "/opt/local/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
      handle._run()
    File "/opt/local/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
      await self.process_one()
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 502, in process_one
      await dispatch(*args)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
      await result
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3024, in run_cell
      result = self._run_cell(
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell
      result = runner(coro)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/var/folders/fh/tcbrjg6n28b0lzjzh07b5t6m0000gn/T/ipykernel_6398/3302383472.py", line 1, in <module>
      model.fit(xy_dataset)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/engine/training.py", line 1742, in fit
      tmp_logs = self.train_function(iterator)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/engine/training.py", line 1338, in train_function
      return step_function(self, iterator)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/engine/training.py", line 1322, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/engine/training.py", line 1303, in run_step
      outputs = model.train_step(data)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/engine/training.py", line 1080, in train_step
      y_pred = self(x, training=True)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/engine/training.py", line 569, in __call__
      return super().__call__(*args, **kwargs)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1150, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/Users/baart_f/.virtualenvs/py311/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/var/folders/fh/tcbrjg6n28b0lzjzh07b5t6m0000gn/T/ipykernel_6398/251923084.py", line 76, in call
      parameterized = self.reparameterize(mean, logvar)
    File "/var/folders/fh/tcbrjg6n28b0lzjzh07b5t6m0000gn/T/ipykernel_6398/251923084.py", line 64, in reparameterize
      return eps * tf.exp(logvar * .5) + mean
Node: 'cvae_2/mul_1'
Incompatible shapes: [4,100] vs. [16,100]
	 [[{{node cvae_2/mul_1}}]] [Op:__inference_train_function_3340]

In [30]:
model.encoder.summary()
model.decoder.summary()
model.summary()


Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_4 (Conv2D)           (16, 64, 64, 32)          3200      
                                                                 
 conv2d_5 (Conv2D)           (16, 32, 32, 64)          18496     
                                                                 
 flatten_2 (Flatten)         (16, 65536)               0         
                                                                 
 dense_4 (Dense)             (16, 200)                 13107400  
                                                                 
Total params: 13129096 (50.08 MB)
Trainable params: 13129096 (50.08 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              P

In [None]:
def generate_and_save_images(model, epoch, test_sample):
    mean, logvar = model.encode(test_sample)
    z = model.reparameterize(mean, logvar)
    predictions = model.sample(z)
    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0], cmap="gray")
        plt.axis("off")

    # tight_layout minimizes the overlap between 2 sub-plots
    plt.savefig("image_at_epoch_{:04d}.png".format(epoch))
    plt.show()

In [None]:
# Pick a sample of the test set for generating output images
assert batch_size >= num_examples_to_generate
for test_batch in train_dataset.take(1):
    test_sample = test_batch[0:num_examples_to_generate, :, :, :]

In [None]:
test_batch

In [None]:
generate_and_save_images(model, 0, test_sample)
losses = []
for epoch in range(1, epochs + 1):
    start_time = time.time()
    for train_x in train_dataset:
        train_step(model, train_x, optimizer)
    end_time = time.time()

    loss = tf.keras.metrics.Mean()
    for test_x in test_dataset:
        loss(compute_loss(model, test_x))
    elbo = -loss.result()
    losses.append(elbo)
    display.clear_output(wait=False)
    print(
        "Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}".format(
            epoch, elbo, end_time - start_time
        )
    )
    generate_and_save_images(model, epoch, test_sample)

### Display a generated image from the last training epoch

In [None]:
def display_image(epoch_no):
    return PIL.Image.open("image_at_epoch_{:04d}.png".format(epoch_no))

In [None]:
plt.imshow(display_image(epoch))
plt.axis("off")  # Display images

### Display an animated GIF of all the saved images

In [None]:
anim_file = "cvae.gif"

with imageio.get_writer(anim_file, mode="I") as writer:
    filenames = glob.glob("image*.png")
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

In [None]:
import tensorflow_docs.vis.embed as embed

embed.embed_file(anim_file)

### Display a 2D manifold of digits from the latent space

Running the code below will show a continuous distribution of the different digit classes, with each digit morphing into another across the 2D latent space. Use [TensorFlow Probability](https://www.tensorflow.org/probability) to generate a standard normal distribution for the latent space.

In [None]:
def plot_latent_images(model, n, digit_size=28):
    """Plots n x n digit images decoded from the latent space."""

    norm = tfp.distributions.Normal(0, 1)
    grid_x = norm.quantile(np.linspace(0.05, 0.95, n))
    grid_y = norm.quantile(np.linspace(0.05, 0.95, n))
    image_width = digit_size * n
    image_height = image_width
    image = np.zeros((image_height, image_width))

    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z = np.array([[xi, yi]])
            x_decoded = model.sample(z)
            digit = tf.reshape(x_decoded[0], (digit_size, digit_size))
            image[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit.numpy()

    plt.figure(figsize=(10, 10))
    plt.imshow(image, cmap="Greys_r")
    plt.axis("Off")
    plt.show()

In [None]:
plot_latent_images(model, 20)

## Next steps

This tutorial has demonstrated how to implement a convolutional variational autoencoder using TensorFlow. 

As a next step, you could try to improve the model output by increasing the network size. 
For instance, you could try setting the `filter` parameters for each of the `Conv2D` and `Conv2DTranspose` layers to 512. 
Note that in order to generate the final 2D latent image plot, you would need to keep `latent_dim` to 2. Also, the training time would increase as the network size increases.

You could also try implementing a VAE using a different dataset, such as CIFAR-10.

VAEs can be implemented in several different styles and of varying complexity. You can find additional implementations in the following sources:
- [Variational AutoEncoder (keras.io)](https://keras.io/examples/generative/vae/)
- [VAE example from "Writing custom layers and models" guide (tensorflow.org)](https://www.tensorflow.org/guide/keras/custom_layers_and_models#putting_it_all_together_an_end-to-end_example)
- [TFP Probabilistic Layers: Variational Auto Encoder](https://www.tensorflow.org/probability/examples/Probabilistic_Layers_VAE)

If you'd like to learn more about the details of VAEs, please refer to [An Introduction to Variational Autoencoders](https://arxiv.org/abs/1906.02691).