# Going further with VAEs

Building on the simple VAE exemple, this notebook is organized as follows:
1. Convolutional VAE: a stronger architecture, with a few computational tricks to make it work
2. [$\beta$-VAE](https://openreview.net/forum?id=Sy2fzU9gl), increasing the importance of the KL term
3. Conditional VAEs, typically conditionning the modeled density by a class
4. Exploring [Importance weighted Autoencoders](https://arxiv.org/abs/1509.00519) (IWAE)
5. Final thoughts: can VAEs efficiently model simple toy datasets (distributions)?

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from tensorflow.keras.datasets import fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
plt.figure(figsize=(16, 8))
for i in range(0, 18):
    plt.subplot(3, 6, i + 1)
    plt.imshow(x_train[i], cmap="gray")
    plt.axis("off")
plt.show()

x_train = np.expand_dims(x_train.astype('float32') / 255., -1)
x_test = np.expand_dims(x_test.astype('float32') / 255., -1)
x_train.shape, x_test.shape

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, grad
from jax import random
rand_key = random.PRNGKey(1)

## Convolutional VAE

A convolutional VAE will make use of the spatial structure of the images, and rely on convolutions instead of Dense layers, both in the encoder and decoder



In [None]:
Gelu = (lambda rng, input_shape: (input_shape, ()), 
        jit(lambda params, inputs: inputs * jax.nn.sigmoid(1.702*inputs))
       )

In [None]:
from jax.experimental import stax # neural network library
from jax.experimental.stax import Dense, Relu, Sigmoid, Conv, BatchNorm, MaxPool, Flatten

input_dim = x_train_standard.shape[-1]
hidden_dim = 128
latent_dim = 2

encoder_init, encoder_fn = stax.serial(
    Conv(32, (3, 3), padding="SAME"), BatchNorm(), Relu, 
    Conv(32, (3, 3), padding="SAME"), BatchNorm(), Relu, 
    MaxPool((2,2), (2,2)),
    Conv(64, (3, 3), padding="SAME"), BatchNorm(), Relu, 
    Conv(64, (3, 3), padding="SAME"), BatchNorm(), Relu, 
    MaxPool((2,2), (2,2)),
    Flatten, Dense(latent_dim*2)
)

#initialize the parameters
rand_key, key = random.split(rand_key)
out_shape, params_enc = encoder_init(rand_key, (-1, 28, 28, 1))

def count_params(params):
    count = 0
    for param_tuple in params:
        for param in param_tuple:
            count += np.prod(param.shape)
    return count

params_num = len(params_enc)
print(f"Number of param objects: {params_num}, total number of params: {count_params(params_enc)}")

In [None]:
%time z = jit(encoder_fn)(params_enc, x_train[0:10])
%time z = encoder_fn(params_enc, x_train[0:10])
print(f"output shape: {z.shape}")

In [None]:
def sample(rand_key, z_mean, z_log_var):
    epsilon = random.normal(rand_key, shape=z_mean.shape)
    return z_mean + jnp.exp(z_log_var / 2) * epsilon

fast_sample = jit(sample)

In [None]:
z = encoder_fn(params_enc, x_train[0:3])
d = z.shape[-1]//2
z_mean, z_log_var = z[:, :d], z[:,d:]
rand_key, key = random.split(rand_key)
samples = sample(key, z_mean, z_log_var)
print(f"z shape (concatenation of z_mean and z_log_var) : {z.shape}, samples shape: {samples.shape}")

In [None]:
from functools import partial

def _upsample(x):
    x = x.transpose((0,3,1,2))
    upx = jnp.kron(x, jnp.ones((2,2)))
    return upx.transpose((0,2,3,1))

Upsample = (lambda rng, ish: ((ish[0], ish[1]*2, ish[2]*2, ish[3]), ()), 
        jit(lambda params, inputs, **kwargs: _upsample(inputs))
       )
def Reshape(shape):
    def init_fun(rng, input_shape):
        return (input_shape[0],) + shape, ()

    def apply_fun(params, inputs, **kwargs):
        return inputs.reshape((inputs.shape[0],) + shape)
    
    return init_fun, apply_fun

In [None]:
Reshape((7, 7, 64))[1]((), jnp.ones([2,7*7*64])).shape

In [None]:
plt.imshow(_upsample(x_train[0:3])[0,:,:,0], cmap="gray");

In [None]:
decoder_init, decoder_fn = stax.serial(
    Dense(64*7*7), Reshape((7,7,64)),
    Upsample, 
    Conv(32, (3, 3), padding="SAME"), BatchNorm(), Relu, 
    Conv(32, (3, 3), padding="SAME"), BatchNorm(), Relu, 
    Upsample,
    Conv(64, (3, 3), padding="SAME"), BatchNorm(), Relu, 
    Conv(1, (3, 3), padding="SAME"), Sigmoid
)

#initialize the parameters
rand_key, key = random.split(rand_key)
out_shape, params_dec = decoder_init(rand_key, (-1, latent_dim))

params = params_enc + params_dec

In [None]:
print(f"Number of params: {count_params(params_dec)}")
decoder_fn(params_dec, samples).shape

In [None]:
# Decoder function class prior on image generation
rand_key, key = random.split(key)
z = random.normal(key, shape=(1,latent_dim))
generated = decoder_fn(params_dec, z)
plt.imshow(generated.reshape(28, 28), cmap=plt.cm.gray)
plt.axis('off');

In [None]:
EPSILON = 1e-6
xent = jit(lambda x, xt: - jnp.sum(x * jnp.log(xt + EPSILON) + (1-x)*jnp.log(1-xt+EPSILON), axis=(1,2,3)))
kl = jit(lambda z_mean, z_log_var: - 0.5 * jnp.sum(1 + z_log_var - z_mean ** 2 - jnp.exp(z_log_var), axis=(-1)))

@jit
def vae_loss(rand_key, params, x):
    latent = jit(encoder_fn)(params[0:params_num], x)
    d = latent.shape[-1]//2
    z_mean, z_log_var = latent[:, :d], latent[:,d:]
    z_sample = sample(rand_key, z_mean, z_log_var)
    x_rec = jit(decoder_fn)(params[params_num:], z_sample)
    
    xent_loss = xent(x, x_rec)
    kl_loss = kl(z_mean, z_log_var)
    return jnp.mean(xent_loss + kl_loss)


In [None]:
%time vae_loss(rand_key, params, x_train[0:32])

### Training the VAE

The following cells:
    - reinitialize parameters
    - initialize an Adam optimizer
    - run a batch training over 5 epochs

In [None]:
# You may run this cell to reinit parameters if needed
_, params_enc = encoder_init(rand_key, (-1, 28,28,1))
_, params_dec = decoder_init(rand_key, (-1, latent_dim))
params = params_enc + params_dec

In [None]:
from jax.experimental import stax, optimizers

data_size = x_train.shape[0]
batch_size = 32
learning_rate = 0.003

opt_init, opt_update, get_params = optimizers.adam(learning_rate)
opt_state = opt_init(params)

losses = []

In [None]:
@jit
def update(key, batch, opt_state):
    params = get_params(opt_state)
    value_and_grad_fun = jit(jax.value_and_grad(lambda params, x: vae_loss(key, params, x)))
    loss, grads = value_and_grad_fun(params, batch)
    opt_state = opt_update(0, grads, opt_state)
    return opt_state, loss

In [None]:
for epochs in range(1):
    for i in range(10):  #data_size // 32 -1):
        batch = x_train[i * 32:(i+1)*32]
        rand_key, key = random.split(rand_key)
        opt_state, loss = update(key, batch, opt_state)
        losses.append(loss)

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses);

In [None]:
params = get_params(opt_state)
rand_key, key = random.split(key)
z = random.normal(key, shape=(1,latent_dim))
generated = decoder_fn(params[params_num:], z)
plt.imshow(generated.reshape(28, 28), cmap=plt.cm.gray)
plt.axis('off');

### 2D plot of the image classes in the latent space

We can also use the encoder to set the visualize the distribution of the test set in the 2D latent space of the VAE model. In the following the colors show the true class labels from the test samples.

Note that the VAE is an unsupervised model: it did not use any label information during training. However we can observe that the 2D latent space is largely structured around the categories of images used in the training set.

In [None]:
id_to_labels = {0: "T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 
                5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"}

In [None]:
x_test_encoded = encoder_fn(params[0:params_num], x_test)
plt.figure(figsize=(7, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test,
            cmap=plt.cm.tab10)
cb = plt.colorbar()
cb.set_ticks(list(id_to_labels.keys()))
cb.set_ticklabels(list(id_to_labels.values()))
cb.update_ticks()
plt.show()

### 2D panel view of samples from the VAE manifold

The following linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian to produce values of the latent variables z. This makes it possible to use a square arangement of panels that spans the gaussian prior of the latent space.

In [None]:
n = 15  # figure with 15x15 panels
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n)).astype(np.float32)
grid_y = norm.ppf(np.linspace(0.05, 0.95, n)).astype(np.float32)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        x_decoded = decoder_fn(params[params_num:], z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

## Importance Weighted Autoencoders (IWAE)

The idea is simply to sample several latent individuals from the gaussian distribution instead of a single one. Sample $\{z_i\}^k \sim q(z|x)$. The aim is to get a more accurate loss which represents better the distribution instead of a single point MC estimate of $\mathbb{E} q(z|x)$

$$\mathcal{L}_k = \sum_i^k  \tilde{w_i} \nabla_Î¸ [log p(x, z_i) - log q(z_i|x)]$$

$$w_i = \frac{p(x, z_i)}{q(z_i|x)} ; \tilde{w_i} = \frac{w_i}{\sum_i^k  w_i}$$

In [None]:
k_samples = 10

# Sample now has an additional parameter k_samples
def sample(rand_key, z_mean, z_log_var, k_samples):
    epsilon = random.normal(rand_key, shape=(k_samples,) + z_mean.shape)
    return z_mean + jnp.exp(z_log_var / 2) * epsilon

@jit
def vae_loss(rand_key, params, x):
    latent = jit(encoder_fn)(params[0:3], x)
    d = latent.shape[-1]//2
    z_mean, z_log_var = latent[:, :d], latent[:,d:]
    z_sample = sample(rand_key, z_mean, z_log_var, k_samples)
    
    # decoding applies to each of the samples
    x_rec = jit(decoder_fn)(params[3:], z_sample)
    
    # these terms apply to each of the samples
    xent_loss = xent(x, x_rec)
    kl_loss = kl(z_mean, z_log_var)
    
    # softmax of the log_w_i corresponds to the normalized weights
    log_w_i = xent_loss + kl_loss
    normalized_w_i = jax.lax.stop_gradient(jax.nn.softmax(xent_loss, axis=0))
    
    weighted_sum = (normalized_w_i * (xent_loss + kl_loss)).sum(axis=0)
    
    # average over the batch
    loss = jnp.mean(weighted_sum) 
    return loss

In [None]:
%time vae_loss(rand_key, params, x_train_standard[0:10])

# A simpler dataset

Up to now, we used the VAE on complex structured data (Fashion MNIST), and it can be seen as a dimensionality reduction method, not unlike a [probabilistic PCA](http://edwardlib.org/tutorials/probabilistic-pca).

The following explores how a VAE captures the distribution of toy datasets

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def eightgaussian(n_points):
    """
     Returns the eight gaussian dataset.
    """
    n = np.random.randint(0,8, n_points)
    noisex = np.random.normal(size=(n_points)) * 0.1
    noisey = np.random.normal(size=(n_points)) * 0.1
    x_centers,y_centers = [np.cos(n* np.pi/4.0) * 5 + noisex, np.sin(n* np.pi/4.0) * 5 + noisey]
    return np.vstack((x_centers,y_centers)).T
            
X = eightgaussian(1000)
X.shape
plt.scatter(X[:,0], X[:,1], s=1);

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, grad
from jax import random
rand_key = random.PRNGKey(1)
from jax.experimental import stax # neural network library
from jax.experimental.stax import Dense, Relu, Sigmoid, Selu

input_dim = X.shape[-1]
hidden_dim = 512
latent_dim = 2
k_samples = 1
beta = 1

encoder_init, encoder_fn = stax.serial(
    Dense(hidden_dim), Selu, Dense(hidden_dim), Selu, Dense(hidden_dim), Selu, Dense(latent_dim * 2))

#initialize the parameters
rand_key, key = random.split(rand_key)
out_shape, params_enc = encoder_init(rand_key, (-1, input_dim))

def sample(rand_key, z_mean, z_log_var, k_samples):
    epsilon = random.normal(rand_key, shape=(k_samples,) + z_mean.shape)
    return z_mean + jnp.exp(z_log_var / 2) * epsilon

fast_sample = jit(sample)

decoder_init, decoder_fn = stax.serial(
    Dense(hidden_dim), Selu, Dense(hidden_dim), Selu, Dense(hidden_dim), Selu, Dense(input_dim))

#initialize the parameters
rand_key, key = random.split(rand_key)
out_shape, params_dec = decoder_init(rand_key, (-1, latent_dim))

params = params_enc + params_dec

EPSILON = 1e-6
l2 = jit(lambda x, xt: jnp.sum((x - xt)**2, axis=-1))
kl = jit(lambda z_mean, z_log_var: - 0.5 * jnp.sum(1 + z_log_var - z_mean ** 2 - jnp.exp(z_log_var), axis=-1))

@jit
def vae_loss(rand_key, params, x):
    latent = jit(encoder_fn)(params[0:7], x)
    d = latent.shape[-1]//2
    z_mean, z_log_var = latent[:, :d], latent[:,d:]
    z_sample = sample(rand_key, z_mean, z_log_var, k_samples)
    x_rec = jit(decoder_fn)(params[7:], z_sample)
    l2_loss = l2(x, x_rec)
    kl_loss = kl(z_mean, z_log_var)
    loss = jnp.mean(l2_loss + kl_loss) 
    return loss
    '''
    
    #log_w_i = jax.lax.stop_gradient(xent_loss + kl_loss)
    log_w_i = xent_loss + kl_loss
    print(log_w_i.shape)
    normalized_w_i = jax.lax.stop_gradient(jax.nn.softmax(log_w_i, axis=-1))
    print(normalized_w_i.shape)
    weighted_sum = (normalized_w_i * (xent_loss + kl_loss)).sum(axis=0)
    print(weighted_sum.shape)
    # average over the batch, and sum kl / xent
    #loss = jnp.mean(xent_loss) + jnp.mean(kl_loss) 
    return jnp.mean(weighted_sum) '''

# You may run this cell to reinit parameters if needed
_, params_enc = encoder_init(rand_key, (-1, input_dim))
_, params_dec = decoder_init(rand_key, (-1, latent_dim))
params = params_enc + params_dec

from jax.experimental import stax, optimizers

data_size = X.shape[0]
batch_size = 32
learning_rate = 0.0003

opt_init, opt_update, get_params = optimizers.adam(learning_rate)
opt_state = opt_init(params)

losses = []

In [None]:
@jit
def update(key, batch, opt_state):
    params = get_params(opt_state)
    value_and_grad_fun = jit(jax.value_and_grad(lambda params, x: vae_loss(key, params, x)))
    loss, grads = value_and_grad_fun(params, batch)
    opt_state = opt_update(0, grads, opt_state)
    return opt_state, loss

In [None]:
iters = 1000
data_generator = (X[np.random.choice(X.shape[0], 32)] for _ in range(iters))

for epochs in range(iters):
    batch = X[i * 32:(i+1)*32]
    rand_key, key = random.split(rand_key)
    opt_state, loss = update(key, next(data_generator), opt_state)
    losses.append(loss)

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses);

In [None]:
params = get_params(opt_state)
params_enc = params[0:7]
params_dec = params[7:]

x_encoded = np.asarray(encoder_fn(params_enc, X[:]))
plt.figure(figsize=(7, 7))
z_mean, z_log_var = x_encoded[:,0:latent_dim], x_encoded[:,latent_dim:]
rand_key, key = random.split(rand_key)
z_sample = sample(rand_key, z_mean, z_log_var, k_samples)
z_sample = np.reshape(z_sample, (k_samples * x_encoded.shape[0],latent_dim))
plt.scatter(z_sample[:, 0], z_sample[:, 1], s=1, c="r")
plt.scatter(x_encoded[:, 0], x_encoded[:, 1], s=2, c="b")
plt.show()

In [None]:
np.mean(np.exp(x_encoded[:,2]/2)), np.mean(np.exp(x_encoded[:,3]/2))

In [None]:
generated = np.asarray(decoder_fn(params_dec, z_mean))
plt.figure(figsize=(7, 7))
plt.scatter(X[:, 0], X[:, 1], s=1, c="b")
plt.scatter(generated[:, 0], generated[:, 1], s=1, c="r")
plt.show()

In [None]:
rand_key, key = random.split(rand_key)
z = random.normal(key, shape=(1000,latent_dim))
generated = np.asarray(decoder_fn(params_dec, z))
plt.figure(figsize=(7, 7))
plt.scatter(X[:, 0], X[:, 1], s=1, c="b")
plt.scatter(generated[:, 0], generated[:, 1], s=2, c="r")
plt.show()