# Normalizing Flow

In [None]:
import math
from typing import TypeAlias

import matplotlib.pyplot as plt
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import seaborn as sns
import torch
import torch.linalg as LA

from distributions import SphericalUniformPrior3D
from models import NormalizingFlowRQS, NormalizingFlowC2
from transforms import MobiusTransform, RQSplineTransform, RQSplineTransformCircularDomain, C2SplineTransform
from utils import metropolis_acceptance, effective_sample_size, spherical_mesh, simple_fnn_conditioner
from visualisations import scatter, pairplot

Tensor: TypeAlias = torch.Tensor

π = math.pi

sns.set_theme()

## Rational Quadratic Splines x2 (C1 transform)

In [None]:
@torch.no_grad()
def visualise_rqs(model):
        
    # Inputs
    z_in = torch.linspace(-1, 1, 10000).unsqueeze(-1)
    ϕ_in = torch.linspace(0, 2 * π, 10000).unsqueeze(-1)
        
    for z_spline_params, ϕ_spline_params in zip(
        model.z_spline_params.split(1, dim=0),
        model.ϕ_spline_params,
    ):
        z_out, ldj_z = model.z_spline(z_in, z_spline_params.expand(*z_in.shape, -1))
        ϕ_out, ldj_ϕ = model.ϕ_spline(ϕ_in, ϕ_spline_params(z_out).view(*ϕ_in.shape, -1))

        # plot here
        fig, axes = plt.subplots(1, 4, figsize=(16, 4))
        axes = iter(axes)

        ax = next(axes)
        ax2 = ax.twinx()
        ax2.grid(False)
        spline, = ax.plot(z_in.squeeze(), z_out.squeeze(), color="tab:blue")
        grad, = ax2.plot(z_in.squeeze(), ldj_z.squeeze(), color="tab:orange")
        ax.set_xlabel(r"$z_{in}$")
        ax.set_ylabel(r"$z_{out}$")
        ax.set_title("Spline for z coordinate")
        ax2.legend(handles=[spline, grad], labels=["spline", "log gradient"])

        ax = next(axes)
        ax.hist(z_in.squeeze(), bins=20, histtype="step", label=r"$z_{in}$")
        ax.hist(z_out.squeeze(), bins=20, histtype="step", label=r"$z_{out}$")
        ax.set_title("Histogram for z coordinate")
        ax.legend()

        ax = next(axes)
        ax2 = ax.twinx()
        ax2.grid(False)
        spline, = ax.plot(ϕ_in.squeeze(), ϕ_out.squeeze(), color="tab:blue", label="spline")
        grad, = ax2.plot(ϕ_in.squeeze(), ldj_ϕ.squeeze(), color="tab:orange", label="log gradient")
        ax.set_xlabel(r"$\phi_{in}$")
        ax.set_ylabel(r"$\phi_{out}$")
        ax.set_title(r"Spline for $\phi$ coordinate")
        ax2.legend(handles=[spline, grad], labels=["spline", "log gradient"])

        ax = next(axes)
        ax.hist(ϕ_in.squeeze(), bins=20, histtype="step", label=r"$\phi_{in}$")
        ax.hist(ϕ_out.squeeze(), bins=20, histtype="step", label=r"$\phi_{out}$")
        ax.set_title(r"Histogram for $\phi$ coordinate")
        ax.legend() 

        fig.tight_layout()

        yield fig

        z_in, ϕ_in = z_out, ϕ_out

### $\kappa \approx 0$ (uniform target)

In [None]:
κ = 0.001
μ = [0, 0, 1]

model = NormalizingFlowRQS(
    κ=κ,
    μ=μ,
    n_layers=1,
    n_spline=6,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.01,
)

trainer = pl.Trainer(
    max_steps=1000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_rqs(model)]

### $\kappa = 10$, $\mu = (0, 0, 1)$, one layer of 12 segments

In [None]:
κ = 10
μ = [0, 0, 1]

model = NormalizingFlowRQS(
    κ=κ,
    μ=μ,
    n_layers=1,
    n_spline=12,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.01,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=2000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_rqs(model)]

### $\kappa = 10$, $\mu = (0, 0, 1)$, two layers of 6 segments

In [None]:
κ = 10
μ = [0, 0, 1]

model = NormalizingFlowRQS(
    κ=κ,
    μ=μ,
    n_layers=2,
    n_spline=6,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.01,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=2000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_rqs(model)]

### $\kappa = 10$, $\mu = (0, 1, 0)$, one layer of 12 segments

In [None]:
κ = 10
μ = [0, 1, 0]

model = NormalizingFlowRQS(
    κ=κ,
    μ=μ,
    n_layers=1,
    n_spline=12,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.01,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=2000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_rqs(model)]

### $\kappa = 10$, $\mu = (0, 1, 1)$, one layer of 12 segments

In [None]:
κ = 10
μ = [0, 1, 1]

model = NormalizingFlowRQS(
    κ=κ,
    μ=μ,
    n_layers=1,
    n_spline=12,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.01,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=2000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_rqs(model)]

### $\kappa = 10$, $\mu = (0, 1, 1)$, two layers of 6 segments

In [None]:
κ = 10
μ = [0, 1, 1]

model = NormalizingFlowRQS(
    κ=κ,
    μ=μ,
    n_layers=2,
    n_spline=6,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.01,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=2000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_rqs(model)]

### Save a trained model for Flow HMC

In [None]:
κ = 10
μ = [1, -1, 1]

model = NormalizingFlowRQS(
    κ=κ,
    μ=μ,
    n_layers=4,
    n_spline=6,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.001,
)

logger = TensorBoardLogger(save_dir="tb_logs", name="rq_spline")
checkpointing = ModelCheckpoint(save_last=True)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=4000,
    val_check_interval=500,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=logger,
    callbacks=[checkpointing],
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_rqs(model)]

## Mobius + C2 Splines (C2 transform)

In [None]:
@torch.no_grad()
def visualise_c2(model):

    # Inputs
    z_in = torch.linspace(-1, 1, 10000).unsqueeze(-1)
    ϕ_in = torch.linspace(1e-5, 2 * π - 1e-5, 10000).unsqueeze(-1)
    #ϕ_in = torch.linspace(0, 6 * π, 10000).unsqueeze(-1) % (2 * π)
    #ϕ_in = torch.zeros_like(z_in)
    xy_in = torch.cat([ϕ_in.cos(), ϕ_in.sin()], dim=-1)

    for z_params, xy_params in zip(
        model.z_params.split(1, dim=0),
        model.xy_params,
    ):
        z_out, ldj_z = model.spline_transform(z_in, z_params.expand(*z_in.shape, -1))

        # [-1, 1]
        omega_x, omega_y = torch.tanh(xy_params(z_out)).clamp(-1 + 1e-3, 1 - 1e-3).view(*xy_in.shape).split(1, dim=-1)
        omega_y = omega_y * torch.sqrt(1 - omega_x.pow(2))  # [-sqrt(1 - x^2), sqrt(1 - x^2)]
        omega = torch.cat([omega_x, omega_y], dim=-1)
        xy_out, ldj_xy = model.mobius_transform(xy_in, omega)

        ϕ_out = torch.fmod(torch.atan2(xy_out[..., 1], xy_out[..., 0]) + 2 * π, 2 * π)

        # plot here
        fig, axes = plt.subplots(1, 4, figsize=(16, 4))
        axes = iter(axes)

        ax = next(axes)
        ax2 = ax.twinx()
        ax2.grid(False)
        spline, = ax.plot(z_in.squeeze(), z_out.squeeze(), color="tab:blue")
        grad, = ax2.plot(z_in.squeeze(), ldj_z.squeeze(), color="tab:orange")
        ax.set_xlabel(r"$z_{in}$")
        ax.set_ylabel(r"$z_{out}$")
        ax.set_title("Spline for z coordinate")
        ax2.legend(handles=[spline, grad], labels=["spline", "log gradient"])

        ax = next(axes)
        ax.hist(z_in.squeeze(), bins=20, histtype="step", label=r"$z_{in}$")
        ax.hist(z_out.squeeze(), bins=20, histtype="step", label=r"$z_{out}$")
        ax.set_title("Histogram for z coordinate")
        ax.legend()

        ax = next(axes)
        ax2 = ax.twinx()
        ax2.grid(False)
        spline, = ax.plot(ϕ_in.squeeze(), ϕ_out.squeeze(), color="tab:blue", label="spline")
        grad, = ax2.plot(ϕ_in.squeeze(), ldj_xy.squeeze(), color="tab:orange", label="log gradient")
        ax.set_xlabel(r"$\phi_{in}$")
        ax.set_ylabel(r"$\phi_{out}$")
        ax.set_title(r"Mobius transformation for $\phi$ coordinate")
        ax2.legend(handles=[spline, grad], labels=["mobius", "log gradient"])

        ax = next(axes)
        ax.hist(ϕ_in.squeeze(), bins=20, histtype="step", label=r"$\phi_{in}$")
        ax.hist(ϕ_out.squeeze(), bins=20, histtype="step", label=r"$\phi_{out}$")
        ax.set_title(r"Histogram for $\phi$ coordinate")
        ax.legend() 

        fig.tight_layout()

        yield fig

        z_in, xy_in = z_out, xy_out

### $\kappa \approx 0$ (uniform target)

In [None]:
κ = 0.001
μ = [0, 0, 1]

model = NormalizingFlowC2(
    κ=κ,
    μ=μ,
    n_layers=1,
    n_spline=6,
    hidden_shape=[],
    activation="Identity",
    batch_size=500,
    val_batch_size=500,
    init_lr=0.01,
)

trainer = pl.Trainer(
    max_steps=1000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_c2(model)]

### $\kappa = 10$, $\mu = (0, 0, 1)$, one layer of 12 segments

In [None]:
κ = 10
μ = [0, 0, 1]

model = NormalizingFlowC2(
    κ=κ,
    μ=μ,
    n_layers=1,
    n_spline=12,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.01,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=2000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_c2(model)]

### $\kappa = 10$, $\mu = (0, 1, 0)$, one layer of 12 segments

In [None]:
κ = 10
μ = [0, 1, 0]

model = NormalizingFlowC2(
    κ=κ,
    μ=μ,
    n_layers=1,
    n_spline=3,
    hidden_shape=[],
    activation="Identity",
    batch_size=4000,
    val_batch_size=5000,
    init_lr=0.001,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=4000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_c2(model)]

### Save a trained model for Flow HMC

In [None]:
κ = 10
μ = [1, -1, 1]

model = NormalizingFlowC2(
    κ=κ,
    μ=μ,
    n_layers=4,
    n_spline=6,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.001,
)

logger = TensorBoardLogger(save_dir="tb_logs", name="c2_spline")
checkpointing = ModelCheckpoint(save_last=True)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=4000,
    val_check_interval=500,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=logger,
    callbacks=[checkpointing],
)

trainer.fit(model)

(metrics,) = trainer.test(model)

[fig for fig in visualise_c2(model)]