In [None]:
from glasflow.nflows.transforms import LULinear
from glasflow.nflows.distributions import StandardNormal
from nessai.flows.base import NFlow
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from thesis_utils.plotting import (
    set_plotting,
    save_figure,
    get_default_figsize,
)
from thesis_utils.random import seed_everything

set_plotting()
seed_everything()

In [None]:
dims = 2

In [None]:
data_dist = torch.distributions.MultivariateNormal(
    torch.ones(dims), covariance_matrix=torch.tensor([[1.2, 0.9], [0.9, 1.2]])
)

In [None]:
reference_data = data_dist.sample((50_000,))

In [None]:
transform = LULinear(dims)
dist = StandardNormal((dims,))
flow = NFlow(transform, dist)

In [None]:
optimizer = torch.optim.Adam(flow.parameters(), lr=0.05)
n_epochs = 1000
batch_size = 1000
history = []

for e in range(n_epochs):
    data = data_dist.sample((batch_size,))

    optimizer.zero_grad()

    loss = -flow.log_prob(data).mean()
    loss.backward()
    optimizer.step()
    history.append(loss.item())

In [None]:
plt.plot(history)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
flow.eval()
with torch.inference_mode():
    z = flow.forward(reference_data)[0]

In [None]:
fig, axs = plt.subplots(1, 2)
plt.setp(axs.flat, aspect=1.0, adjustable="box")

xrange = [-4.0, 4.0]

kwargs = dict(
    fill=True,
    cmap="Blues",
    levels=6,
    thresh=0,
    clip=xrange,
)


sns.kdeplot(
    x=reference_data[:, 0],
    y=reference_data[:, 1],
    ax=axs[0],
    **kwargs,
)


sns.kdeplot(
    x=z[:, 0],
    y=z[:, 1],
    ax=axs[1],
    **kwargs,
)

zero_colour = axs[1].get_children()[2].get_facecolor()
for ax in axs:
    ax.set_facecolor(zero_colour)
    ax.set_xlim(xrange)
    ax.set_ylim(xrange)

axs[0].set_xlabel(r"$x_0$")
axs[0].set_ylabel(r"$x_1$")
axs[1].set_xlabel(r"$z_0$")
axs[1].set_ylabel(r"$z_1$")

plt.show()

In [None]:
save_figure(fig, "lu_factorization", "figures")