# The reparametrization trick

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

sns.set_theme()

Let's assume we work on e 2-dimensional event space, i.e. with 2-dimensional multivariate Gaussians, with distributions parametrized by the outputs of the encoder part of a variational autoencoder. These are two tensors `z_mean` and `z_log_var` with shape `(n_samples, 2)`, parametrizing respecitively the mean and the log variance of the multivariate Gaussian corresponding to each input samples.

The conversion between log variance and variance is simply
$$
\sigma^2 = \exp\left( \frac{1}{2} \log\left( \sigma^2 \right) \right) = \exp\left( \frac{1}{2} \mathrm{NN}_2(x) \right)\,,
$$
where we emphasized that the second output of the encoder network is right $\mathrm{NN}_2(x) = \log\left( \sigma^2 \right)$.

**Note:** we assume independence along the two dimensions, i.e. the multivariate Gaussians are just the products of two univariate ones.

In [None]:
# Simulate the outputs of the encoder.
n_samples = 5000

z_mean = tf.concat(
    [
        tf.random.normal(shape=(n_samples, 1), mean=-0.05, stddev=0.04),
        tf.random.normal(shape=(n_samples, 1), mean=0.05, stddev=0.04)
    ],
    axis=1
)

z_log_var = tf.concat(
    [
        tf.random.normal(shape=(n_samples, 1), mean=-0.06, stddev=0.04),
        tf.random.normal(shape=(n_samples, 1), mean=-0.09, stddev=0.04)
    ],
    axis=1
)

# Compute the variances once an for all.
z_var = tf.exp(0.5 * z_log_var)

## Reparametrization trick: constant variance

In [None]:
sigma = 1e-1

gaussians = tfd.Independent(
    tfd.Normal(loc=z_mean, scale=sigma * tf.ones_like(z_mean)),
    reinterpreted_batch_ndims=1
)

# Sample from `n_samples` independent 2-dimensional Gaussians
# with means `z_mean` and variance `sigma` (constant across
# distributions - i.e. input samples) along each dimension.
samples_gaussians = gaussians.sample()

# Generate samples with the reparametrization trick.
samples_rt = z_mean + sigma * tfd.Independent(
    tfd.Normal(loc=tf.zeros_like(z_mean), scale=tf.ones_like(z_mean)),
    reinterpreted_batch_ndims=1
).sample()

# Plot samples.
fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=samples_gaussians[:, 0],
    y=samples_gaussians[:, 1],
    color=sns.color_palette()[0],
    alpha=0.4,
    label='Samples without reparametrization'
)
sns.scatterplot(
    x=samples_rt[:, 0],
    y=samples_rt[:, 1],
    color=sns.color_palette()[1],
    alpha=0.4,
    label='Samples with reparametrization'
)

## Reparametrization trick: non-constant variance

Let's now use the values for the variance generated from the input samples to the encoder (here they are just synthetic values).

In [None]:
gaussians_var = tfd.Independent(
    tfd.Normal(loc=z_mean, scale=tf.sqrt(z_var)),
    reinterpreted_batch_ndims=1
)

# Sample from `n_samples` independent 2-dimensional Gaussians
# with means `z_mean` and variance `z_var`.
samples_gaussians_var = gaussians_var.sample()

# Generate samples with the reparametrization trick.
samples_var_rt = z_mean + tf.sqrt(z_var) * tfd.Independent(
    tfd.Normal(loc=[0., 0.], scale=[1., 1.]),
    reinterpreted_batch_ndims=1
).sample(n_samples)

# Plot samples.
fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=samples_gaussians_var[:, 0],
    y=samples_gaussians_var[:, 1],
    color=sns.color_palette()[0],
    alpha=0.2,
    label='Samples without reparametrization'
)

sns.scatterplot(
    x=samples_var_rt[:, 0],
    y=samples_var_rt[:, 1],
    color=sns.color_palette()[1],
    alpha=0.2,
    label='Samples with reparametrization'
)