# Normalizing flows for probability distribution reconstruction

Normalizing flows can approximate complex probability distributions by applying a sequence of invertible transformations to a simpler base distribution, such as a multivariate Gaussian. If the base distribution has a joint probability density function (PDF) $p\left(\textbf{x}\right)$, the sequence for invertible transformations $f_1, f_2, ... f_k$ is applied to the original variable $\textbf{x}$ such that

$$\textbf{z} = f_k\left(f_{k-1}\left(...f_1\left(\textbf{x}\right)\right)\right) = g\left(\textbf{x}\right), $$

where $\textbf{z}$ is the transformed variable and $g\left(\textbf{x}\right)$ is the composition of all other transformations $f_i$. This flexible framework allows for the learning of complex probability distributions that might not be obtainable by other methods.

Additionally, we can see that one of the benefits of normalizing flows is efficient sampling of complex distributions. Samples can be easily calculated for the base distribution, such as a Gaussian, which then through the normalizing flow transform $g\left(\textbf{x}\right)$ provides samples from the learned distribution.

The transformed joint PDF $p\left(\textbf{z}\right)$ can be easily calculated from the trained normalizing flow. The normalization condition for a PDF dictates that, for both $p\left(\textbf{x}\right)$ and $p\left(\textbf{z}\right)$,

$$\int ... \int d x_1 ... d x_N \; p\left(\textbf{x}\right) = 1, \tag{1}$$

$$\int ... \int d z_1 ... d z_N \; p\left(\textbf{z}\right) = 1. $$

Because $\textbf{z} = g\left(\textbf{x}\right)$, the second integral can be expressed in terms of $x$ as 

$$
\begin{align}
    1 &= \int ... \int d z_1 ... d z_N \; p\left(\textbf{z}\right) \\
    &= \int ... \int d x_1 ... d x_N \left| \text{det} \frac{d \textbf{z}}{d \textbf{x}} \right| p\left(g\left(x\right)\right),
\end{align}\tag{2}
$$

where $\frac{d \textbf{z}}{d \textbf{x}}$ is the matrix of first derivatives whose determinant is the *Jacobian* of the transformation function $g$. With the correct assumptions about the integration domain, comparison of Equations 1 and 2 yield

$$p\left(\textbf{z}\right) = p\left(\textbf{x}\right) \left| \text{det} \frac{d \textbf{x}}{d \textbf{z}} \right|, $$

which can be easily calculated from the trained normalizing flow using automatic differentiation. See these resources [[1]](https://en.wikibooks.org/wiki/Probability/Transformation_of_Random_Variables) [[2]](https://stats.libretexts.org/Bookshelves/Probability_Theory/Probability_Mathematical_Statistics_and_Stochastic_Processes_(Siegrist)/03%3A_Distributions/3.07%3A_Transformations_of_Random_Variables) [[3]](https://en.wikipedia.org/wiki/Probability_density_function)
for more details about the transformation of random variables.

## Tutorial

This tutorial will demonstrate the use of the `normflows` Python package (see the `normalizing-flows` [Github repository](https://github.com/VincentStimper/normalizing-flows)) for a simple example using a normalizing flow to approximate an unknown target distribution. An IPython notebook containing the full code in this tutorial can be found [here](https://github.com/jayspendlove/blog/tree/main/simple-normalizing-flow). For this example, the 2D target distribution $p(a,b)$ will consist of a random variable $a$ sampled from a uniform distribution, and another random variable $b$ sampled from a Normal distribution with mean and variance $a$. 

$$ a \sim \text{Uniform}\left[1,2\right] $$
$$ b \sim \text{Normal}\left(\mu=a, \sigma^2=a\right). $$

Stated differently, we want the normalizing flow approximation $NF\left(a, b\right) \approx p\left(a,b\right)$. It is assumed that samples from the target distribution $p\left(a,b\right)$ can be easily obtained. 

This tutorial will begin by visualizing the target distribution $p\left(a,b\right)$, followed by construction and training of the normalizing flow on samples from $p\left(a,b\right)$. 

In [None]:
import normflows as nf
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# DO NOT INCLUDE IN BLOG POST
# NF Cuda setup-- move model on GPU if available
enable_cuda = False # not on my laptop
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

## Visualizing the target distribution

To help us get an idea of what this target distribution $p\left(a,b\right)$ looks like, we will visualize it in a couple different ways. First, we'll draw $10000$ samples of pairs $\left(a,b\right)$ from the distributions defined in Equations 4 and 5 and visualize the resulting 2D scatter plot.

In [None]:
N = 10000
sampled_a_vals = np.random.uniform(1, 2, N) # Uniform distribution
sampled_b_vals = np.random.normal(sampled_a_vals, np.sqrt(sampled_a_vals)) # Normal distribution with mean a and variance a

In [None]:
# DO NOT INCLUDE IN BLOG POST
alim = (0.95, 2.05)
blim = (-3, 7)

# If I can delete a cell and not the picture output, I will for the plotting code
plt.figure(figsize=(6,6))
plt.scatter(sampled_a_vals, sampled_b_vals, alpha=0.15)
# For the b distribution, plot mean and STD's
a_mean_var = np.linspace(1,2,100)
plt.plot(a_mean_var, a_mean_var, 'r', label="Mean")
colors = ['b', 'lime', 'hotpink']
for i,n_std in enumerate([1, 2, 3]):
    plt.plot(a_mean_var, a_mean_var + n_std*np.sqrt(a_mean_var), c=colors[i], linestyle='--', label=f"{n_std} STD")
    plt.plot(a_mean_var, a_mean_var - n_std*np.sqrt(a_mean_var), c=colors[i], linestyle='--')
plt.xlabel("a")
plt.ylabel("b")
plt.xlim(*alim)
plt.ylim(*blim)
plt.legend()
plt.show()

We see that as expected, both the mean and the standard deviation of the $b$ distribution increase with increasing $a$. Note that the sampled points appear to have the highest density for low $a$. We can see the same thing by calculating the expected probability distribution $p\left(a,b\right)$ analytically, using the law of total probability.

In [None]:
def p_a(a):
    """Probability density function of a ~ Uniform[1,2]. Trivial function in this case, but included for consistency."""
    return np.where((a >= 1) & (a <= 2), 1, 0)

def p_b_given_a(b, a):
    """Probability density function of b ~ Normal(a, a), p(b|a). 
    
    Args:
        b (np.ndarray): Values of b.
        a (np.ndarray): Values of a, mean and variance of b.
    """
    return (1 / (np.sqrt(2 * np.pi * a))) * np.exp(-0.5 * ((b - a)**2 / a))

def p_a_b(ab):
    """Joint probability density p(a, b).
    
    Args:
        ab (np.ndarray): 2D array where each row is a tuple (a, b).
    """
    a = ab[:, 0]
    b = ab[:, 1]
    return p_a(a) * p_b_given_a(b, a)

In [None]:
# DO NOT INCLUDE IN BLOG POST
grid_size = [1000, 1000]
a_vals = np.linspace(*alim, grid_size[0])
b_vals = np.linspace(*blim, grid_size[1])
aa, bb = np.meshgrid(a_vals, b_vals, indexing='ij')
zz = np.stack([aa, bb], axis=2).reshape(-1, 2)

prob = p_a_b(zz).reshape(*grid_size)

In [None]:
# DO NOT INCLUDE IN BLOG POST

plt.figure(figsize=(6,6))
plt.pcolormesh(aa, bb, prob, cmap='coolwarm')
plt.colorbar(label='Probability Density')
plt.xlabel('a')
plt.ylabel('b')

# mark mean and std like on other plot
a_mean_var = np.linspace(1,2,100)
plt.plot(a_mean_var, a_mean_var, 'r', label="Mean")
colors = ['b', 'lime', 'hotpink']
for i,n_std in enumerate([1, 2, 3]):
    plt.plot(a_mean_var, a_mean_var + n_std*np.sqrt(a_mean_var), c=colors[i], linestyle='--', label=f"{n_std} STD")
    plt.plot(a_mean_var, a_mean_var - n_std*np.sqrt(a_mean_var), c=colors[i], linestyle='--')
plt.show()

In addition to visualizing the full 2D distribution, we can look at just the $a$ and $b$ target distributions by marginalizing out the other variable. Because $a$ does not depend on $b$, the marginalized distribution for $a$ is just the $p\left(a\right) \sim \text{Uniform}\left[1,2\right]$. The marginal $b$ distribution, however, is more challenging to obtain because it depends on $a$ in a non-trivial way. If you think about it, as $a$ increases from $1$ to $2$, the distribution for $b$ will drift and spread out, meaning that the total marginal distribution of $b$ is a combination of all these different Normal distributions for different values of $a$. The marginal distribution $p\left(b\right)$ can be calculated by

$$
p(b) = \int_1^2 da \frac{1}{\sqrt{2 \pi a}} \text{exp}\left(\frac{\left(b-a\right)^2}{2a}\right).
$$

This integral is challenging to solve analytically, but can be calculated straightforwardly via numerical integration. Below is shown the marginal distributions for $a$ and $b$.

In [None]:
# DO NOT INCLUDE IN BLOG POST
def marginalize_p_b(p_a_b, a_vals, b_val):
    """
    Calculate the marginal probability p(b) by integrating over the joint probability distribution p(a, b).

    Args:
        p_a_b (function): Joint probability distribution function p(a, b).
        a_vals (numpy.ndarray): Array of a values to integrate over.
        b_val (float): The value of b for which to calculate the marginal probability p(b).

    Returns:
        float: The marginal probability p(b).
    """
    # Calculate the joint probability p(a, b) for each a in a_vals
    p_a_b_vals = p_a_b(np.column_stack((a_vals, np.full_like(a_vals, b_val))))
    
    # Use the trapezoidal rule to integrate over a_vals
    p_b = np.trapezoid(p_a_b_vals, a_vals)
    
    return p_b

prob_b = np.array([marginalize_p_b(p_a_b, a_vals, b_val) for b_val in b_vals])

In [None]:
# DO NOT INCLUDE IN BLOG POST
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# Plot a Target Distribution
axs[0].hist(sampled_a_vals, bins=50, alpha=0.5, density=True)
# axs[0].plot([0.9, 0.999, 1.0, 2.0, 2.001, 2.1], [0, 0, 1, 1, 0, 0], 'r')
axs[0].plot(a_vals, p_a(a_vals), 'r')
axs[0].set_xlim(*alim)
axs[0].set_xlabel("a")
axs[0].set_ylabel("Probability Density")
axs[0].set_title("Marginalized a Target Distribution")

# Plot Marginalized b Target Distribution
axs[1].plot(b_vals, prob_b, c="r")
axs[1].hist(sampled_b_vals, bins=50, alpha=0.4, density=True, color="tab:blue")
axs[1].set_xlim(*blim)
axs[1].set_xlabel("b")
axs[1].set_ylabel("Probability Density")
axs[1].set_title("Marginalized b Target Distribution")

plt.tight_layout()
plt.show()

## Setting up the normalizing flow

Now that we understand our target distributions, we can begin to set up our normalizing flow to approximate $p(a,b)$.

Here, we use the `normflows` Python package which utilizes Pytorch. Setting up a normalizing flow includes deciding on the number and type of flow layers and the base distribution. The base distribution is the initial choice of probability density, which the normalizing flow will transform over the course of training. These choices are problem specific. 

In [None]:
# Define NF architecture
torch.manual_seed(0)
K = 16 # number of repeated blocks
latent_size = 2 # num input channels
flows = []
for i in range(K):
    param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
    flows += [nf.flows.AffineCouplingBlock(param_map)]
    flows += [nf.flows.LULinearPermute(latent_size)]

base = nf.distributions.DiagGaussian(2, trainable=False) # Base distribution
model = nf.NormalizingFlow(q0=base, flows=flows)

## Training the normalizing flow

The next step is to train the normalizing flow on samples of $(a,b)$ from the true distributions. Here the loss function is chosen to be the negative log likelihood. Using the Adam optimizer, the normalizing flow is trained for 1000 epochs, with 512 samples of the true distribution per epoch.

In [None]:
# DO NOT INCLUDE IN BLOG POST
model = model.to(device)

def sample_ab(num_samples, device):
    '''Sample from a and b distributions'''
    a = np.random.uniform(1, 2, num_samples)
    b = np.array([np.random.normal(a_val, np.sqrt(a_val)) for a_val in a])
    x_np = np.stack([a, b], axis=1)
    x = torch.tensor(x_np).float().to(device)
    return x

In [None]:
# DO NOT INCLUDE IN BLOG POST
# Define before for saving the necessary information
epochs = 1000
loss_hist = np.zeros(epochs) # Store loss values

# Plotting parameters for 2D probability density grids used in tutorial
grid_size = (100, 120)
a_bounds = [-0.5, 2.1]
b_bounds = [-3, 6]
a_vals = torch.linspace(*a_bounds, grid_size[0])
b_vals = torch.linspace(*b_bounds, grid_size[1])
aa, bb = torch.meshgrid(a_vals, b_vals, indexing='ij')
zz = torch.cat([aa.unsqueeze(2), bb.unsqueeze(2)], 2).view(-1, 2)
zz_np = zz.detach().numpy()
zz = zz.to(device)
vmax = 0.4

# Save probability density for base distribution
base_prob = base.log_prob(zz).exp().to('cpu').view(*aa.shape)

# Save target probability density
prob_target = p_a_b(zz.detach().numpy())
prob_target[np.isnan(prob_target)] = 0 # set NaNs to 0
prob_target = prob_target.reshape(*grid_size)

# iterations to keep
# iter_early = np.arange(1, 11, 2)
n_samp = 5
end = 20
start = 1
iter_early = np.arange(2, 21, (end-start)//(n_samp+1))
print(iter_early)
# iter_early = np.arange(1, end + 1, 2)
show_iter = 200
iter_later = np.arange(show_iter-1, epochs, show_iter)
iterations = np.hstack((iter_early, iter_later)) #one less than "human readable" iteration number
prob_grid = np.zeros((len(iterations), *grid_size)) # store probability grids
print(iterations)

def save_info(i, model, loss):
    loss_hist[i] = loss.to('cpu').data.numpy()
    if i in iterations:
        model.eval() #change to evaluation mode
        log_prob = model.log_prob(zz)
        model.train() #change back to training mode
        prob = torch.exp(log_prob.to('cpu').view(*aa.shape))
        prob[torch.isnan(prob)] = 0 # set NaNs to 0
        prob_grid[np.where(iterations == i)[0][0],:,:] = prob.data.numpy()


In [None]:
# Train NF
epochs = 1000
num_samples = 2 ** 9 # 512 samples per iteration
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

for it in range(epochs):
    optimizer.zero_grad()
    x = sample_ab(num_samples, device) # Get training samples
    loss = -1*model.log_prob(x).mean() # Compute loss
    # Do backprop and optimizer step
    if ~(torch.isnan(loss) | torch.isinf(loss)): # Check for NaNs or infs
        loss.backward()
        optimizer.step()
    save_info(it, model, loss)

Following training, we can see that in a small number of iterations the normalizing flow learned distribution was able to effectively capture the primary features of the target distribution. The error plot on the right shows where the true and learned distributions disagree.

In [None]:
# DO NOT INCLUDE IN BLOG POST
# Plot base distribution, learned distribution, true distribution, and error side by side
fig, axs = plt.subplots(1, 4, figsize=(15, 7))

# Plot base distribution
axs[0].pcolormesh(aa, bb, base_prob.data.numpy(), cmap='coolwarm', vmin=0, vmax=vmax)
axs[0].set_aspect('equal', 'box')
axs[0].set_title("Base Distribution")

# Plot learned distribution
model.eval()
log_prob = model.log_prob(zz).to('cpu').view(*aa.shape)
model.train()
prob_nf = torch.exp(log_prob)
prob_nf[torch.isnan(prob_nf)] = 0
prob_nf = prob_nf.data.numpy()

axs[1].pcolormesh(aa, bb, prob_nf, cmap='coolwarm', vmin=0, vmax=vmax)
axs[1].set_aspect('equal', 'box')
axs[1].set_title("Learned Distribution")

# Plot true distribution
axs[2].pcolormesh(aa, bb, prob_target, cmap='coolwarm', vmin=0, vmax=vmax)
axs[2].set_aspect('equal', 'box')
axs[2].set_title("True Distribution")

# Plot error
error = np.abs(prob_nf - prob_target)
axs[3].pcolormesh(aa, bb, error, cmap='gray_r')
axs[3].set_aspect('equal', 'box')
axs[3].set_title("Error (|Learned - True|)")

# Add a shared colorbar for the first three plots
cbar = fig.colorbar(axs[0].collections[0], ax=axs[:3], orientation='vertical', fraction=0.02, pad=0.04, aspect=70)
cbar.set_label('Probability Density')

# Add a separate colorbar for the error plot
cbar_error = fig.colorbar(axs[3].collections[0], ax=axs[3], orientation='vertical', fraction=0.02, pad=0.04, aspect=70)
cbar_error.set_label('Error')

plt.show()

To get an idea of how quickly the normalizing flow converged to this approximation, we can look at the training loss over epochs. We see that the loss for the normalizing flow solution has plateaued by the end of training. In fact, even if we train for 10x as long the learned distribution does not improve significantly.

In [None]:
# Plot loss
plt.figure(figsize=(6, 6))
plt.plot(loss_hist, label='loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.show()

By looking at probability densities at intermediate epochs, we can see that the approximate normalizing flow solution achieves reasonable accuracy early in training.

In [None]:
# DON'T INCLUDE CELL IN BLOG POST
fig, axs = plt.subplots(1, len(iter_later) + 2, figsize=(18, 6),sharey=True)

# Plot base distribution
axs[0].pcolormesh(aa, bb, base_prob, cmap='coolwarm', vmin=0, vmax=vmax)
axs[0].set_aspect('equal', 'box')
axs[0].set_title("Base Distribution")
axs[0].set_xlabel("a")
axs[0].set_ylabel("b")

# Plot the first 5 saved densities in prob_grid
for plot_i, iter in enumerate(iter_later):
    iter_i = np.where(iterations == iter)[0][0]
    axs[plot_i + 1].pcolormesh(aa, bb, prob_grid[iter_i], cmap='coolwarm', vmin=0, vmax=vmax)
    axs[plot_i + 1].set_aspect('equal', 'box')
    axs[plot_i + 1].set_title(f"Epoch {iter + 1}")
    axs[plot_i + 1].set_xlabel("a")

# Plot the target probability density
axs[-1].pcolormesh(aa, bb, prob_target, cmap='coolwarm', vmin=0, vmax=vmax)
axs[-1].set_aspect('equal', 'box')
axs[-1].set_title("Target Distribution")
axs[-1].set_xlabel("a")

# Add a shared colorbar
cbar = fig.colorbar(axs[0].collections[0], ax=axs, orientation='vertical', fraction=0.02, pad=0.04)#, fraction=0.02, pad=0.04, aspect=70)
cbar.set_label('Probability Density')
plt.show()


From the loss function, we see that the largest changes happen at the very beginning of training. In the subplots below, by looking at some of the first few epochs we can see the probability density migrating from the original Gaussian centered at $(0,0)$ towards the region with the target density.

In [None]:
# DON'T INCLUDE CELL IN BLOG POST
fig, axs = plt.subplots(1, len(iter_early) + 2, figsize=(18, 6),sharey=True)

# Plot base distribution
axs[0].pcolormesh(aa, bb, base_prob, cmap='coolwarm', vmin=0, vmax=vmax)
axs[0].set_aspect('equal', 'box')
axs[0].set_title("Base Distribution")
axs[0].set_xlabel("a")
axs[0].set_ylabel("b")

# Plot the first 5 saved densities in prob_grid
for i, iter in enumerate(iter_early):
    axs[i + 1].pcolormesh(aa, bb, prob_grid[i], cmap='coolwarm', vmin=0, vmax=vmax)
    axs[i + 1].set_aspect('equal', 'box')
    axs[i + 1].set_title(f"Epoch {iter + 1}")
    axs[i + 1].set_xlabel("a")

# Plot the target probability density
axs[-1].pcolormesh(aa, bb, prob_target, cmap='coolwarm', vmin=0, vmax=vmax)
axs[-1].set_aspect('equal', 'box')
axs[-1].set_title("Target Distribution")
axs[-1].set_xlabel("a")

# Add a shared colorbar
cbar = fig.colorbar(axs[0].collections[0], ax=axs, orientation='vertical', fraction=0.02, pad=0.04)#, fraction=0.02, pad=0.04, aspect=70)
cbar.set_label('Probability Density')

plt.show()


As another perspective of the normalizing flow solution, we can look at the marginalized probability densities for $a$ and $b$ and compare the true and approximate solutions, in red and black respectively.

In [None]:
# DO NOT INCLUDE IN BLOG POST
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# Plot a Target Distribution
prob = torch.exp(model.log_prob(zz).view(*aa.shape))
prob[torch.isnan(prob)] = 0  # Fix NaNs
prob_marg = prob.sum(dim=1).detach().numpy()
da = (a_vals[1] - a_vals[0]).numpy()
prob_marg_norm = prob_marg / (prob_marg.sum()*da)
axs[0].plot(a_vals, prob_marg_norm, c="k", label="NF", linestyle="dashed")#, linewidth=3)
axs[0].plot(a_vals, p_a(a_vals), 'r', label="True")
axs[0].set_xlabel("a")
axs[0].set_ylabel("Probability Density")
axs[0].set_title("Marginalized a distribution")
axs[0].legend()

# Plot Marginalized b Target Distribution
# NF estimate
prob = torch.exp(model.log_prob(zz).view(*aa.shape))
prob[torch.isnan(prob)] = 0  # Fix NaNs
prob_marg = prob.sum(dim=0).detach().numpy()
db = (b_vals[1] - b_vals[0]).numpy()
prob_marg_norm = prob_marg / (prob_marg.sum()*db)
axs[1].plot(b_vals, prob_marg_norm, label="NF", c="k", linestyle="dashed")#, linewidth=3)
# True marginalized b distribution
a_vals_int = np.linspace(1, 2, 1000)
prob_b = np.array([marginalize_p_b(p_a_b, a_vals_int, b_val) for b_val in b_vals.numpy()])

axs[1].plot(b_vals, prob_b, c="r", label="True")
axs[1].set_xlabel("b")
axs[1].set_ylabel("Probability Density")
axs[1].set_title("Marginalized b Distribution")
axs[1].legend()

plt.tight_layout()
plt.show()

## Sampling from the trained normalizing flow

Once the normalizing flow has been trained, sampling from it is trivial.

In [None]:
model.eval() #Set model to evaluation mode
samples, log_prob = model.sample(num_samples=1000) #Sample from normalizing flow

In [None]:
# DO NOT INCLUDE IN BLOG POST
# Extract a and b from the samples
a_samples = samples[:, 0].detach().numpy()
b_samples = samples[:, 1].detach().numpy()

# Plot the sampled values
plt.figure()
plt.scatter(a_samples, b_samples, alpha=0.5)
plt.xlabel("a")
plt.ylabel("b")
plt.title("Samples from the trained Normalizing Flow")
plt.show()

## Acknowledgements

Some of the code in this tutorial was adapted from example scripts in the `normalizing-flows` repository: https://github.com/VincentStimper/normalizing-flows

Stimper et al., (2023). normflows: A PyTorch Package for Normalizing Flows. Journal of Open Source Software, 8(86), 5361, https://doi.org/10.21105/joss.05361

## Author
Jay Spendlove

PhD student, Arizona State University

jcspendl@asu.edu