Click to open in Colab to access a GPU environment: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/charlesollion/dlexperiments/blob/master/Going%20Further%20with%20VAEs.ipynb)

# Going further with VAEs

Building on the simple VAE exemple, this notebook is organized as follows:
1. Convolutional VAE: a stronger architecture, with a few computational tricks to make it work
2. [$\beta$-VAE](https://openreview.net/forum?id=Sy2fzU9gl), increasing the importance of the KL term
3. Better parametrization of the output distribution
4. Conditional VAEs, typically conditionning the modeled density by a class
5. Exploring [Importance weighted Autoencoders](https://arxiv.org/abs/1509.00519) (IWAE)
6. Final thoughts: can VAEs efficiently model simple toy datasets (distributions)?

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from tensorflow.keras.datasets import fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
plt.figure(figsize=(16, 8))
for i in range(0, 18):
    plt.subplot(3, 6, i + 1)
    plt.imshow(x_train[i], cmap="gray")
    plt.axis("off")
plt.show()

x_train = np.expand_dims(x_train.astype('float32') / 255., -1)
x_test = np.expand_dims(x_test.astype('float32') / 255., -1)
x_train.shape, x_test.shape

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, grad
from jax import random
rand_key = random.PRNGKey(1)
from jax.experimental import stax # neural network library
from jax.experimental.stax import Dense, Relu, Sigmoid, Selu


## 1. Convolutional VAE

A convolutional VAE will make use of the spatial structure of the images, and rely on convolutions instead of Dense layers, both in the encoder and decoder



In [None]:
from jax.experimental.stax import Conv, BatchNorm, MaxPool, Flatten

input_dim = x_train.shape[-1]
hidden_dim = 128
latent_dim = 8

encoder_init, encoder_fn = stax.serial(
    Conv(32, (3, 3), padding="SAME"), BatchNorm(), Selu, 
    Conv(32, (3, 3), padding="SAME"), BatchNorm(), Selu, 
    MaxPool((2,2), (2,2)),
    Conv(64, (3, 3), padding="SAME"), BatchNorm(), Selu, 
    Conv(64, (3, 3), padding="SAME"), BatchNorm(), Selu, 
    MaxPool((2,2), (2,2)),
    Flatten, Dense(latent_dim*2)
)

#initialize the parameters
rand_key, key = random.split(rand_key)
out_shape, params_enc = encoder_init(rand_key, (-1, 28, 28, 1))

def count_params(params):
    count = 0
    for param_tuple in params:
        for param in param_tuple:
            count += np.prod(param.shape)
    return count

params_num = len(params_enc)
print(f"Number of param objects: {params_num}, total number of params: {count_params(params_enc)}")

In [None]:
%time z = jit(encoder_fn)(params_enc, x_train[0:10])
%time z = encoder_fn(params_enc, x_train[0:10])
print(f"output shape: {z.shape}")

In [None]:
def sample(rand_key, z_mean, z_log_var):
    epsilon = random.normal(rand_key, shape=z_mean.shape)
    return z_mean + jnp.exp(z_log_var / 2) * epsilon

fast_sample = jit(sample)

In [None]:
z = encoder_fn(params_enc, x_train[0:3])
d = z.shape[-1]//2
z_mean, z_log_var = z[:, :d], z[:,d:]
rand_key, key = random.split(rand_key)
samples = sample(key, z_mean, z_log_var)
print(f"z shape (concatenation of z_mean and z_log_var) : {z.shape}, samples shape: {samples.shape}")

In [None]:
# We need a Upsample and Reshape layer, we can build them using the stax layer API. 
# A layer is a tuple: (init_fun, apply_fun): 
# init_fun creates the parameters matrices and return the params and output shape
# apply_fun is the mathematical operation

def Upsample():
    def init_fun(rng, input_shape):
        ish = input_shape
        assert len(ish) == 4
        return ((ish[0], ish[1]*2, ish[2]*2, ish[3]), ())
    
    def apply_fun(params, inputs, **kwargs):
        inputs = inputs.transpose((0,3,1,2))
        upx = jnp.kron(inputs, jnp.ones((2,2)))
        return upx.transpose((0,2,3,1))
    
    return init_fun, apply_fun

def Reshape(shape):
    def init_fun(rng, input_shape):
        return input_shape[:-1] + shape, ()

    def apply_fun(params, inputs, **kwargs):
        return inputs.reshape(inputs.shape[:-1] + shape)
    
    return init_fun, apply_fun

In [None]:
Reshape((7, 7, 64))[1]((), jnp.ones([2,7*7*64])).shape

In [None]:
plt.imshow(Upsample()[1]([], x_train[0:3])[0,:,:,0], cmap="gray");

In [None]:
decoder_init, decoder_fn = stax.serial(
    Dense(64*7*7), Reshape((7,7,64)),
    Upsample(), 
    Conv(32, (3, 3), padding="SAME"), BatchNorm(), Selu, 
    Conv(32, (3, 3), padding="SAME"), BatchNorm(), Selu, 
    Upsample(),
    Conv(64, (3, 3), padding="SAME"), BatchNorm(), Selu, 
    Conv(64, (3, 3), padding="SAME"), BatchNorm(), Selu, 
    Conv(1, (3, 3), padding="SAME"), Sigmoid
)

#initialize the parameters
rand_key, key = random.split(rand_key)
out_shape, params_dec = decoder_init(rand_key, (-1, latent_dim))

params = params_enc + params_dec

In [None]:
print(f"Number of params: {count_params(params_dec)}")
decoder_fn(params_dec, samples).shape

## 2. Bernoulli Output distribution

Previously, we assumed that the data is valued on $\{0, 1\}$ and used a Bernoulli distribution and the associated crossentropy loss.

This is very common and the problem is pointed by [this paper](https://arxiv.org/pdf/1907.06845.pdf) which proposes a different parametrization of the output instead of parametrizing a Bernoulli likelihood. As the pixel data is $[0, 1]$ valued, they introduce a continuous Bernoulli distribution, which is a normalized Bernoulli:
$$p(x|λ) = C(λ) λ^x (1 − λ)^{1−x}$$

See the paper for details about the $C(λ)$. The following cell computes the log of this normalization factor:

In [None]:
EPSILON = 1e-6

@jit
def cont_bern_log_norm(l):
    # computes the log normalizing constant of a continuous Bernoulli distribution in a numerically stable way.
    # When l is in [lower_lim, upper_lim], we cut it to lower_lim
    lower_lim, upper_lim =0.49, 0.51
    
    cut_l = jnp.where(jnp.logical_or(l < lower_lim, l > upper_lim), l, lower_lim * jnp.ones_like(l))
    log_norm = jnp.log(jnp.abs(2.0 * jnp.arctanh(jnp.abs(1 - 2.0 * cut_l) - EPSILON))) - jnp.log(jnp.abs(1 - 2.0 * cut_l) + EPSILON)
    return log_norm

xent = jit(lambda x, xt: - jnp.sum(xt * jnp.log(x + EPSILON) + 
                                   (1-xt)*jnp.log(1-x+EPSILON) + 
                                   cont_bern_log_norm(x), axis=(1,2,3)))
kl = jit(lambda z_mean, z_log_var: - 0.5 * jnp.sum(1 + z_log_var - z_mean ** 2 - 
                                                   jnp.exp(z_log_var), axis=(-1)))

In [None]:
# the normalizing factor can be seen below as a function of lambda
plt.plot(jnp.linspace(0,1,100), cont_bern_log_norm(jnp.linspace(0,1,100)));

Using standard Bernoulli, we had the property that the output of our decoder, i.e. parameters $\lambda_i$ of the Bernoulli distribution, where equal to the expected value of the output:

$$X \sim \mathcal{B}(\lambda) \implies \mathbb{E}[X] = \lambda$$ for each cooridinate of $X$ and $\lambda$.

To sample from a continuous Bernoulli distribution, we will have to compute the following:

$$X \sim \mathcal{CB}(\lambda) \implies \mathbb{E}[X] = \frac{\lambda}{2\lambda - 1} + \frac{1}{2tanh^{-1}(1-2\lambda)}$$

if $\lambda \neq 0.5$, otherwise $\mathbb{E}[X] = 0.5$

In [None]:
def mean_from_lambda(l):
    # Computes the mean of output distribution given lambda parameter, in a numerically stable way
    lower_lim, upper_lim =0.49, 0.51
    
    cut_l = jnp.where(jnp.logical_or(l < lower_lim, l > upper_lim), l, lower_lim * jnp.ones_like(l))
    mean = cut_l / (2.0 * cut_l - 1.0) + 1.0 / (2.0 * jnp.arctanh(1.0 - 2.0 * cut_l))
    return mean

In [None]:
# the expected output as a function of lambda
plt.plot(jnp.linspace(0,1,100), mean_from_lambda(jnp.linspace(0,1,100)))
plt.plot(jnp.linspace(0,1,100), jnp.linspace(0,1,100), ls="dashed");

In [None]:
# Decoder function class prior on image generation, before training
rand_key, key = random.split(key)
z = random.normal(key, shape=(1,latent_dim))
generated = decoder_fn(params_dec, z)
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(generated.reshape(28, 28), cmap=plt.cm.gray)
plt.axis('off');
plt.subplot(1, 2, 2)
plt.imshow(mean_from_lambda(generated).reshape(28, 28), cmap=plt.cm.gray)
plt.axis('off');

## 2. $\beta$ - VAE

- [This paper](https://openreview.net/forum?id=Sy2fzU9gl) introduce a minimal modification of the VAE loss, by scaling the KL, with usually $\beta \gt 1$, to increase the weight of the KL term in the loss. This $\beta$ balances latent channel capacity and independence constraints with reconstruction accuracy. They show that this modification is sound and not just a trick, however it adds a new hyperparameter to tune.

In [None]:
EPSILON = 1e-6
beta = 2.0

@jit
def vae_loss(rand_key, params, x):
    latent = jit(encoder_fn)(params[0:params_num], x)
    d = latent.shape[-1]//2
    z_mean, z_log_var = latent[:, :d], latent[:,d:]
    z_sample = sample(rand_key, z_mean, z_log_var)
    x_rec = jit(decoder_fn)(params[params_num:], z_sample)
    
    xent_loss = xent(x_rec, x)
    kl_loss = kl(z_mean, z_log_var)
    return jnp.mean(xent_loss + beta * kl_loss)


In [None]:
%time vae_loss(rand_key, params, x_train[0:32])

### Training the VAE

The following cells:
    - reinitialize parameters
    - initialize an Adam optimizer
    - run a batch training over 5 epochs

In [None]:
# You may run this cell to reinit parameters if needed
_, params_enc = encoder_init(rand_key, (-1, 28,28,1))
_, params_dec = decoder_init(rand_key, (-1, latent_dim))
params = params_enc + params_dec

In [None]:
from jax.experimental import stax, optimizers

data_size = x_train.shape[0]
batch_size = 32
learning_rate = 0.003

opt_init, opt_update, get_params = optimizers.adam(learning_rate)
opt_state = opt_init(params)

losses = []

In [None]:
@jit
def update(key, batch, opt_state):
    params = get_params(opt_state)
    value_and_grad_fun = jit(jax.value_and_grad(lambda params, x: vae_loss(key, params, x)))
    loss, grads = value_and_grad_fun(params, batch)
    opt_state = opt_update(0, grads, opt_state)
    return opt_state, loss

In [None]:
# Can be long! 
subset_train_num = 100

for epochs in range(1):
    rand_key, key = random.split(rand_key)
    permutation = random.permutation(key, data_size)
    for i in range(min(data_size // 32 - 1, subset_train_num)):
        batch = x_train[permutation[i * 32:(i+1)*32]]
        rand_key, key = random.split(rand_key)
        opt_state, loss = update(key, batch, opt_state)
        losses.append(loss)

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses);

In [None]:
params = get_params(opt_state)
rand_key, key = random.split(key)
z = random.normal(key, shape=(1,latent_dim))
generated = decoder_fn(params[params_num:], z)

plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.imshow(generated.reshape(28, 28), cmap=plt.cm.gray)
plt.axis('off');
plt.subplot(1, 2, 2)
plt.imshow(mean_from_lambda(generated).reshape(28, 28), cmap=plt.cm.gray)
plt.axis('off');

### 2D plot of the image classes in the latent space

We can also use the encoder to set the visualize the distribution of the test set in the 2D latent space of the VAE model. In the following the colors show the true class labels from the test samples.

Note that the VAE is an unsupervised model: it did not use any label information during training. However we can observe that the 2D latent space is largely structured around the categories of images used in the training set.

In [None]:
id_to_labels = {0: "T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 
                5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"}

In [None]:
import sklearn
from sklearn.decomposition import PCA
x_test_encoded = encoder_fn(params[0:params_num], x_test)
pca = PCA(n_components=2)
encoded_pca_x=pca.fit_transform(x_test_encoded[:,:latent_dim])
plt.figure(figsize=(7, 6))
plt.scatter(encoded_pca_x[:, 0], encoded_pca_x[:, 1], c=y_test,
            cmap=plt.cm.tab10)

cb = plt.colorbar()
cb.set_ticks(list(id_to_labels.keys()))
cb.set_ticklabels(list(id_to_labels.values()))
cb.update_ticks()
plt.show()

### 2D panel view of samples from the VAE manifold

The following linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian to produce values of the latent variables z. This makes it possible to use a square arangement of panels that spans the gaussian prior of the latent space.

In [None]:
n = 15  # figure with 15x15 panels
img_size = 28
figure = np.zeros((img_size * n, img_size * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n)).astype(np.float32)
grid_y = norm.ppf(np.linspace(0.05, 0.95, n)).astype(np.float32)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        # z_sample = np.array([[xi, yi]])  # uncomment if latent_dim = 2
        z_sample = np.dot(np.array([[xi,yi]]), np.array(pca.components_))
        x_decoded = mean_from_lambda(decoder_fn(params[params_num:], z_sample))
        img = x_decoded[0].reshape(img_size, img_size)
        figure[i * img_size: (i + 1) * img_size,
               j * img_size: (j + 1) * img_size] = img

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

# A simpler dataset

Up to now, we used the VAE on complex structured data (Fashion MNIST), and it can be seen as a dimensionality reduction method, which aims at building a coherent, disantangled latent space.

The following explores how a VAE can capture the distribution of toy datasets, instead of high dimensional and heavily correlated data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def eightgaussian(n_points):
    """
     Returns the eight gaussian dataset.
    """
    n = np.random.randint(0,8, n_points)
    noisex = np.random.normal(size=(n_points)) * 0.2
    noisey = np.random.normal(size=(n_points)) * 0.2
    x_centers,y_centers = [np.cos(n* np.pi/4.0) * 5 + noisex, np.sin(n* np.pi/4.0) * 5 + noisey]
    return np.vstack((x_centers,y_centers)).T
            
X = eightgaussian(1000)
X.shape
plt.figure(figsize=(6,6))
plt.scatter(X[:,0], X[:,1], s=1);
plt.axis('square');

In [None]:

input_dim = X.shape[-1]
hidden_dim = 512
latent_dim = 2
k_samples = 1
beta = 4.0

encoder_init, encoder_fn = stax.serial(
    Dense(hidden_dim), Selu, Dense(hidden_dim), Selu, Dense(hidden_dim), Selu, Dense(latent_dim * 2))

#initialize the parameters
rand_key, key = random.split(rand_key)
out_shape, params_enc = encoder_init(rand_key, (-1, input_dim))

def sample(rand_key, z_mean, z_log_var, k_samples):
    epsilon = random.normal(rand_key, shape=(k_samples,) + z_mean.shape)
    return z_mean + jnp.exp(z_log_var / 2) * epsilon

fast_sample = jit(sample)

decoder_init, decoder_fn = stax.serial(
    Dense(hidden_dim), Selu, Dense(hidden_dim), Selu, Dense(hidden_dim), Selu, Dense(input_dim))

#initialize the parameters
rand_key, key = random.split(rand_key)
out_shape, params_dec = decoder_init(rand_key, (-1, latent_dim))

params_enc_num = len(params_enc)
params = params_enc + params_dec

EPSILON = 1e-6
l2 = jit(lambda x, y: jnp.sum((x - y)**2, axis=-1))
kl = jit(lambda z_mean, z_log_var: - 0.5 * jnp.sum(1 + z_log_var - z_mean ** 2 - 
                                                   jnp.exp(z_log_var), axis=-1))

@jit
def vae_loss(rand_key, params, x):
    latent = jit(encoder_fn)(params[0:params_enc_num], x)
    d = latent.shape[-1]//2
    z_mean, z_log_var = latent[:, :d], latent[:,d:]
    z_sample = sample(rand_key, z_mean, z_log_var, k_samples)
    x_rec = jit(decoder_fn)(params[params_enc_num:], z_sample)
    l2_loss = l2(x_rec, x)
    kl_loss = kl(z_mean, z_log_var)
    loss = jnp.mean(l2_loss + beta * kl_loss) 
    return loss
    

from jax.experimental import stax, optimizers

data_size = X.shape[0]
batch_size = 32
learning_rate = 0.0003

opt_init, opt_update, get_params = optimizers.adam(learning_rate)
opt_state = opt_init(params)

@jit
def update(key, batch, opt_state):
    params = get_params(opt_state)
    value_and_grad_fun = jit(jax.value_and_grad(lambda params, x: vae_loss(key, params, x)))
    loss, grads = value_and_grad_fun(params, batch)
    opt_state = opt_update(0, grads, opt_state)
    return opt_state, loss

losses = []

In [None]:
iters = 1000
data_generator = (X[np.random.choice(X.shape[0], 32)] for _ in range(iters))

for epochs in range(iters):
    batch = X[i * 32:(i+1)*32]
    rand_key, key = random.split(rand_key)
    opt_state, loss = update(key, next(data_generator), opt_state)
    losses.append(loss)

import matplotlib.pyplot as plt
plt.plot(losses);

### Latent space analysis

below : means (blue) and samples (red) of encoded dataset.

Note that the latent space has the same shape as the original data (simply scaled), one of the reason of this is that we chose 2D latent space, which is the same as data dimensionality. Maybe this choice of parametrization of latent space was not good!

In [None]:
params = get_params(opt_state)
params_enc, params_dec = params[0:params_enc_num], params[params_enc_num:]
x_encoded = np.asarray(encoder_fn(params_enc, X[:]))
plt.figure(figsize=(7, 7))
z_mean, z_log_var = x_encoded[:,0:latent_dim], x_encoded[:,latent_dim:]
rand_key, key = random.split(rand_key)
z_sample = sample(rand_key, z_mean, z_log_var, k_samples)
z_sample = np.reshape(z_sample, (k_samples * x_encoded.shape[0],latent_dim))
plt.scatter(z_sample[:, 0], z_sample[:, 1], s=1, c="r")
plt.scatter(x_encoded[:, 0], x_encoded[:, 1], s=2, c="b")
plt.show()
print(f"average of predicted latent variances on dataset: {np.mean(np.exp(z_log_var[:,0]/2)):.2f}, {np.mean(np.exp(z_log_var[:,1]/2)):.2f}")

### Recontructed samples vs train samples

In [None]:
generated = decoder_fn(params_dec, z_mean)
plt.figure(figsize=(7, 7))
plt.scatter(X[:, 0], X[:, 1], s=1, c="b")
plt.scatter(generated[:, 0], generated[:, 1], s=1, c="r")
plt.show()

### Input distribution vs sampled distribution

In [None]:
rand_key, key = random.split(rand_key)
z = random.normal(key, shape=(5000,latent_dim))
generated = np.asarray(decoder_fn(params_dec, z))
plt.figure(figsize=(7, 7))
plt.scatter(X[:, 0], X[:, 1], s=1, c="b")
plt.scatter(generated[:, 0], generated[:, 1], s=2, c="r")
plt.show()

In [None]:
from scipy.stats import kde
nbins = 100
k = kde.gaussian_kde(generated.T)
xi, yi = np.mgrid[-6:6:nbins*1j, -6:6:nbins*1j]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))

k = kde.gaussian_kde(X.T)
xio, yio = np.mgrid[-6:6:nbins*1j, -6:6:nbins*1j]
zio = k(np.vstack([xio.flatten(), yio.flatten()]))

plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.pcolormesh(xi, yi, zi.reshape(xi.shape))
plt.title("VAE distribution")
plt.axis('off');
plt.subplot(1, 2, 2)
plt.pcolormesh(xio, yio, zio.reshape(xio.shape))
plt.title("original distribution")
plt.axis('off');

## Importance Weighted Autoencoders (IWAE)

The idea is simply to sample several latent individuals from the gaussian distribution instead of a single one. Sample $\{z_i\}^k \sim q_\phi(z|x)$. The aim is to get a more accurate loss which represents better the distribution instead of a single point MC estimate of $\mathbb{E} q(z|x)$

$$\mathcal{L}_k = \sum_i^k  \tilde{w_i} \nabla_θ [log p(x, z_i) - log q(z_i|x)]$$

$$w_i = \frac{p(x, z_i)}{q(z_i|x)} ; \tilde{w_i} = \frac{w_i}{\sum_i^k  w_i}$$

In [None]:
k_samples = 10

# Sample now has an additional parameter k_samples
def sample(rand_key, z_mean, z_log_var, k_samples):
    epsilon = random.normal(rand_key, shape=(k_samples,) + z_mean.shape)
    return z_mean + jnp.exp(z_log_var / 2) * epsilon

@jit
def vae_loss(rand_key, params, x):
    latent = jit(encoder_fn)(params[0:params_num], x)
    z_mean, z_log_var = latent[:, :latent_dim], latent[:,latent_dim:]
    z_sample = sample(rand_key, z_mean, z_log_var, k_samples)
    
    # decoding applies to each of the samples
    x_rec = jit(decoder_fn)(params[params_num:], z_sample)
    
    # these terms apply to each of the samples
    l2_loss = l2(x, x_rec)
    kl_loss = kl(z_mean, z_log_var)
    
    # softmax of the log_w_i corresponds to the normalized weights
    log_w_i = l2_loss + kl_loss
    normalized_w_i = jax.lax.stop_gradient(jax.nn.softmax(log_w_i, axis=0))
    
    weighted_sum = (normalized_w_i * (l2_loss + kl_loss)).sum(axis=0)
    
    # average over the batch
    loss = jnp.mean(weighted_sum) 
    return loss

from jax.experimental import stax, optimizers

data_size = X.shape[0]
batch_size = 32
learning_rate = 0.0003

opt_init, opt_update, get_params = optimizers.adam(learning_rate)
opt_state = opt_init(params)

@jit
def update(key, batch, opt_state):
    params = get_params(opt_state)
    value_and_grad_fun = jit(jax.value_and_grad(lambda params, x: vae_loss(key, params, x)))
    loss, grads = value_and_grad_fun(params, batch)
    opt_state = opt_update(0, grads, opt_state)
    return opt_state, loss

losses = []

iters = 1000
data_generator = (X[np.random.choice(X.shape[0], 32)] for _ in range(iters))

for epochs in range(iters):
    batch = X[i * 32:(i+1)*32]
    rand_key, key = random.split(rand_key)
    opt_state, loss = update(key, next(data_generator), opt_state)
    losses.append(loss)

import matplotlib.pyplot as plt
plt.plot(losses);

In [None]:
# Here is a new stax layer that implements a Gelu activation function https://arxiv.org/abs/1606.08415
Gelu = (lambda rng, input_shape: (input_shape, ()), 
        jit(lambda params, inputs: inputs * jax.nn.sigmoid(1.702*inputs))
       )