In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
import torch
from copy import deepcopy

from nflows.utils import tensor2numpy, create_mid_split_binary_mask
from nflows.distributions import StandardNormal
from nflows.transforms import (
    CompositeTransform, 
    InverseTransform,
    BatchNorm,
    PointwiseAffineTransform,
    Tanh,
    ReversePermutation,
    MaskedAffineAutoregressiveTransform as MAF,
    MaskedPiecewiseRationalQuadraticAutoregressiveTransform as RQ_NSF_AR,
    PiecewiseRationalQuadraticCouplingTransform as RQ_NSF_C,
    )
from nflows.nn.nets import ResidualNet
from nflows.flows.base import Flow

device = torch.device('cpu')

In [None]:
x, y = datasets.make_moons(1_000, noise=.1)
plt.scatter(*x.T, c=y);

## RQ-NSF-C

In [None]:
num_layers = 5
hidden_features = 100
num_blocks = 2
activation = torch.relu
num_bins = 5
tails = 'linear'
tail_bound = 5.

base_dist = StandardNormal(shape=[2])

transforms = []
for _ in range(num_layers):
    transforms.append(RQ_NSF_C(
        mask=create_mid_split_binary_mask(2),
        transform_net_create_fn=lambda in_features, out_features: ResidualNet(
            in_features,
            out_features,
            hidden_features=hidden_features,
            context_features=1,
            num_blocks=num_blocks,
            activation=activation,
            dropout_probability=0.,
            use_batch_norm=False,
            ),
        num_bins=num_bins,
        tails=tails,
        tail_bound=tail_bound,
        # apply_unconditional_transform=False,
        # img_shape=None,
        # min_bin_width=splines.rational_quadratic.DEFAULT_MIN_BIN_WIDTH,
        # min_bin_height=splines.rational_quadratic.DEFAULT_MIN_BIN_HEIGHT,
        # min_derivative=splines.rational_quadratic.DEFAULT_MIN_DERIVATIVE,
        ))
    transforms.append(ReversePermutation(features=2))
    # transforms.append(BatchNorm())
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)

In [None]:
epochs = 1000
batch_size = 128
lr = 1e-3

optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
best_epoch = 0
best_loss = np.inf

for epoch in range(epochs):    
    x, y = datasets.make_moons(128, noise=.1)
    x = torch.tensor(x, dtype=torch.float32).to(device)
    y = torch.tensor(y[:, None], dtype=torch.float32).to(device)
    
    optimizer.zero_grad()
    loss = -flow.log_prob(inputs=x, context=y).mean()
    loss.backward()
    optimizer.step()
    
    if loss < best_loss:
        best_epoch = epoch
        best_loss = loss
        best_flow = deepcopy(flow)
        
    print(epoch, float(loss))

In [None]:
plt.scatter(
    *best_flow.sample(
        1_000, torch.tensor([[0.]]),
        )[0].detach().numpy().T,
    );

In [None]:
zgrid = torch.zeros(xy.shape[0])
zgrid.shape

In [None]:
p = 3
n = 200

x = np.linspace(-p, p, n)
y = np.linspace(-p, p, n)
xgrid, ygrid = np.meshgrid(x, y)
xy = np.concatenate(
    [xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)],
    axis=1,
    dtype=np.float32,
    )

with torch.no_grad():
    zgrid = torch.zeros(xy.shape[0])
    for context in [0., 1.]:
        zgrid += best_flow.log_prob(
            torch.tensor(xy), torch.tensor([[context]]*xy.shape[0]),
            ).exp()#.reshape(n, n)
        
plt.imshow(
    zgrid.numpy().reshape(n, n),
    origin='lower',
    aspect='equal',
    extent=(-p, p, -p, p),
    );