In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import torch
import copy

from nflows.utils import tensor2numpy, create_mid_split_binary_mask
from nflows.distributions import StandardNormal
from nflows.transforms import (
    CompositeTransform, 
    InverseTransform,
    BatchNorm,
    PointwiseAffineTransform,
    Tanh,
    ReversePermutation,
    MaskedAffineAutoregressiveTransform as MAF,
    MaskedPiecewiseRationalQuadraticAutoregressiveTransform as RQ_NSF_AR,
    PiecewiseRationalQuadraticCouplingTransform as RQ_NSF_C,
    )
from nflows.nn.nets import ResidualNet
from nflows.flows.base import Flow

device = torch.device('cpu')

In [None]:
x, y = datasets.make_moons(1_000, noise=.1)
plt.scatter(*x.T);

## MAF

In [None]:
num_layers = 5
hidden_features = 100
num_blocks = 2
activation = torch.relu

base_dist = StandardNormal(shape=[2])

transforms = []
transforms.append(PointwiseAffineTransform(shift=0., scale=1/6))
transforms.append(InverseTransform(Tanh()))
for _ in range(num_layers):
    transforms.append(ReversePermutation(features=2))
    transforms.append(MAF(
        features=2,
        hidden_features=hidden_features,
        num_blocks=num_blocks,
        activation=activation,
        ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)

In [None]:
num_iter = 1000
batch_size = 128
learning_rate = 1e-3

optimizer = torch.optim.Adam(flow.parameters(), lr=learning_rate)
best_epoch = 0
best_loss = np.inf

for i in range(num_iter):
    
    x, y = datasets.make_moons(128, noise=.1)
    x = torch.tensor(x, dtype=torch.float32).to(device)
    
    optimizer.zero_grad()
    loss = -flow.log_prob(inputs=x).mean()
    loss.backward()
    optimizer.step()
    
    if loss < best_loss:
        best_epoch = i
        best_loss = loss
        best_flow = copy.deepcopy(flow)
    
    print(i, float(loss))

In [None]:
p = 3
n = 200
x = np.linspace(-p, p, n)
y = np.linspace(-p, p, n)
xgrid, ygrid = np.meshgrid(x, y)
xy = np.concatenate(
    [xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1, dtype=np.float32,
    )
with torch.no_grad():
    zgrid = best_flow.log_prob(torch.tensor(xy)).exp().reshape(n, n)
plt.imshow(zgrid.numpy(), origin='lower', aspect='equal', extent=(-p, p, -p, p));

## RQ-NSF

### Autoregressive

In [None]:
num_layers = 5
hidden_features = 100
num_blocks = 2
activation = torch.relu
num_bins = 10
tails = 'linear'
tail_bound = 5.

base_dist = StandardNormal(shape=[2])

transforms = []
for _ in range(num_layers):
    transforms.append(ReversePermutation(features=2))
    transforms.append(RQ_NSF_AR(
        features=2,
        hidden_features=hidden_features,
        num_blocks=num_blocks,
        activation=activation,
        num_bins=num_bins,
        tails=tails,
        tail_bound=tail_bound
        ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)

In [None]:
num_iter = 1000
batch_size = 128
learning_rate = 1e-3

optimizer = torch.optim.Adam(flow.parameters(), lr=learning_rate)
best_epoch = 0
best_loss = np.inf

for i in range(num_iter):
    
    x, y = datasets.make_moons(128, noise=.1)
    x = torch.tensor(x, dtype=torch.float32).to(device)
    
    optimizer.zero_grad()
    loss = -flow.log_prob(inputs=x).mean()
    loss.backward()
    optimizer.step()
    
    if loss < best_loss:
        best_epoch = i
        best_loss = loss
        best_flow = copy.deepcopy(flow)
    
    print(i, float(loss))

In [None]:
p = 3
n = 200
x = np.linspace(-p, p, n)
y = np.linspace(-p, p, n)
xgrid, ygrid = np.meshgrid(x, y)
xy = np.concatenate(
    [xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1, dtype=np.float32,
    )
with torch.no_grad():
    zgrid = best_flow.log_prob(torch.tensor(xy)).exp().reshape(n, n)
plt.imshow(zgrid.numpy(), origin='lower', aspect='equal', extent=(-p, p, -p, p));

### Coupling

In [None]:
num_layers = 5
hidden_features = 100
num_blocks = 2
activation = torch.relu
num_bins = 5
tails = 'linear'
tail_bound = 5.

base_dist = StandardNormal(shape=[2])

transforms = []
for _ in range(num_layers):
    transforms.append(ReversePermutation(features=2))
    transforms.append(RQ_NSF_C(
        mask=create_mid_split_binary_mask(2),
        transform_net_create_fn=lambda in_features, out_features: ResidualNet(
            in_features,
            out_features,
            hidden_features=hidden_features,
            context_features=None,
            num_blocks=num_blocks,
            activation=activation,
            dropout_probability=0.,
            use_batch_norm=False,
            ),
        num_bins=num_bins,
        tails=tails,
        tail_bound=tail_bound,
        # apply_unconditional_transform=False,
        # img_shape=None,
        # min_bin_width=splines.rational_quadratic.DEFAULT_MIN_BIN_WIDTH,
        # min_bin_height=splines.rational_quadratic.DEFAULT_MIN_BIN_HEIGHT,
        # min_derivative=splines.rational_quadratic.DEFAULT_MIN_DERIVATIVE,
        ))
    # transforms.append(BatchNorm(2))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)

In [None]:
num_iter = 1000
batch_size = 128
learning_rate = 1e-3

optimizer = torch.optim.Adam(flow.parameters(), lr=learning_rate)
best_epoch = 0
best_loss = np.inf

for i in range(num_iter):
    
    x, y = datasets.make_moons(128, noise=.1)
    x = torch.tensor(x, dtype=torch.float32).to(device)
    
    optimizer.zero_grad()
    loss = -flow.log_prob(inputs=x).mean()
    loss.backward()
    optimizer.step()
    
    if loss < best_loss:
        best_epoch = i
        best_loss = loss
        best_flow = copy.deepcopy(flow)
    
    print(i, float(loss))

In [None]:
p = 3
n = 200
x = np.linspace(-p, p, n)
y = np.linspace(-p, p, n)
xgrid, ygrid = np.meshgrid(x, y)
xy = np.concatenate(
    [xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1, dtype=np.float32,
    )
with torch.no_grad():
    zgrid = best_flow.log_prob(torch.tensor(xy)).exp().reshape(n, n)
plt.imshow(zgrid.numpy(), origin='lower', aspect='equal', extent=(-p, p, -p, p));