# Real NVP (non-volume-preserving) flows

__Note:__ it looks like `tfb.real_nvp_default_template` doesn't produce trainable variables that either Keras or Tensorflow can track. This is probably related to [this issue](https://github.com/tensorflow/probability/issues/1439). To overcome this, a `Layer` subclass is defined, using Keras `Dense` layers to mimic what `tfb.real_nvp_default_template` - in which case variables are correctly tracked.

__Objective:__ train a real NVP model.

Source: [here](https://github.com/tensorchiefs/dl_book/blob/master/chapter_06/nb_ch06_04.ipynb)

**Idea:** build an alternative to inverse autoregressive flows that singles out the first $n$ dimensions. Working on $d$-dimensional event space, as with inverse autoregressive flows we start from a "complicated space" with points $\mathbf{z}\in\mathbb{R}^d$ distributed according to the probability density function $p_z(\mathbf{z})$ we'd like to model. We want to find a transformation $\mathbf{G}$ whose inverse maps points $\mathbf{z}$ to points $\mathbf{x}\in\mathbb{R}^d$, with the distribution becoming $p_x(\mathbf{x})$, which we want to be very simple (e.g. a multivariate isotropic Gaussian). In direct form, $\mathbf{x} = \mathbf{G}(\mathbf{z})$. In real NVP flows, $\mathbf{G}$ is taken to be the identity for the first $n$ components (working one component at a time, on the corresponding one),

$$
x_i = G_i(\mathbf{z}) = z_i\quad \forall i = 1, \ldots, n\,,
$$

while for the remaining $d - n$ components $\mathbf{G}$ is taken to be an affine transformation of the corresponding components, with parameters depending **only on the first $n$ components** and given by a neural networks,

$$
x_i = G_i(\mathbf{z}) = \exp\left( \alpha_i(z_1, \ldots, z_n) \right)\, z_i + b(z_1, \ldots, z_n)\quad \forall i = n + 1, \ldots, d\,.
$$

Tensorflow Probability provides the following implementation:
- `tfp.bijectors.Bijector`: the general class implementing an invertible transformation (a normalizing flow).
- `tfp.bijectors.RealNVP`: the bijector implementing the full real NVP flow.

As per the first note, `tfb.real_nvp_default_template`, so the network returning the parameters for the affine transformations was written by hand.

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Dense
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions
tfb = tfp.bijectors

sns.set_theme()

## Generate data

Generate a complicated distribution of points in 2 dimensions.

In [None]:
n_samples = 2500

x2_samples = tfd.Normal(loc=0., scale=4.).sample(n_samples)

x1_samples = tfd.Normal(loc=.25 * tf.square(x2_samples), scale=tf.ones(n_samples, dtype=tf.float32)).sample()

samples = tf.stack(
    [x1_samples, x2_samples],
    axis=1
) / 40.

samples

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=samples[:, 0],
    y=samples[:, 1],
)

plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.title('Samples', fontsize=14)

## Build the NVP model

Build the model as a Keras `Model` object (subclass).

In [None]:
class RealNVPLayer(tf.keras.layers.Layer):
    """
    """
    def __init__(self, full_dim, num_masked):
        """
        """
        super().__init__()
        
        self.full_dim = full_dim
        self.num_masked = num_masked
        
        self.dense_alpha = Dense(
            units=full_dim - num_masked,
            activation='relu'
        )
        self.dense_b = Dense(
            units=full_dim - num_masked,
            activation='relu'
        )
        
        
    def call(self, x, *inputs):
        """
        """
        output = (
            self.dense_alpha(x),
            self.dense_b(x)
        )
        
        return output
    
    
class RealNVPDeepLayer(tf.keras.layers.Layer):
    """
    """
    def __init__(self, full_dim, num_masked, hidden_layers):
        """
        """
        super().__init__()
        
        self.full_dim = full_dim
        self.num_masked = num_masked
        self.hidden_layers = hidden_layers
        
        # self.dense_alpha = Dense(
        #     units=full_dim - num_masked,
        #     activation='relu'
        # )
        self.alpha_layers = [
            Dense(
                units=full_dim - num_masked,
                activation='relu'
            )
            for _ in range(hidden_layers[0])
        ]
        
        # self.dense_b = Dense(
        #     units=full_dim - num_masked,
        #     activation='relu'
        # )
        self.b_layers = [
            Dense(
                units=full_dim - num_masked,
                activation='relu'
            )
            for _ in range(hidden_layers[0])
        ]
        
        
    def call(self, x, *inputs):
        """
        """
        alpha = x
        b = x
        
        for layer in self.alpha_layers:
            alpha = layer(alpha)
        
        for layer in self.b_layers:
            b = layer(b)
        
        return (alpha, b)


class RealNVP(tf.keras.Model):
    """
    Subclass of a Keras `Model` object implementing a real
    NVP flow.
    """
    def __init__(self, *, output_dim, num_masked, hidden_layers, **kwargs):
        """
        Constructor of the real NVP.
        """
        super().__init__(kwargs)
        
        self.output_dim = output_dim
        self.hidden_layers = hidden_layers
        self.nets = []
        
        bijectors = []
        
        # Number of layers.
        num_blocks = 5
        
        # Number of units in the hidden layers of the NN parametrizing
        # the affine transformation in the real NVP flow.
        h = 32
        
        use_hidden_layers = not ((hidden_layers is None) or (hidden_layers == 0))
        
        if use_hidden_layers:
            print('Instantiating real NVP model with hidden layers')
        else:
            print('Instantiating real NVP model with no hidden layer')
        
        # Each block (layer) is composed of a real NVP flow and a
        # permutation, written in this order but then applied in
        # reversed order (first the permutation, then the real NVP).
        # The resulting first permutation is actually discarded (see
        # below).
        for i in range(num_blocks):
            # Build a function to be used to compute the affine
            # parameters in the real NVP (in this case, a NN).
            # net = tfb.real_nvp_default_template(
            #     hidden_layers=[h, h]  # Number of units in each hidden layer (two heads).
            # )
            if not use_hidden_layers:
                net = RealNVPLayer(full_dim=2, num_masked=1)
            else:
                net = RealNVPDeepLayer(full_dim=2, num_masked=1, hidden_layers=hidden_layers)
            
            # Instantiate a real NVP object and append it to
            # the list of bijectors.
            bijectors.append(
                tfb.RealNVP(
                    shift_and_log_scale_fn=net,
                    # Number of masked dimensions.
                    # Note: in 2 dimensions this can only be 1 to get a
                    #       nontrivial case.
                    num_masked=num_masked
                )
            )
            
            # Instantiate a bijector implementing the permutation
            # operation among dimensions, so that singling out the
            # first n dimensions in the real NVP doesn't select
            # the same ones in each layer (block).
            # Note: the argument is the permutation to be used,
            #       which in our 2-dimensional case can be only
            #       [1, 0] ([0, 1] would be the identity).
            bijectors.append(tfb.Permute([1, 0]))
            
            # Append the neural network function (parametrizing the
            # affine parameters) to keep track of it.
            self.nets.append(net)
            
        # Build the full bijector corresponding to the real NVP by
        # chaining together the bijectors in the `bijectors` list.
        # Notes: 
        #   * We reverse the list of bijectors so that they are
        #     applied in reversed order w.r.t. the one we populated
        #     the list with.
        #   * Before reversing the list, we leave out the last biijector,
        #     which whould be a useless initial permutation.
        bijector = tfb.Chain(list(reversed(bijectors[:-1])))
        
        # Instantiate the flow object: a distribution obtained starting
        # from simple source distribution and then applying the full
        # bijector obtained above.
        self.flow = tfd.TransformedDistribution(
            # Source distribution.
            distribution=tfd.MultivariateNormalDiag(loc=[0., 0.]),
            # Bijector (NF) to apply.
            bijector=bijector
        )
        
    def call(self, *inputs):
        """
        Forward pass.
        """
        return self.flow.bijector.forward(*inputs)

In [None]:
# Test the custom layers.
print('Test with no hidden layers:')
print(
    tfb.RealNVP(
        num_masked=1,
        shift_and_log_scale_fn=RealNVPLayer(full_dim=2, num_masked=1),
    ).forward(tf.random.uniform(shape=(5, 2)))
)

print('\nTest with hidden layers:')
print(
    tfb.RealNVP(
        num_masked=1,
        shift_and_log_scale_fn=RealNVPDeepLayer(full_dim=2, num_masked=1, hidden_layers=[4, 4]),
    ).forward(tf.random.uniform(shape=(5, 2)))
)

In [None]:
test_model = RealNVP(output_dim=2, num_masked=1, hidden_layers=None)

# Test on some random data.
test_model(tf.random.uniform(shape=(5, 2)))

In [None]:
len(test_model.trainable_variables)

**Note:** what the untrained flow does depends on the random initialization of the NN weights.

In [None]:
# Sample the source distribution and transform the samples via the
# untrained flow. This should give nothing sensible!
# Note: every time the model is instantiated, the NN weights
#       in it are re-initialized and a different transformation
#       is obtained.
transformed_samples_untrained = test_model(test_model.flow.distribution.sample((2500)))

fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=samples[:, 0].numpy(),
    y=samples[:, 1].numpy(),
    label='Original samples'
)

sns.scatterplot(
    x=transformed_samples_untrained[:, 0].numpy(),
    y=transformed_samples_untrained[:, 1].numpy(),
    label='Transformed samples (untrained flow)'
)

plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.title('Samples', fontsize=14)

Training.

In [None]:
def nll(samples, distr):
    """
    Negative log likelihood of `samples` according to the
    distribution `distr`.
    """
    return - tf.reduce_mean(distr.log_prob(samples))

In [None]:
model = RealNVP(output_dim=2, num_masked=1, hidden_layers=None)

loss_history = []

In [None]:
@tf.function
def training_step(x):
    """
    """
    with tf.GradientTape() as tape:
        loss = nll(x, model.flow)

    grad = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(grad, model.trainable_variables))
    
    return loss

In [None]:
epochs = 20000

optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)

for i in range(epochs):
    loss = training_step(samples)
        
    loss_history.append(loss.numpy())
    
    if (i < 10) or (i % 100 == 0):
        print(f'Epoch: {i} | Loss: {loss_history[-1]}')
    
loss_history.append(nll(samples, model.flow).numpy())

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.lineplot(
    x=range(len(loss_history)),
    y=loss_history
)

plt.title('Training loss', fontsize=14)
plt.xlabel('Epoch')
plt.ylabel('Loss value')



fig = plt.figure(figsize=(14, 6))

sns.scatterplot(
    x=samples[:, 0].numpy(),
    y=samples[:, 1].numpy(),
    label='Original samples'
)

# Sample the source distribution and transform the samples via the
# trained flow. This should give something that looks much more
# like the original samples, as the flow should have understood
# how to map the simple space into the complicated one.
transformed_samples = model(model.flow.distribution.sample((2500)))

sns.scatterplot(
    x=transformed_samples[:, 0].numpy(),
    y=transformed_samples[:, 1].numpy(),
    label='Transformed samples (trained flow)',
    alpha=0.1
)

plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.title('Samples', fontsize=14)