In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from nessai.flowmodel import FlowModel
import numpy as np
from scipy.stats import norm, multivariate_normal, chi
import seaborn as sns
import torch

import thesis_utils.colours as thesis_colours
from thesis_utils.plotting import pp_plot, set_plotting, save_figure
from thesis_utils.random import seed_everything

seed_everything()
set_plotting()

# Training data

In [None]:
data = 5 * np.random.randn(10_000, 2)

In [None]:
target_dist = norm(loc=2, scale=2)

In [None]:
weights = target_dist.pdf(data).prod(axis=1) / norm(scale=5).pdf(data).prod(
    axis=1
)

In [None]:
true = target_dist.rvs(size=(10_000, 2))

## Flow

In [None]:
config = dict(
    annealing=False,
    patience=50,
    lr=0.001,
    batch_size=1000,
    model_config=dict(
        n_inputs=2,
        n_neurons=16,
        n_blocks=2,
        kwargs=dict(
            batch_norm_between_layers=False,
            batch_norm_within_layers=False,
            # linear_transform="lu",
        ),
    ),
)

In [None]:
fm = FlowModel(config=config, output="outdir/training_w_weights/")

In [None]:
history = fm.train(data, weights=weights)

In [None]:
plt.plot(history["loss"])
plt.plot(history["val_loss"])
plt.show()

In [None]:
with torch.inference_mode():
    samples = fm.sample(5_000)

In [None]:
fig = plt.figure()
plt.scatter(data[:, 0], data[:, 1])
plt.scatter(true[:, 0], true[:, 1])
plt.scatter(samples[:, 0], samples[:, 1])
plt.show()

In [None]:
x = np.linspace(-20, 20, 1000, endpoint=True)
target_pdf = target_dist.pdf(x)
data_pdf = norm(loc=0, scale=5).pdf(x)

In [None]:
labels = [r"$x_0$", r"$x_1$"]
colours = ["C0", "C1"]

fig, axs = plt.subplots(2, 1)
axs[0].plot(x, target_pdf, label="Target", c="k")
axs[0].plot(x, data_pdf, label="Training data", ls="--", c=thesis_colours.pillarbox)
axs[0].hist(
    samples[:, 0],
    32,
    density=True,
    histtype="step",
    ls="-.",
    label=r"$x_0$",
    color=colours[0],
)
axs[0].hist(
    samples[:, 1],
    32,
    density=True,
    histtype="step",
    ls="-.",
    label=r"$x_1$",
    color=colours[1],
)
axs[0].set_xlabel(r"$x_i$")
axs[0].set_ylabel(r"$p(x_i)$")

axs[0].legend()

n_steps = 1000
for d, label, colour in zip(samples.T, labels, colours):
    sorted_data = np.sort(d)
    pp_data = target_dist.cdf(sorted_data)
    fig = pp_plot(
        pp_data,
        labels=label,
        ax=axs[1],
        n_steps=n_steps,
        # confidence_intervals=True,
        colours=colour,
    )

axs[1].legend()
axs[1].set_xlabel("Theoretical percentiles")
axs[1].set_ylabel("Sample percentiles")

plt.show()
save_figure(fig, "flow_weights_plot", "figures")
plt.show()

In [None]:
def get_circle_points(r, levels, n=100):
    points = np.empty([len(levels), 2, n])
    theta = np.linspace(0, 2 * np.pi, n)
    dist = chi(2, scale=r)
    for i, level in enumerate(levels):
        r_s = dist.ppf(level)
        x = r_s * np.ones(2) / np.sqrt(2)
        points[i, 0] = r_s * np.cos(theta)
        points[i, 1] = r_s * np.sin(theta)
    return points

In [None]:
levels = 1.0 - np.exp(-0.5 * np.arange(1.0, 3.1, 1.0) ** 2)

In [None]:
legend_elements = [
    Line2D([0, 0], [0, 0], ls="--", color="grey", label="Training data"),
    Line2D([0, 0], [0, 0], ls="-.", color=thesis_colours.pillarbox, label="Target"),
    Line2D([0, 0], [0, 0], ls="-", color=thesis_colours.cobalt, label="Flow"),
]

In [None]:
fig, axs = plt.subplots(1, 2)

data_points = get_circle_points(5, levels)
ls = ["-", "-.", ":"]
for i, p in enumerate(data_points):
    axs[0].plot(p[0], p[1], c="grey", ls="--")

target_points = get_circle_points(2, levels) + 2
ls = ["-", "-.", ":"]
for i, p in enumerate(target_points):
    axs[0].plot(p[0], p[1], c=thesis_colours.pillarbox, ls="-.")

sns.kdeplot(
    x=samples[:, 0],
    y=samples[:, 1],
    levels=(1 - levels)[::-1],
    ax=axs[0],
    color=thesis_colours.cobalt,
    bw_adjust=1.0,
)

axs[0].set_xlabel(r"$x_0$")
axs[0].set_ylabel(r"$x_1$")

axs[0].legend(handles=legend_elements, loc="lower left")


n_steps = 1000
for d, label, colour in zip(samples.T, labels, colours):
    sorted_data = np.sort(d)
    pp_data = target_dist.cdf(sorted_data)
    fig = pp_plot(
        pp_data,
        labels=label,
        ax=axs[1],
        n_steps=n_steps,
        confidence_intervals=[0.995],
        colours=colour,
    )


axs[1].legend()
axs[1].set_xlabel("Theoretical percentiles")
axs[1].set_ylabel("Sample percentiles")

for ax in axs:
    ax.set_aspect('equal', 'box')

plt.tight_layout()
plt.show()
save_figure(fig, "flow_weights_plot_alt")