# Training TensorFlow Probability distributions

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

tfd = tfp.distributions

sns.set_theme()

## Objective

Generate sythetic data from a chosen probability distribution and then fit another distribution on the data, training its parameters.

## Generate data

In [None]:
generating_loc = 3.3
generating_scale = 2.1

generating_distr = tfd.Normal(loc=generating_loc, scale=generating_scale)

generating_distr

In [None]:
n_samples = 5000

samples = generating_distr.sample(n_samples)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Histogram of samples'
)

grid_points = tf.linspace(-5., 10., 1000)

sns.lineplot(
    x=grid_points,
    y=generating_distr.prob(grid_points),
    color=sns.color_palette()[1],
    label='Generating PDF'
)

plt.xlabel('Values', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.legend(loc='upper right', fontsize=12)

## Train a distribution on the data

**Enforcing contraints on the parameters:** some probability distributions have parameters that are required to obey constraints (e.g. they must be positive, add up to 1, etc.). These constraints should be satisfied at each step of the optimization procedure and there exist vaious methods to make this happen, among which:
- "Clipping" negative values to 0 or to a small positive number $\epsilon$: every time the optimizer generates a negative number this is mapped to either 0 or $\epsilon$, depending on what's required.
- Pass values through a function that maps them to positive values.
- Use an optimizer that supports constraints naturally (e.g. LBFGS).

In this notebook we implemented the first and the second options (which one is used can be chosen by the user).

In [None]:
def nll(samples, distr):
    """
    Negative log likelihood for a given dataset according
    to a given distribution. This will be used as the loss
    function during the optimization phase.
    """
    return - tf.reduce_mean(distr.log_prob(samples))

@tf.function
def get_loss_and_grads(samples, loss, distr):
    """
    """
    with tf.GradientTape() as tape:
        tape.watch(distr.trainable_variables)
        
        # Compute the value of the loss at the current point in
        # parameters' space.
        loss_value = loss(samples, distr)
        
        # Compute the gradient of the loss function at the current
        # point in parameters' space.
        grads = tape.gradient(loss_value, distr.trainable_variables)

    return loss_value, grads


def distr_optimization(
    samples,
    loss,
    distr,
    param_names,
    n_iter=500,
    learning_rate=0.2,
    positive_params=[],
    positivity_constraint_method='clip'
):
    """
    """
    missing_params = [
        p for p in positive_params
        if p not in param_names
    ]
    
    if len(missing_params) > 0:
        raise Exception(
            f"Some parameters required to be positive are not mentioned among the parameters' names: {missing_params}")
    
    loss_history = []
    grad_history = []
    params_history = []
    
    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
    
    for i in range(n_iter):
        loss_value, grads = get_loss_and_grads(samples, loss, distr)
        
        loss_history.append(loss_value)
        grad_history.append(grads)
        
        params_history.append([
            distr.parameters[param].numpy() for param in param_names
        ])
        
        optimizer.apply_gradients(zip(grads, distr.trainable_variables))
        
        # For parameters that must be kept positive at each iteration,
        # we clip the values to a small positive value.
        if positivity_constraint_method == 'clip':
            eps = 1e-6

            for p in positive_params:
                # print(f'Clipping value for parameter: {p}')
                
                distr.parameters[p] = tf.maximum(distr.parameters[p], eps)
        elif positivity_constraint_method == 'exp':
            for p in positive_params:
                # print(f'Applying an exponential function to parameter: {p}')
                
                distr.parameters[p] = tf.exp(distr.parameters[p])
        else:
            raise Exception(f'Positivity constraint method {positivity_constraint_method} not available')
        
    # Record the final values.
    loss_value, grads = get_loss_and_grads(samples, loss, distr)
        
    loss_history.append(loss_value)
    grad_history.append(grads)
    
    params_history.append([
        distr.parameters[param].numpy() for param in param_names
    ])
    
    loss_history = tf.concat(loss_history, axis=0)
    grad_history = tf.concat(grad_history, axis=0)
    params_history = tf.stack(params_history, axis=1)
    
    return loss_history, grad_history, params_history

Define a trainable distribution.

In [None]:
# Trainable parameters must be TensorFlow variables.
trainable_loc = tf.Variable(0.0, name='trainable_loc')
trainable_scale = tf.Variable(1.0, name='trainable_scale')

trainable_distr = tfd.Normal(loc=trainable_loc, scale=trainable_scale)

trainable_distr

Plot the samples, the generating PDF and the PDF corresponding to the trainable distribution with the initial values for the parameters.

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Histogram of samples'
)

grid_points = tf.linspace(-5., 10., 1000)

sns.lineplot(
    x=grid_points,
    y=generating_distr.prob(grid_points),
    color=sns.color_palette()[1],
    label='Generating PDF'
)

sns.lineplot(
    x=grid_points,
    y=trainable_distr.prob(grid_points),
    color=sns.color_palette()[3],
    label='Trainable PDF (initial values for parameters)'
)

plt.xlabel('Values', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.legend(loc='upper right', fontsize=12)

Define the tools for the optimization phase.

In [None]:
loss_history, grad_history, params_history = distr_optimization(
    samples,
    nll,
    trainable_distr,
    param_names=['loc', 'scale'],
    n_iter=500,
    learning_rate=.2,
    positive_params=['scale'],
    positivity_constraint_method='exp'
)

print(
    f'Initial loss: {loss_history[0]}\n'
    f'Final loss: {loss_history[-1]}\n'
)

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(14, 15), sharex=True)

sns.lineplot(
    x=range(len(loss_history)),
    y=loss_history.numpy(),
    color=sns.color_palette()[0],
    ax=axs[0]
)

plt.sca(axs[0])
plt.ylabel('Loss values', fontsize=12)
plt.title('Optimization history', fontsize=16)

sns.lineplot(
    x=range(params_history.shape[1]),
    y=params_history[0, :].numpy(),
    color=sns.color_palette()[0],
    ax=axs[1]
)

sns.lineplot(
    x=range(params_history.shape[1]),
    y=np.ones(params_history.shape[1]) * generating_loc,
    color=sns.color_palette()[1],
    ax=axs[1]
)

plt.sca(axs[1])
plt.ylabel('Loc param values', fontsize=12)

sns.lineplot(
    x=range(params_history.shape[1]),
    y=params_history[1, :].numpy(),
    color=sns.color_palette()[0],
    ax=axs[2]
)

sns.lineplot(
    x=range(params_history.shape[1]),
    y=np.ones(params_history.shape[1]) * generating_scale,
    color=sns.color_palette()[1],
    ax=axs[2]
)

plt.sca(axs[2])
plt.xlabel('N iteration', fontsize=12)
plt.ylabel('Scale param values', fontsize=12)

In [None]:
loc_error = tf.abs(trainable_distr.loc - generating_loc) / generating_loc
scale_error = tf.abs(trainable_distr.scale - generating_scale) / generating_scale

print(
    f'Error on loc parameter: {loc_error}\n'
    f'Error on scale parameter: {scale_error}\n'
)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Histogram of samples'
)

grid_points = tf.linspace(-5., 10., 1000)

sns.lineplot(
    x=grid_points,
    y=generating_distr.prob(grid_points),
    color=sns.color_palette()[1],
    label='Generating PDF'
)

params_to_plot = tf.concat([params_history[:, ::50], params_history[:, -1, tf.newaxis]], axis=1)

for i in range(params_to_plot.shape[1]):
    distr_to_plot = tfd.Normal(
        loc=params_to_plot[0, i],
        scale=params_to_plot[1, i]
    )
    
    sns.lineplot(
        x=grid_points,
        y=distr_to_plot.prob(grid_points),
        color=sns.color_palette()[3],
        label=f'Trainable PDF (step {i+1})',
        alpha=0.4
    )

plt.xlabel('Values', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.legend(loc='upper right', fontsize=12)

## An example in 2 dimensions

Here we use a 2x2 covariance matrix that is diagonal by construction. In the most general case, we should use a full covariance matrix with an optimization procedure that makes sure that at each iteration the covariance matrix itself stays symmetric and positive definite.

In [None]:
generating_loc_2d = [3.3, 4.1]
generating_scale_2d = [2.1, 1.1]

generating_distr_2d = tfd.MultivariateNormalDiag(loc=generating_loc_2d, scale_diag=generating_scale_2d)

generating_distr_2d

In [None]:
samples_2d = generating_distr_2d.sample(n_samples)

samples_2d.shape

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=samples_2d[:, 0].numpy(),
    y=samples_2d[:, 1].numpy()
)
# g.plot_joint(sns.kdeplot, color="r", zorder=1, levels=6)


xs, ys = np.meshgrid(
    np.linspace(-5., 12.5, 100),
    np.linspace(0., 8.5, 100)
)

grid_points = np.stack(
    [xs.ravel(), ys.ravel()]
).T

plt.contour(
    xs,
    ys,
    generating_distr_2d.prob(grid_points).numpy().reshape(100, 100),
    levels=20,
)

plt.title('Generated points', fontsize=16)
plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)

In [None]:
trainable_loc_2d = tf.Variable([0., 0.])
trainable_scale_2d = tf.Variable([1., 1.])

trainable_distr_2d = tfd.MultivariateNormalDiag(
    loc=trainable_loc_2d,
    scale_diag=trainable_scale_2d
)

trainable_distr_2d

In [None]:
loss_history_2d, grad_history_2d, params_history_2d = distr_optimization(
    samples_2d,
    nll,
    trainable_distr_2d,
    ['loc', 'scale_diag'],
    n_iter=500,
    learning_rate=.2,
    positive_params=['scale_diag'],
    positivity_constraint_method='exp'
)

print(
    f'Initial loss: {loss_history_2d[0]}\n'
    f'Final loss: {loss_history_2d[-1]}\n'
)

In [None]:
fig, axs = plt.subplots(nrows=5, ncols=1, figsize=(14, 15), sharex=True)

plt.subplots_adjust(wspace=0.5)

# Loss.
sns.lineplot(
    x=range(len(loss_history_2d)),
    y=loss_history_2d.numpy(),
    color=sns.color_palette()[0],
    ax=axs[0]
)

plt.sca(axs[0])
plt.ylabel('Loss values', fontsize=12)
plt.title('Optimization history', fontsize=16)

# loc (component 0).
sns.lineplot(
    x=range(params_history_2d.shape[1]),
    y=params_history_2d[0, :, 0].numpy(),
    color=sns.color_palette()[0],
    ax=axs[1]
)

sns.lineplot(
    x=range(params_history_2d.shape[1]),
    y=np.ones(params_history_2d.shape[1]) * generating_loc_2d[0],
    color=sns.color_palette()[1],
    ax=axs[1]
)

plt.sca(axs[1])
plt.ylabel('loc (comp. 0)\nparam values', fontsize=12)

# loc (component 1).
sns.lineplot(
    x=range(params_history_2d.shape[1]),
    y=params_history_2d[0, :, 1].numpy(),
    color=sns.color_palette()[0],
    ax=axs[2]
)

sns.lineplot(
    x=range(params_history_2d.shape[1]),
    y=np.ones(params_history_2d.shape[1]) * generating_loc_2d[1],
    color=sns.color_palette()[1],
    ax=axs[2]
)

plt.sca(axs[2])
plt.ylabel('loc (comp. 1)\nparam values', fontsize=12)

# scale_diag (component 0).
sns.lineplot(
    x=range(params_history_2d.shape[1]),
    y=params_history_2d[1, :, 0].numpy(),
    color=sns.color_palette()[0],
    ax=axs[3]
)

sns.lineplot(
    x=range(params_history_2d.shape[1]),
    y=np.ones(params_history_2d.shape[1]) * generating_scale_2d[0],
    color=sns.color_palette()[1],
    ax=axs[3]
)

plt.sca(axs[3])
plt.ylabel('scale_diag (comp. 0)\nparam values', fontsize=12)

# scale_diag (component 1).
sns.lineplot(
    x=range(params_history_2d.shape[1]),
    y=params_history_2d[1, :, 1].numpy(),
    color=sns.color_palette()[0],
    ax=axs[4]
)

sns.lineplot(
    x=range(params_history_2d.shape[1]),
    y=np.ones(params_history_2d.shape[1]) * generating_scale_2d[1],
    color=sns.color_palette()[1],
    ax=axs[4]
)

plt.sca(axs[4])
plt.ylabel('scale_diag (comp. 1)\nparam values', fontsize=12)
plt.xlabel('N iteration', fontsize=12)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(14, 10))

sns.scatterplot(
    x=params_history_2d[0, :, :].numpy()[:, 0],
    y=params_history_2d[0, :, :].numpy()[:, 1],
    alpha=0.3,
    ax=axs[0]
)

sns.scatterplot(
    x=[generating_loc_2d[0]],
    y=[generating_loc_2d[1]],
    alpha=1.,
    ax=axs[0],
    color=sns.color_palette()[1]
)

plt.sca(axs[0])
plt.xlabel('loc_0', fontsize=12)
plt.ylabel('loc_1', fontsize=12)
plt.title('History of parameters along the optimization', fontsize=16)

sns.scatterplot(
    x=params_history_2d[1, :, :].numpy()[:, 0],
    y=params_history_2d[1, :, :].numpy()[:, 1],
    alpha=0.3,
    ax=axs[1]
)

sns.scatterplot(
    x=[generating_scale_2d[0]],
    y=[generating_scale_2d[1]],
    alpha=1.,
    ax=axs[1],
    color=sns.color_palette()[1]
)

plt.sca(axs[1])
plt.xlabel('scale_diag_0', fontsize=12)
plt.ylabel('scale_diag_1', fontsize=12)