### Plotting and Imports

In [0]:
# Set up plotting
import IPython
import bokeh
from bokeh.plotting import figure
from bokeh.io import output_notebook, show
from bokeh.layouts import gridplot
IPython.get_ipython().events.register('pre_run_cell', lambda: output_notebook(hide_banner=True))

In [0]:
import sys
import numpy as np
import torch
from torch import nn, distributions
from sklearn import cluster, datasets, mixture, preprocessing
import tqdm

### Data

In [0]:
# Create and visualize data
get_data = lambda: torch.from_numpy(datasets.make_moons(n_samples=100, noise=0.05, shuffle=False)[0].astype(np.float32))

# Plot data sample 
def get_fig(data, title='title', x_label='x', y_label='y', w=400, h=400, xr=None, yr=None, size=10, color='blue'):
    fig = figure(title=title, x_axis_label=x_label, y_axis_label=y_label, width=w, height=h, x_range=xr, y_range=yr)
    fig.scatter(data.numpy()[:,0], data.numpy()[:,1], size=size, color=color)
    return fig
show(get_fig(get_data(), title="Classic Moons Dataset (X space)", x_label='X (dim=1)', y_label='X (dim=2)'))

### Model

In [0]:
class CouplingLayer(nn.Module):
  
    def __init__(self): 
        super(CouplingLayer, self).__init__()
        
        # For masked convolution from Section 3.4
        self.mask_beg = nn.Parameter(torch.tensor([1.0,0.0]), requires_grad=False)
        self.mask_end = nn.Parameter(torch.tensor([0.0,1.0]), requires_grad=False)
        
        # The t and s functions from Eq. 9
        self.t = nn.Sequential(
            nn.Linear(2, 256), 
            nn.LeakyReLU(), 
            nn.Linear(256, 256), 
            nn.LeakyReLU(), 
            nn.Linear(256, 2), 
        )
        self.s = nn.Sequential(
            nn.Linear(2, 256), 
            nn.LeakyReLU(), 
            nn.Linear(256, 256), 
            nn.LeakyReLU(), 
            nn.Linear(256, 2), 
            nn.Tanh(),
        )

    # Forward NVP flow (X -> Z)
    def f_layer(self, x):
        z, log_det_beg = self.f_(x, self.mask_beg)
        z, log_det_end = self.f_(z, self.mask_end)        
        return z, log_det_beg + log_det_end

    # Eq. 8 (note: b is binary mask)
    def f_(self, x, b):        
        s = self.s(x * b)
        t = self.t(x * b)
        y = (x * b) + (1 - b) * (x - t) * torch.exp(-s)        
        log_det = - (s * (1 - b)).sum(dim=1)
        return y, log_det
    
    # Inverse NVP flow (Z -> X)
    def f_inv_layer(self, z):
        y = self.f_inv_(z, self.mask_end)
        y = self.f_inv_(y, self.mask_beg)
        return y

    # Eq. 7 / Eq. 9 (note: b is binary mask)
    def f_inv_(self, z, b): 
        s = self.s(z * b)
        t = self.t(z * b)
        return (z * b) + (1 - b) * (z * torch.exp(s) + t) 
            
class NVP(nn.Module):
  
    def __init__(self, prior, depth=3):
        super(NVP, self).__init__()
        self.prior = prior
        self.depth = depth
        self.layers = nn.ModuleList([CouplingLayer() for _ in range(depth)])
                
    # Applies f: X \to Z
    def f(self, x):
        z = x
        log_det_total = 0
        for layer in self.layers: 
            z, log_det = layer.f_layer(z)
            log_det_total += log_det
        return z, log_det_total
    
    # Applies f^{-1}: Z \to X
    def f_inv(self, z):
        x = z
        for layer in reversed(self.layers):
            x = layer.f_inv_layer(x)
        return x
        
    # Calculates log p(x) = log p(z) + log(|det J(f_x)|)
    def log_prob(self, x):
        z, log_det = self.f(x)
        return self.prior.log_prob(z) + log_det.reshape(-1,1)
    
    # Samples from the prior and passes through f^{-1}: Z \to X
    def sample(self, size):
        z = self.prior.sample((size, 2))
        return self.f_inv(z)

In [0]:
# Create model with N(0,1) prior
prior = distributions.Normal(0,1)

model = NVP(prior)
optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=1e-4)

In [0]:

# Train model with SGD 
epochs = 5000
for i in range(epochs):
    loss = - model.log_prob(get_data()).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 500 == 0:
        print(f'Completed {i:5d}/{epochs:5d} | loss {loss.item():.3f} ')

Completed     0/ 5000 | loss 1.405 
Completed   500/ 5000 | loss -1.012 
Completed  1000/ 5000 | loss -1.076 
Completed  1500/ 5000 | loss -1.177 
Completed  2000/ 5000 | loss -1.214 
Completed  2500/ 5000 | loss -1.250 
Completed  3000/ 5000 | loss -0.978 
Completed  3500/ 5000 | loss -1.330 
Completed  4000/ 5000 | loss -1.292 
Completed  4500/ 5000 | loss -1.269 


### Plotting

In [0]:
# Generate data to plot
with torch.no_grad():
    X_data = get_data()
    X_model = model.sample(100)
    Z_prior = prior.sample((100, 2))
    Z_model = model.f(get_data())[0]

# Plot
X_data_fig = get_fig(X_data.detach(), title="x ~ p(X)", x_label='', y_label='X space')
X_model_fig = get_fig(X_model.detach(), title="x = f^{-1}(z)", x_label='', y_label='X space', xr=X_data_fig.x_range, yr=X_data_fig.y_range,)
Z_prior_fig = get_fig(Z_prior.detach(), title="z ~ p(z)", x_label='', y_label='Z space', xr=X_data_fig.x_range, yr=X_data_fig.y_range, color='green')
Z_model_fig = get_fig(Z_model.detach(), title="z = f(x)", x_label='', y_label='Z space', xr=X_data_fig.x_range, yr=X_data_fig.y_range, color='green')
show(gridplot([X_data_fig, X_model_fig, Z_prior_fig, Z_model_fig], ncols=2))

In [0]:
# Plot grid
colors = bokeh.palettes.viridis(255)
grid = np.stack(np.meshgrid(np.linspace(-2,2,30), np.linspace(-2,2,30))).reshape(2,-1).transpose((1,0))
grid_colors = [colors[i] for i in ((grid**2).sum(axis=1) / (grid**2).sum(axis=1).max() * 250).astype(int)]
with torch.no_grad():
    grid_z = torch.from_numpy(grid).float()
    grid_x = model.f_inv(grid_z)
grid_z_fig = get_fig(grid_z, size=7, color=grid_colors)
grid_x_fig = get_fig(grid_x, size=7, color=grid_colors)
show(gridplot([grid_z_fig, grid_x_fig], ncols=2))