In [None]:
import pyrootutils
root = pyrootutils.setup_root(search_from=".", pythonpath=True)

import torch as T
import numpy as np
import torch.nn as nn
import normflows as nf
from sklearn.datasets import make_moons
from matplotlib import pyplot as plt
from tqdm import tqdm
from mltools.mltools.flows import rqs_flow

In [None]:
flow = rqs_flow(
    xz_dim = 2,
    ctxt_dim = 0,
    num_stacks = 16,
    mlp_width = 32,
    mlp_depth = 2,
    mlp_act = nn.SiLU,
    tail_bound = 4,
    dropout = 0.1,
    num_bins = 4,
    do_lu = False,
    do_norm = False,
    flow_type = "coupling",
    init_identity = True,
)

In [None]:
device = T.device("cuda" if T.cuda.is_available() else "cpu")
flow = flow.to(device)
print(flow)

In [None]:
grid_size = 200
x_bins = T.linspace(-1.5, 2.5, grid_size)
y_bins = T.linspace(-2, 2, grid_size)

# Plot target distribution
x_np, _ = make_moons(2 ** 20, noise=0.1)
plt.figure(figsize=(4, 4))
plt.hist2d(x_np[:, 0], x_np[:, 1], bins=[x_bins, y_bins])
plt.show()

# Plot initial flow distribution
xx, yy = T.meshgrid(x_bins, y_bins, indexing="xy")
zz = T.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
zz = zz.to(device)

flow.eval()
log_prob = flow.log_prob(zz).to('cpu').view(*xx.shape)
flow.train()
prob = T.exp(log_prob)
prob[T.isnan(prob)] = 0

plt.figure(figsize=(4, 4))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal', 'box')
plt.show()

In [None]:
max_iter = 5000
num_samples = 2 ** 10
loss_hist = np.array([])
optimizer = T.optim.Adam(flow.parameters(), lr=1e-3, weight_decay=1e-5)
for it in tqdm(range(max_iter)):
    optimizer.zero_grad()

    # Get training samples
    x_np, _ = make_moons(num_samples, noise=0.1)
    x = T.tensor(x_np).float().to(device)

    # Compute loss
    loss = flow.forward_kld(x)

    # Do backprop and optimizer step
    if ~(T.isnan(loss) | T.isinf(loss)):
        loss.backward()
        optimizer.step()

    # Log loss
    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())

# Plot loss
plt.figure(figsize=(4, 4))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()

In [None]:
# Plot learned distribution
flow.eval()
log_prob = flow.log_prob(zz).to('cpu').view(*xx.shape)
flow.train()
prob = T.exp(log_prob)
prob[T.isnan(prob)] = 0

plt.figure(figsize=(4, 4))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal', 'box')
plt.show()