# Learnable normalizing flows (NF)

Normalizing flows (NF) are transformations $g$ mapping a complicated distribution $p_z(z)$ representing the data to a simple one $p_x(x)$ we can easily sample from:

$$
x = g(z)\,.
$$

$g$ needs to be invertible ($z = g^{-1}(x)$) and is implemented via TFP's `Bijector` objects.

NFs can have learnable parameters and can be fitted to the data via maximum likelihood: this way we learn the best transformation between the two distributions, within the parametric family of transformations we choose.

Source: [here](https://github.com/tensorchiefs/dl_book/blob/master/chapter_06/nb_ch06_03.ipynb)

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

sns.set_theme()

## An affine mapping between Gaussians

Let's generate samples from two Gaussian distributions with different parameters and learn the NF transforming one into the other.

In [None]:
n_samples = 10000

standard_gaussian_samples = tfd.Normal(loc=0., scale=1.).sample(n_samples)
generic_gaussian_samples = tfd.Normal(loc=5., scale=0.2).sample(n_samples)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=standard_gaussian_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Standard Gaussian samples'
)

sns.histplot(
    x=generic_gaussian_samples.numpy(),
    stat='density',
    color=sns.color_palette()[1],
    label='Generic Gaussian samples'
)

plt.legend()
plt.title('Samples')

Define an affine bijector (implementing a linear transformation between samples) depending on two trainable parameters.

In [None]:
# Initial values for the parameters (scale and shift)
# of the affine transformation.
m = tf.Variable(.5, name='m')
q = tf.Variable(1.2, name='q')

affine_bij = tfp.bijectors.Chain([
    tfp.bijectors.Shift(shift=q),
    tfp.bijectors.Scale(scale=m)
])

Define the transformed distribution.

In [None]:
transformed_distr = tfd.TransformedDistribution(
    distribution=tfd.Normal(loc=0., scale=1.),  # Source distribution.
    bijector=affine_bij
)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=standard_gaussian_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Standard Gaussian samples'
)

sns.histplot(
    x=generic_gaussian_samples.numpy(),
    stat='density',
    color=sns.color_palette()[1],
    label='Generic Gaussian samples'
)

sns.histplot(
    x=transformed_distr.sample(10000).numpy(),
    stat='density',
    color=sns.color_palette()[2],
    label='Samples from the transformed distribution',
    alpha=0.5
)

plt.legend()
plt.title('Samples')

The loss function to minimize w.r.t. the variable in the NF (affine bijector) is the negative log likelihood of the target data given by the transformed distribution.

In [None]:
def nll(samples, distr):
    """
    """
    return - tf.reduce_mean(
        distr.log_prob(samples)
    )

Training: we apply gradient descent to minimize the loss function w.r.t. the NF variables.

__Note:__ by trial and error it's evident that a bigger (~0.1) learning rate is needed for the first ~800 epochs, after which if the learning rate itself is kept constant, SGD starts overshooting the minimum of the loss, in which case the final value depends on where in the "overshooting cycle" the training ends. Decreasing the learning rate needs after epoch 800 via a decay schedule allows the minumum to be found (first 800 epochs) and the overshooting to be kept at bay (afterwards).

In [None]:
epochs = 1000

# n_lr_values = 5

lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=[800],  # list(tf.cast(tf.linspace(0, epochs - 1, n_lr_values + 1), dtype=tf.int64)[1:-1].numpy()),
    values=[0.1, 0.05]  # list(tf.linspace(0.1, 0.01, n_lr_values).numpy())
)

optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

loss_history = []
params_history = [[m.numpy(), q.numpy()]]

for i in range(epochs):
    with tf.GradientTape() as tape:
        loss = nll(generic_gaussian_samples, transformed_distr)
        
    loss_history.append(loss.numpy())
    
    grad = tape.gradient(loss, [m, q])
    
    optimizer.apply_gradients(zip(grad, [m, q]))
    
    params_history.append([m.numpy(), q.numpy()])

loss_history.append(nll(generic_gaussian_samples, transformed_distr).numpy())

params_history = tf.constant(params_history)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.lineplot(
    x=range(len(loss_history)),
    y=loss_history
)

plt.title('Training loss', fontsize=14)
plt.xlabel('Epoch')
plt.ylabel('Loss value')

In [None]:
params_history[-1, :]

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=standard_gaussian_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Standard Gaussian samples'
)

sns.histplot(
    x=generic_gaussian_samples.numpy(),
    stat='density',
    color=sns.color_palette()[1],
    label='Generic Gaussian samples'
)

sns.histplot(
    x=transformed_distr.sample(10000).numpy(),
    stat='density',
    color=sns.color_palette()[2],
    label='Samples from the transformed distribution after training',
    alpha=0.5
)

plt.legend()
plt.title('Samples')

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=params_history[::10, 0].numpy(),
    y=params_history[::10, 1].numpy(),
    alpha=tf.linspace(0.1, 1.0, params_history[::10, 0].shape[0]).numpy(),
    label="Parameters' values"
)

sns.scatterplot(
    x=params_history[-1:, 0].numpy(),
    y=params_history[-1:, 1].numpy(),
    color='red',
    label='Final values'
)

plt.title("Parameters' trajectory along training", fontsize=14)
plt.xlabel('m')
plt.ylabel('q')

## A more complicated stack of NF, with nonlinearities

Load the Old Faithful dataset and fit the data with a more complicated stack of NF containing nonlinearities. We'll work on the `TimeWaiting` column.

In [None]:
import pandas as pd

In [None]:
old_faithful_data = pd.read_csv('../data/learnable_normalizing_flows/OldFaithful.csv')

old_faithful_data

In [None]:
of_data = old_faithful_data.iloc[:, 1].values

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=of_data,
    stat='density'
)

plt.title('Old Faithful TimeWaiting data', fontsize=14)
plt.xlabel('x')

Define a stack of bijectors. Each "layer" in the sequence is itself composed by sub-sequence of two bijectors, in order of application:
- A `SinhArcsinh` bijector.
- An affine bijector.

__Note:__ the order in which the bijector appear in the sequence is inverted w.r.t. to the one in which they act!

The source distribution will be a standard Gaussian.

In [None]:
n_layers = 5

bij_list = []

for i in range(n_layers):
    # Add an affine bijector.
    shift = tf.Variable(0., name=f'shift_{i}')
    scale = tf.Variable(1., name=f'scale_{i}')

    bij_list.append(tfp.bijectors.Chain([
        tfp.bijectors.Shift(shift=shift),
        tfp.bijectors.Scale(scale=scale)
    ]))

    # Add a `SinhArcsinh` bijector.
    skewness = tf.Variable(0., name=f'skewness_{i}')
    tailweight = tf.Variable(1., name=f'tailweight_{i}')
    
    bij_list.append(tfp.bijectors.SinhArcsinh(
        skewness=skewness,
        tailweight=tailweight
    ))
    
bij_stack = tfp.bijectors.Chain(bij_list)

trainable_distr = tfd.TransformedDistribution(
    distribution=tfd.Normal(loc=0., scale=1.),
    bijector=bij_stack
)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=of_data,
    stat='density',
    color=sns.color_palette()[0],
    label='Data'
)

sns.histplot(
    x=trainable_distr.sample(100),
    stat='density',
    color=sns.color_palette()[1],
    label='Samples from the trainable distribution (before training)'

)

plt.title('Old Faithful TimeWaiting data', fontsize=14)
plt.xlabel('x')
plt.legend()

Training.

__Note:__ training is extremely sensitive to the learning rate, so we need to proceed slowly and with a lot of epochs. I still couldn't reach a really optimal value (the trained distribution still failed to model the left peak in the data) - probably something more could be done with the bijectors.

In [None]:
loss_history_2 = []
params_history_2 = [[var.numpy() for var in trainable_distr.trainable_variables]]

In [None]:
epochs = 10000

optimizer_2 = tf.keras.optimizers.SGD(learning_rate=1e-4)

for i in range(epochs):
    with tf.GradientTape() as tape:
        loss = nll(of_data, trainable_distr)
        
    loss_history_2.append(loss.numpy())
    
    grad = tape.gradient(loss, trainable_distr.trainable_variables)
    
    optimizer_2.apply_gradients(zip(grad, trainable_distr.trainable_variables))
    
    params_history_2.append([var.numpy() for var in trainable_distr.trainable_variables])
    
    if i % 10 == 0:
        print(f'Epoch: {i} | Loss: {loss}')

loss_history_2.append(nll(of_data, trainable_distr).numpy())

# params_history_2 = tf.constant(params_history_2)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.lineplot(
    x=range(len(loss_history_2))[100:],
    y=loss_history_2[100:]
)

plt.title('Training loss', fontsize=14)
plt.xlabel('Epoch')
plt.ylabel('Loss value')

In [None]:
trainable_distr.trainable_variables

In [None]:
test_samples = trainable_distr.sample(100)

# test_samples = test_samples[~tf.math.is_inf(test_samples)]
test_samples

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=of_data,
    stat='density',
    color=sns.color_palette()[0],
    label='Data'
)

x_plot = tf.linspace(of_data.min(), of_data.max(), 100).numpy()

sns.lineplot(
    x=x_plot,
    y=trainable_distr.prob(x_plot).numpy(),
    color=sns.color_palette()[1],
    label='Trained distribution'
)

plt.title('Old Faithful TimeWaiting data', fontsize=14)
plt.xlabel('x')
plt.legend()