# RealNVP flows from scratch

__Objective:__ build and train a simple RealNVP flow model from scratch.

__Source:__ D. Foster, [_Generative deep learning_](https://www.oreilly.com/library/view/generative-deep-learning/9781492041931/) (2nd ed.) (with notebooks [here](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition)).

**Setup:**
- We start from a vetor $z \in \mathbb{R}^D$ in latent space, which we sample from a multivariate standard normal distribution, so $p_Z \sim \mathcal{N}(0, I)$.
- We transform $z$ to the "real" data space $x \in \mathbb{R}^D$ via the RealNVP transformation so that $z \to x = x(z)$ is the **forward** transformation (this is opposite to what's done in the source, in which this is taken to be the inverse transformation, but for RealNVP's it doesn't really matter the forward and the inverse transformation are computationally equivalent).
- The RealNVP transformation is implemented by a stack of **coupling layers** with feature permutation operations (bijectors) in between.
- Following the RealNVP recipe, in each coupling layer the first $d$ dimensions (features) of $x$ are singled out and used to generate the corresponding dimensions of $z$ (an identity transformation) and to parametrize (via a neural network) an affine transformation for the last $(D - d)$ dimensions of $z$.
- Full transformation for a single coupling layer:
$$
\begin{array}{lll}
z_i &=& x_i\quad \forall x=i, \ldots, d\\
z_j &=& x_j\,\exp\left( s_j(x_1, \ldots, x_d) \right) + t_j(x_1, \ldots, x_d)\quad \forall j = d+1, \ldots, D
\end{array}
$$
where the vectors $s, t \in \mathbb{R}^{D-d}$ are the tensors outputted by the coupling layer and are functions of $x_1, \ldots, x_d$ given by a neural network.

In [None]:
import sys
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('../modules/')

tfd = tfp.distributions

sns.set_theme()

%load_ext autoreload
%autoreload 2

## Coupling layer

The coupling layer is responsible for taking the first $d$ dimensions (features) of the input and outputting a scale and a translation tensor (so two outputs) to be used to parametrize an affine transformation for the remaining $(D - d)$ dimensions of the input.

In [None]:
from real_nvp import CouplingLayer

In [None]:
n_masked_dims = 2
n_affine_dims = 3

test_cl = CouplingLayer(
    n_masked_dims=n_masked_dims,
    n_affine_dims=n_affine_dims,
    hidden_layers_dims=[32, 32]
)

s, t = test_cl(tf.random.normal(shape=(14, 5)))

## RealNVP bijector

Parametrize an affine (scale and then shift) tranformation with the output from the `CouplingLayer`.

In [None]:
from real_nvp import AffineBijector

In [None]:
# Generate simple test data.
test_data = tf.ones(shape=(4, 5)) * 2.4

# Instantiate an affine bijector parametrized by a
# coupling layer so that it works as a RealNVP bijector.
test_real_nvp_bij = AffineBijector(test_cl)

Test the forward and inverse transformations. In both cases the first `n_masked_dims` of the datapoints should be left unaltered.

In [None]:
test_real_nvp_bij.forward(test_data), test_real_nvp_bij.inverse(test_data)

Check a "cycle condition": applying the forward and then the inverse transformation on some data (and vice versa) we should reobtain the starting tensors.

In [None]:
tf.norm(test_real_nvp_bij.inverse(test_real_nvp_bij.forward(test_data)) - test_data)

In [None]:
tf.norm(test_real_nvp_bij.forward(test_real_nvp_bij.inverse(test_data)) - test_data)

## RealNVP layer

The `RealNVP` layer object represents one RealNVP block inside a larger model. It's composed of 2 operations: a feature permutation followed by an affine transformation parametrized by a `CouplingLayer` object.

In [None]:
from real_nvp import RealNVPLayer

In [None]:
test_rnvp_layer = RealNVPLayer(
    n_masked_dims=2,
    n_affine_dims=3,
    hidden_layers_dims=[32, 32]
)

In [None]:
test_rnvp_layer(
    tf.constant([range(test_rnvp_layer.n_masked_dims + test_rnvp_layer.n_affine_dims)] * 4, dtype=tf.float32)
)

In [None]:
test_rnvp_layer.count_params()

## RealNVP model

In [None]:
from real_nvp import RealNVPModel

In [None]:
test_rnvp_model = RealNVPModel(
    n_masked_dims=2,
    n_affine_dims=3,
    n_real_nvp_blocks=3,
    hidden_layers_dims=[32, 32]
)

In [None]:
test_rnvp_model(test_data)

test_rnvp_model.summary()

Compute log probabilities with the base and the transformed distributions.

**Note:** the event shape of the distribution should be equal to the shape of one sample so that one sample corresponds to one value of log prob.

In [None]:
test_rnvp_model.base_distr, test_rnvp_model.transformed_distr

In [None]:
test_rnvp_model.base_distr.log_prob(test_data), test_rnvp_model.transformed_distr.log_prob(test_data)

## Training

Generate data.

In [None]:
from sklearn.datasets import make_moons
from tensorflow.keras import backend as K
from keras_utilities import plot_history

In [None]:
data, labels_data = make_moons(n_samples=2500, noise=.06)

data = tf.constant(data, dtype=tf.float32)

In [None]:
rnvp_model = RealNVPModel(
    n_masked_dims=1,
    n_affine_dims=1,
    n_real_nvp_blocks=3,
    hidden_layers_dims=[32, 32]
)

In [None]:
fig, axs = plt.subplots(figsize=(14, 6), nrows=1, ncols=2)

base_samples = rnvp_model.base_distr.sample(2500)

sns.scatterplot(
    x=base_samples[:, 0],
    y=base_samples[:, 1],
    ax=axs[0]
)

plt.sca(axs[0])
plt.xlabel('x')
plt.ylabel('y')
plt.title('Base distribution', fontsize=14)

sns.scatterplot(
    x=data[:, 0],
    y=data[:, 1],
    ax=axs[1],
    label='Data'
)

transformed_samples = rnvp_model.transformed_distr.bijector.forward(base_samples)

sns.scatterplot(
    x=transformed_samples[:, 0],
    y=transformed_samples[:, 1],
    ax=axs[1],
    label='Transformed samples (before training)',
    alpha=.2
)

plt.sca(axs[1])
plt.xlabel('x')
plt.title('Data', fontsize=14)

Fit model to the data using a custom training step.

In [None]:
def nll(data, distr):
    """
    Negative log likelihood of data `data` w.r.t. the distribution
    `distr`, used as an objective function (loss) to minimize
    during training.
    """
    return - tf.reduce_mean(distr.log_prob(data))


@tf.function
def training_step(data, model, loss_fn, optimizer):
    """
    A single training step.
    """
    with tf.GradientTape() as tape:
        loss = nll(data, model.transformed_distr)
    
    grad = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(grad, model.trainable_variables))

    return loss

In [None]:
training_history = {
    'loss': [],
    'learning_rate': []
}

In [None]:
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=[2500],
    values=[1e-2, 1e-3]
)

optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

In [None]:
# K.set_value(optimizer.learning_rate, 0.001)

In [None]:
epochs = 15000

for epoch in range(epochs):
    training_history['loss'].append(training_step(data, rnvp_model, nll, optimizer).numpy())

    training_history['learning_rate'].append(optimizer.learning_rate.numpy())

    if (epoch < 5) or (epoch % 200 == 0):
        print(f'Epoch: {epoch} | Loss: {training_history["loss"][-1]} | Learning rate: {training_history["learning_rate"][-1]}')

plot_history(training_history)


# Plot data and samples from the transformed distribution.
fig = plt.figure(figsize=(14, 6))

transformed_samples = rnvp_model.transformed_distr.sample(2500)

sns.scatterplot(
    x=data[:, 0].numpy(),
    y=data[:, 1].numpy(),
    label='Data'
)

sns.scatterplot(
    x=transformed_samples[:, 0],
    y=transformed_samples[:, 1],
    label='Transformed samples (after training)',
    alpha=.2
)


plt.xlabel('x')
plt.ylabel('y')
plt.title('Data', fontsize=14)