# Normalizing Flows

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mariogemoll/normalizing-flows/blob/main/py/normalizing-flows.ipynb)

Let's create a normalizing flow model for the moons dataset.

In [None]:
from sklearn.datasets import make_moons
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.animation as animation
from IPython.display import display
from matplotlib import rc


device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

def sample_from_moons(n):
    moons = make_moons(n_samples=n, noise=0.05)[0].astype(np.float32)
    return torch.from_numpy(moons)

X = sample_from_moons(10000)
plt.scatter(X[:, 0], X[:, 1], s=1, alpha=0.5)
plt.show()

The model consists of 8 coupling layers, each of them modifying one of the two dimensions in turn:

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden_dims):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, 1)
        )

    def forward(self, x):
        return self.net(x)

class CouplingLayer(nn.Module):
    def __init__(self, flip):
        super().__init__()
        self.flip = flip
        hidden_dims = 24
        self.scale_net = MLP(hidden_dims)
        self.shift_net = MLP(hidden_dims)

    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        if self.flip:
            x1, x2 = x2, x1

        s = self.scale_net(x1)
        s = torch.tanh(s)
        t = self.shift_net(x1)

        y1 = x1
        y2 = torch.exp(s) * x2 + t
        log_det = s.sum(dim=1)

        if self.flip:
            y1, y2 = y2, y1

        return torch.cat([y1, y2], dim=1), log_det

    def inverse(self, y):
        y1, y2 = y.chunk(2, dim=1)
        if self.flip:
            y1, y2 = y2, y1

        s = self.scale_net(y1)
        s = torch.tanh(s)
        t = self.shift_net(y1)

        # Inverse transform
        x1 = y1
        x2 = (y2 - t) * torch.exp(-s)
        log_det = -s.sum(dim=1)

        if self.flip:
            x1, x2 = x2, x1

        return torch.cat([x1, x2], dim=1), log_det

class NormalizingFlow(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            CouplingLayer(i % 2 == 0)
            for i in range(num_layers)
        ])

    def forward(self, x):
        m = x.shape[0]
        log_det = torch.zeros(m).to(device)
        zs = [x]
        for layer in self.layers:
            x, layer_log_det = layer(x)
            log_det += layer_log_det
            zs.append(x)
        return zs, log_det

    def inverse(self, z):
        m = z.shape[0]
        log_det = torch.zeros(m).to(device)
        xs = [z]
        for layer in reversed(self.layers):
            z, layer_log_det = layer.inverse(z)
            log_det += layer_log_det
            xs.append(z)
        return xs, log_det

Let's train the model.

In [None]:
flow = NormalizingFlow(num_layers=8).to(device)
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

losses = []
pbar = tqdm(range(1000))
for epoch in pbar:
    x = sample_from_moons(1024).to(device)
    optimizer.zero_grad()

    zs, log_j = flow(x)
    z = zs[-1]
    loss = 0.5 * torch.sum(z**2, dim=1).to(device) - log_j
    loss = loss.mean()

    loss.backward()
    optimizer.step()
    pbar.set_description(f'Loss: {loss.item():.2f}')
    losses.append(loss.item())

plt.figure()
plt.plot(losses)
plt.show()

We can now see how the model transforms the moons data distribution into something which
looks pretty much like a normal distribution:

In [None]:
flow.eval()
X = X.to(device)
with torch.no_grad():
    zs, _ = flow(X)
    z = zs[-1]
X = X.cpu()
z = z.cpu()
plt.figure(figsize=(12, 5))
plt.subplot(121)
plt.scatter(X[:, 0], X[:, 1], s=1, alpha=0.5)
plt.title("Original Data")
plt.subplot(122)
plt.scatter(z[:, 0], z[:, 1], s=1, alpha=0.5)
plt.title("Transformed Data")
plt.show()

In [None]:
# Enable JS animation in Jupyter
rc('animation', html='jshtml')

flow_outputs = [output.cpu() for output in zs]

# Interpolation settings
num_interp_frames = 10
total_frames = (len(flow_outputs) - 1) * num_interp_frames

# Precompute interpolated frames
interp_outputs = []
for i in range(len(flow_outputs) - 1):
    for t in np.linspace(0, 1, num_interp_frames):
        interpolated = (1 - t) * flow_outputs[i] + t * flow_outputs[i + 1]
        interp_outputs.append(interpolated)

# Turn off interactive mode to prevent automatic figure display
plt.ioff()

fig, ax = plt.subplots()
sc = ax.scatter([], [], s=1, alpha=0.5)

x_min, x_max = (
    min(d[:, 0].min() for d in flow_outputs),
    max(d[:, 0].max() for d in flow_outputs)
)
y_min, y_max = (
    min(d[:, 1].min() for d in flow_outputs),
    max(d[:, 1].max() for d in flow_outputs)
)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

def update(frame):
    data = interp_outputs[frame]
    sc.set_offsets(data)
    return sc,

anim = animation.FuncAnimation(
    fig, update, frames=total_frames, interval=50, blit=False
)

# Only display the animation, not the static figure
display(anim)