# Variational autoencoder with a learning rate schedule

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/tutorials/vae_tutorial.ipynb)

This tutorial demonstrates how to train a simple variational autoencoder (VAE) end-to-end with learning rate scheduling using Flax and [Optax](https://optax.readthedocs.io/).

For the optimizer schedule, you will use the linear warmup followed by cosine decay ([`optax.warmup_cosine_decay_schedule`](https://optax.readthedocs.io/en/latest/api.html#optax.warmup_cosine_decay_schedule)).

The tutorial uses a lot of fundamental concepts covered in [Getting started](https://flax.readthedocs.io/en/latest/getting_started.html). If you're new to Flax, start there.

## Setup

- Install/upgrade Flax, which will also set up [Optax](https://optax.readthedocs.io/) (for common optimizers and loss functions), and JAX.
- Install [TensorFlow Datasets](https://www.tensorflow.org/datasets) to load a dataset for this tutorial.
- Import the necessary libraries.

In [1]:
!pip install --upgrade -q flax tensorflow_datasets 

[K     |████████████████████████████████| 189 kB 10.1 MB/s 
[K     |████████████████████████████████| 4.7 MB 56.2 MB/s 
[K     |████████████████████████████████| 8.3 MB 16.8 MB/s 
[K     |████████████████████████████████| 237 kB 54.8 MB/s 
[K     |████████████████████████████████| 154 kB 63.4 MB/s 
[K     |████████████████████████████████| 51 kB 4.6 MB/s 
[K     |████████████████████████████████| 85 kB 2.6 MB/s 
[?25h

In [2]:
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Flax Linen API
from flax.training import train_state  # A Flax dataclass to keep the train state

import numpy as np                     # Ordinary NumPy
import optax                           # The Optax library
import tensorflow as tf                # TensorFlow for preprocessing operations (`tf.cast`, `tf.reshape`)
import tensorflow_datasets as tfds     # TFDS for the dataset

Create a simple [variational autoencoder](https://arxiv.org/abs/1312.6114) model, subclassed from [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics).

In [3]:
class Encoder(nn.Module):
  latents: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(500, name='fc1')(x)
    x = nn.relu(x)
    mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
    logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
    return mean_x, logvar_x

In [4]:
class Decoder(nn.Module):

  @nn.compact
  def __call__(self, z):
    z = nn.Dense(500, name='fc1')(z)
    z = nn.relu(z)
    z = nn.Dense(784, name='fc2')(z)
    return z

In [5]:
def reparameterize(rng, mean, logvar):
  std = jnp.exp(0.5 * logvar)
  eps = jax.random.normal(rng, logvar.shape)
  return mean + eps * std

In [6]:
class VAE(nn.Module):
  latents: int = 20

  def setup(self):
    self.encoder = Encoder(self.latents)
    self.decoder = Decoder()

  def __call__(self, x, z_rng):
    mean, logvar = self.encoder(x)
    z = reparameterize(z_rng, mean, logvar)
    recon_x = self.decoder(z)
    return recon_x, mean, logvar

  def generate(self, z):
    return nn.sigmoid(self.decoder(z))

In [7]:
def model():
  return VAE(latents=latents)

Define the binary cross-entropy loss function.

In addition to optimizers, Optax provides a number of common loss functions, including [`optax.sigmoid_binary_cross_entropy`](https://optax.readthedocs.io/en/latest/api.html#optax.sigmoid_binary_cross_entropy).

In [8]:
@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
 logits = nn.log_sigmoid(logits)
 return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))

Define the KL divergence:

In [9]:
@jax.vmap
def kl_divergence(mean, logvar):
  return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

Create a function for the loss and accuracy metrics:

In [10]:
def compute_metrics(recon_x, x, mean, logvar):
  bce_loss = binary_cross_entropy_with_logits(recon_x, x).mean()
  kld_loss = kl_divergence(mean, logvar).mean()
  return {
      'bce': bce_loss,
      'kld': kld_loss,
      'loss': bce_loss + kld_loss
  }

Define the training step function. Note that:

- During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply).
- The model constructor argument should be set to `training=True`.

In [11]:
@jax.jit
def train_step(state, batch, z_rng):
  """Train for a single step."""
  def loss_fn(params):
    # Perform the forward pass with `flax.linen.apply()`.
    recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)
    # Calculate the binary cross-entropy loss.
    bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
    # Calculate the KL divergence loss.
    kld_loss = kl_divergence(mean, logvar).mean()
    # Calculate the total loss.
    loss = bce_loss + kld_loss
    return loss
  # Compute the gradients
  grads = jax.grad(loss_fn)(state.params)
  return state.apply_gradients(grads=grads)

Write the evaluation step function. Remember to set the model constructor argument to `training=false`.

In [12]:
@jax.jit
def eval_step(params, images, z, z_rng):
  def eval_model(vae):
    recon_images, mean, logvar = vae(images, z_rng)
    comparison = jnp.concatenate([images[:8].reshape(-1, 28, 28, 1),
                                  recon_images[:8].reshape(-1, 28, 28, 1)])

    generate_images = vae.generate(z)
    generate_images = generate_images.reshape(-1, 28, 28, 1)
    metrics = compute_metrics(recon_images, images, mean, logvar)
    return metrics, comparison, generate_images

  return nn.apply(eval_model, model())({'params': params})

Download the dataset and split it into training and test sets:

In [13]:
def prepare_image(x):
  x = tf.cast(x['image'], tf.float32)
  x = tf.reshape(x, (-1,))
  return x

# Write a function for loading your dataset with [TensorFlow Datasets](https://www.tensorflow.org/datasets):
def get_datasets():
  ds_builder = tfds.builder('binarized_mnist')
  ds_builder.download_and_prepare()
  train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
  train_ds = train_ds.map(prepare_image)
  train_ds = train_ds.cache()
  train_ds = train_ds.repeat()
  train_ds = train_ds.shuffle(50000)
  train_ds = train_ds.batch(batch_size)
  train_ds = iter(tfds.as_numpy(train_ds))

  test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
  test_ds = test_ds.map(prepare_image).batch(10000)
  test_ds = np.array(list(test_ds)[0])
  test_ds = jax.device_put(test_ds)

  return train_ds, test_ds

In [14]:
batch_size = 128

train_ds, test_ds = get_datasets()

Downloading and preparing dataset 104.68 MiB (download: 104.68 MiB, generated: Unknown size, total: 104.68 MiB) to /root/tensorflow_datasets/binarized_mnist/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/binarized_mnist/1.0.0.incompleteBBC09W/binarized_mnist-train.tfrecord*...:…

Generating validation examples...:   0%|          | 0/10000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/binarized_mnist/1.0.0.incompleteBBC09W/binarized_mnist-validation.tfrecord…

Generating test examples...:   0%|          | 0/10000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/binarized_mnist/1.0.0.incompleteBBC09W/binarized_mnist-test.tfrecord*...: …

Dataset binarized_mnist downloaded and prepared to /root/tensorflow_datasets/binarized_mnist/1.0.0. Subsequent calls will reuse this data.


Use a JAX PRNG key and split it to get one key for parameter initialization:

In [15]:
seed = 0
rng = jax.random.PRNGKey(seed=seed)
rng, key = jax.random.split(key=rng)

In [16]:
latents = 20

rng, z_key, eval_rng = jax.random.split(rng, 3)
z = jax.random.normal(z_key, (64, latents))

Create and initialize the Flax [`TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#train-state) with an [Optax](https://optax.readthedocs.io/) optimizer. Remember that:

- When initializing the variables, use the `params_key` PRNG key (the `params_key` is equivalent to a dictionary of PRNGs).
- The model constructor is `training=False` before you start training.

This example uses the Yogi optimizer ([`optax.adabelief`](https://optax.readthedocs.io/en/latest/api.html#yogi)).

In [17]:
def create_train_state(rng, learning_rate_fn):
  # Instantiate the model with `training=False`.
  init_data = jnp.ones((batch_size, 784), jnp.float32)
  params = model().init(key, init_data, rng)['params']
  # Use an Optax optimizer.
  # The `learning_rate_fn` is an Optax learning rate schedule (defined further below).
  tx = optax.yogi(learning_rate_fn)
  return train_state.TrainState.create(
      apply_fn=model().apply, params=params, tx=tx)

Define a learning rate schedule that uses [`optax.linear_schedule`](https://optax.readthedocs.io/en/latest/api.html#optax.linear_schedule) and [`optax.cosine_decay_schedule`](https://optax.readthedocs.io/en/latest/api.html#optax.cosine_decay_schedule) (cosine learning rate decay).

Note: You can learn more about Optax, its optimizers, loss functions and schedules in the [Optax tutorial](https://optax.readthedocs.io/en/latest/optax-101.html).

In [18]:
def create_learning_rate_fn(base_learning_rate, warmup_steps, steps_per_epoch):
  warmup_fn = optax.linear_schedule(
      init_value=0.0,
      end_value=base_learning_rate,
      transition_steps=warmup_steps * steps_per_epoch,
  )
  cosine_epochs = max(num_epochs - warmup_steps, 1)
  cosine_fn = optax.cosine_decay_schedule(
      init_value=base_learning_rate,
      decay_steps=cosine_epochs * steps_per_epoch
  )
  schedule_fn = optax.join_schedules(
      schedules=[warmup_fn, cosine_fn],
      boundaries=[warmup_steps, steps_per_epoch],
  )

  return schedule_fn

Instantiate the learning rate schedule:

In [19]:
num_epochs = 10 # For simplicity, train for 10 epochs.
base_learning_rate = 0.0001
warmup_steps = 1.0
steps_per_epoch = 10 

learning_rate_fn = create_learning_rate_fn(
    base_learning_rate=base_learning_rate,
    warmup_steps=warmup_steps,
    steps_per_epoch=steps_per_epoch
    )

Initialize the Flax `TrainState`, passing in the learning rate schedule for the optimizer:

In [20]:
state = create_train_state(rng=key, learning_rate_fn=learning_rate_fn)

Train the model for 10 epochs:

In [21]:
for epoch in range(num_epochs):
  for _ in range(steps_per_epoch):
    batch = next(train_ds)  
    # Use a separate PRNG key to permute image data during shuffling.
    rng, key = jax.random.split(key=key)
      # Run an optimization step over a training batch.
    state = train_step(state=state, batch=batch, z_rng=key)

  metrics, comparison, sample = eval_step(params=state.params, images=test_ds, z=z, z_rng=eval_rng)
  print('Eval epoch: %d, loss: %.2f, binary cross-entropy: %.2f, KL divergence: %.2f' % (
      epoch+1, metrics['loss'], metrics['bce'], metrics['kld']))

Eval epoch: 1, loss: 542.03, binary cross-entropy: 541.22, KL divergence: 0.81
Eval epoch: 2, loss: 496.17, binary cross-entropy: 494.80, KL divergence: 1.37
Eval epoch: 3, loss: 457.73, binary cross-entropy: 455.39, KL divergence: 2.34
Eval epoch: 4, loss: 422.83, binary cross-entropy: 418.90, KL divergence: 3.93
Eval epoch: 5, loss: 394.51, binary cross-entropy: 388.54, KL divergence: 5.97
Eval epoch: 6, loss: 375.41, binary cross-entropy: 367.59, KL divergence: 7.83
Eval epoch: 7, loss: 365.03, binary cross-entropy: 356.02, KL divergence: 9.01
Eval epoch: 8, loss: 360.92, binary cross-entropy: 351.43, KL divergence: 9.49
Eval epoch: 9, loss: 360.18, binary cross-entropy: 350.60, KL divergence: 9.58
Eval epoch: 10, loss: 360.18, binary cross-entropy: 350.60, KL divergence: 9.58
