# Training distributions using the KL divergence

__Objective:__ train a TFP distribution to approximate a given one minimizing the Kullback-Leibler divergence between the two.

The Kullback-Leibler divergence between distributions $p$ and $q$ is defined as
$$
\begin{array}{lll}
D_\mathrm{KL} [p || q] &\equiv& - \int \mathrm{d}^dx\, p(x) \log \left( \frac{q(x)}{p(x)} \right)\\
&=& \mathbb{E}_{x \sim p} \left[ \log(p(x)) - \log(q(x))\right]
\end{array}
$$
and quantifies how much $p$ differs from $q$. Notice however that it can't be regarded as a metric on the space of distributions, as it isn't symmetric.

If $p$ is parametrized by a set of parameters $\theta$, we can minimize $D_\mathrm{KL} [p || q]$ w.r.t. $\theta$, finding the optimal parameters minimizing the difference between the two distributions. Because the KL divergence is not symmetric, minimizing $D_\mathrm{KL} [q || p]$, though perfectly legit, would give a different result - in particular, the $D_\mathrm{KL}$ tends to give better (lower) scores if the support of the **first** distribution is contained in that of the second one. Therefore,
- Minimizing $D_\mathrm{KL} [p || q]$ (with the trainable distribution as the first argument) will tend to find $p$ with support contained in the support of $q$ (which can lead to an undersetimate of its variance).
- Minimizing $D_\mathrm{KL} [q || p]$ (with the trainable distribution as the second argument) will tend to find $p$ with a support that contains that of $q$ (which in turns can lead to an overestimate of its variance).

In [None]:
import sys
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('../../modules/')

from keras_utilities import plot_history

tfd = tfp.distributions
tfb = tfp.bijectors

sns.set_theme()

## Target distribution

Define the target distribution: a 2-dimensional multivariate Gaussian distribution with full covariance matrix.

In [None]:
# Mean of the target distribution.
q_mu = [0., 0.]

# Bijector that , given a general vector, generates the
# lower-triangular part of the covariance matrix.
lower_triangle_bij = tfb.Chain([
    # Applies the given bijector (in this case softplus
    # to the diagonal entries a matrix.
    tfb.TransformDiagonal(tfb.Softplus()),
    # Given a vector, fills in the lower triangle of
    # a matrix with the elements of the vector in a
    # clockwise spiral way.
    tfb.FillTriangular()
])

# Generate the lower-triangular covariance matrix.
q_l = lower_triangle_bij(tf.random.uniform(shape=(3,)))

# Create a multivariate normal distribution from the
# lower-triangular covariance matrix.
q = tfd.MultivariateNormalTriL(loc=q_mu, scale_tril=q_l)

In [None]:
x_min, x_max = (
    (q.parameters['loc'][0] - 3. * q.parameters['scale_tril'][0, 0]).numpy(),
    (q.parameters['loc'][0] + 3. * q.parameters['scale_tril'][0, 0]).numpy()
)

y_min, y_max = (
    (q.parameters['loc'][1] - 3. * q.parameters['scale_tril'][1, 1]).numpy(),
    (q.parameters['loc'][1] + 3. * q.parameters['scale_tril'][1, 1]).numpy()
)

x_plot, y_plot = np.meshgrid(
    np.linspace(x_min, x_max, 1000, dtype=np.float32),
    np.linspace(y_min, y_max, 1000, dtype=np.float32),
)

prob_plot = q.prob(tf.stack(
    [x_plot.flatten(), y_plot.flatten()],
    axis=1
))

fig = plt.figure(figsize=(14, 6))

plt.contour(
    x_plot,
    y_plot,
    np.reshape(prob_plot, x_plot.shape),
    cmap='Blues'
)

plt.title('Target distribution', fontsize=12)

## Variational distribution

Define the approximate distribution we'll use to approximate the target one: a multivariate Gaussian with diagonal covariance.

**Note:** the target distribution exhibits correlation between the two dimensions (covariance matrix is not diagonal), while the approximate distribution does not (diagonal covariance). This means that the target distribution is **not** in the same family of distributions parametrized by the approximate one, so we won't be able to match it perfectly with the optimization.

In [None]:
p = tfd.MultivariateNormalDiag(
    # Randomly initialized mean vector.
    loc=tf.Variable(tf.random.normal(shape=(2,))),
    # Randomly initialized diagon entries of the covariance
    # matrix (the other entries are assumed to be zero).
    scale_diag=tf.Variable(tfb.Exp()(tf.random.uniform(shape=(2,))))
)

In [None]:
x_min_p, x_max_p = (
    (p.parameters['loc'][0] - 3. * p.parameters['scale_diag'][0]).numpy(),
    (p.parameters['loc'][0] + 3. * p.parameters['scale_diag'][0]).numpy()
)

y_min_p, y_max_p = (
    (p.parameters['loc'][1] - 3. * p.parameters['scale_diag'][1]).numpy(),
    (p.parameters['loc'][1] + 3. * p.parameters['scale_diag'][1]).numpy()
)

x_plot_p, y_plot_p = np.meshgrid(
    np.linspace(x_min_p, x_max_p, 1000, dtype=np.float32),
    np.linspace(y_min_p, y_max_p, 1000, dtype=np.float32),
)

prob_plot_p = p.prob(tf.stack(
    [x_plot_p.flatten(), y_plot_p.flatten()],
    axis=1
))

fig = plt.figure(figsize=(14, 6))

plt.contour(
    x_plot_p,
    y_plot_p,
    np.reshape(prob_plot_p, x_plot_p.shape),
    cmap='Reds'
)

plt.contour(
    x_plot,
    y_plot,
    np.reshape(prob_plot, x_plot.shape),
    cmap='Blues'
)

plt.title('Distributions', fontsize=12)

## Optimization

The training loop minimizes $D_\mathrm{KL} [p || q]$ iteratively using a gradient descent-like algorithm.

In [None]:
@tf.function
def loss_and_grads(distributions, loss_f=tfd.kl_divergence, trainable_distr=0):
    """
    Compute the value of the loss function `loss` between
    distributions `distr_1` and `distr_2`.
    """
    dist_1, dist_2 = distributions
    
    with tf.GradientTape() as tape:
        loss = loss_f(dist_1, dist_2)

    grad = tape.gradient(loss, distributions[trainable_distr].trainable_variables)

    return loss, grad

In [None]:
learning_rate = 1e-2

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

history = {
    'loss': []
}

epoch_counter = 0

In [None]:
epochs = 500

for i in range(epochs):
    epoch_counter += 1
    
    loss, grad = loss_and_grads([p, q])

    optimizer.apply_gradients(zip(grad, p.trainable_variables))

    history['loss'].append(loss.numpy())

    if (epoch_counter < 10) or (epoch_counter % 100 == 0):
        print(f'Epoch: {epoch_counter} | Loss: {loss}')

plot_history(history)

In [None]:
x_min_p, x_max_p = (
    (p.parameters['loc'][0] - 3. * p.parameters['scale_diag'][0]).numpy(),
    (p.parameters['loc'][0] + 3. * p.parameters['scale_diag'][0]).numpy()
)

y_min_p, y_max_p = (
    (p.parameters['loc'][1] - 3. * p.parameters['scale_diag'][1]).numpy(),
    (p.parameters['loc'][1] + 3. * p.parameters['scale_diag'][1]).numpy()
)

x_plot_p, y_plot_p = np.meshgrid(
    np.linspace(x_min_p, x_max_p, 1000, dtype=np.float32),
    np.linspace(y_min_p, y_max_p, 1000, dtype=np.float32),
)

prob_plot_p = p.prob(tf.stack(
    [x_plot_p.flatten(), y_plot_p.flatten()],
    axis=1
))

fig = plt.figure(figsize=(14, 6))

plt.contour(
    x_plot_p,
    y_plot_p,
    np.reshape(prob_plot_p, x_plot_p.shape),
    cmap='Reds'
)

plt.contour(
    x_plot,
    y_plot,
    np.reshape(prob_plot, x_plot.shape),
    cmap='Blues'
)

plt.title('Distributions', fontsize=12)

### Experiment: swap the terms in the KL divergence

Let's reinitialize the approximate distribution $p$ and minimize $D_\mathrm{KL} [q || p]$ this time.

In [None]:
p_2 = tfd.MultivariateNormalDiag(
    loc=tf.Variable(tf.random.normal(shape=(2,))),
    scale_diag=tf.Variable(tfb.Exp()(tf.random.uniform(shape=(2,))))
)

In [None]:
learning_rate = 1e-2

optimizer_2 = tf.keras.optimizers.Adam(learning_rate=learning_rate)

history_2 = {
    'loss': []
}

epoch_counter_2 = 0

In [None]:
epochs = 500

for i in range(epochs):
    epoch_counter_2 += 1

    # Here p and q are swapped w.r.t. before.
    loss, grad = loss_and_grads([q, p_2], trainable_distr=1)

    optimizer_2.apply_gradients(zip(grad, p_2.trainable_variables))

    history_2['loss'].append(loss.numpy())

    if (epoch_counter_2 < 10) or (epoch_counter_2 % 100 == 0):
        print(f'Epoch: {epoch_counter_2} | Loss: {loss}')

plot_history(history_2)

In [None]:
x_min_p_2, x_max_p_2 = (
    (p_2.parameters['loc'][0] - 3. * p_2.parameters['scale_diag'][0]).numpy(),
    (p_2.parameters['loc'][0] + 3. * p_2.parameters['scale_diag'][0]).numpy()
)

y_min_p_2, y_max_p_2 = (
    (p_2.parameters['loc'][1] - 3. * p_2.parameters['scale_diag'][1]).numpy(),
    (p_2.parameters['loc'][1] + 3. * p_2.parameters['scale_diag'][1]).numpy()
)

x_plot_p_2, y_plot_p_2 = np.meshgrid(
    np.linspace(x_min_p_2, x_max_p_2, 1000, dtype=np.float32),
    np.linspace(y_min_p_2, y_max_p_2, 1000, dtype=np.float32),
)

prob_plot_p_2 = p_2.prob(tf.stack(
    [x_plot_p_2.flatten(), y_plot_p_2.flatten()],
    axis=1
))

fig = plt.figure(figsize=(14, 6))

plt.contour(
    x_plot_p_2,
    y_plot_p_2,
    np.reshape(prob_plot_p_2, x_plot_p_2.shape),
    cmap='Reds'
)

plt.contour(
    x_plot,
    y_plot,
    np.reshape(prob_plot, x_plot.shape),
    cmap='Blues'
)

plt.title('Distributions', fontsize=12)

Compare the two optimized distributions: the one obtained by minimizing $D_\mathrm{KL} [p || q]$ has a smaller variance than the one obtained minimizing $D_\mathrm{KL} [q || p]$, as expected.

In [None]:
fig = plt.figure(figsize=(14, 6))

plt.contour(
    x_plot_p,
    y_plot_p,
    np.reshape(prob_plot_p, x_plot_p.shape),
    cmap='Reds'
)

plt.contour(
    x_plot_p_2,
    y_plot_p_2,
    np.reshape(prob_plot_p_2, x_plot_p_2.shape),
    cmap='Blues'
)

plt.title('Distributions', fontsize=12)