# Normalizing Flow

In [None]:
import inspect
import math
from typing import TypeAlias

import matplotlib.pyplot as plt
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import seaborn as sns
import torch
import torch.linalg as LA

from distributions import SphericalUniformPrior3D
from models import (
    NormalizingFlowRQSRQS,
    NormalizingFlowRQSMobius,
    NormalizingFlowBSRQS,
    NormalizingFlowBSMobius,
)
from utils import (
    metropolis_acceptance,
    effective_sample_size,
    spherical_mesh,
    simple_fnn_conditioner,
)
from visualisations import scatter, pairplot, heatmap, spherical_mesh

Tensor: TypeAlias = torch.Tensor

π = math.pi

sns.set_theme()

## Testing

In [None]:
ModelClass = NormalizingFlowRQSRQS
# ModelClass = NormalizingFlowRQSMobius
# ModelClass = NormalizingFlowBSRQS
# ModelClass = NormalizingFlowBSMobius

print(inspect.signature(ModelClass))

In [None]:
model = ModelClass(
    κ=10,
    μ=[1, 0, 0],
    n_layers=2,
    n_spline=12,
    # n_mobius=8,
    hidden_shape=[],
    activation="Identity",
    batch_size=1000,
    val_batch_size=5000,
    init_lr=0.01,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=2000,
    val_check_interval=50,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
    enable_checkpointing=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

with torch.no_grad():
    z, _ = next(model.val_dataloader())
    z.squeeze_(dim=1)
    x, ldj = model(z)

fig = scatter(x, s=2)
fig.suptitle("Data Generated by Normalizing Flow")

## Save a trained model

In [None]:
ModelClass = NormalizingFlowRQSRQS

model = ModelClass(
    κ=10,
    μ=[1, -1, 1],
    n_layers=2,
    n_spline=12,
    # n_mobius=8,
    hidden_shape=[],
    activation="Identity",
    batch_size=5000,
    val_batch_size=5000,
    init_lr=0.01,
)

logger = TensorBoardLogger(
    save_dir="tb_logs", name=ModelClass.__name__.replace("NormalizingFlow", "").lower()
)
checkpointing = ModelCheckpoint(save_last=True)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=5000,
    val_check_interval=500,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=logger,
    callbacks=[checkpointing],
)

trainer.fit(model)

(metrics,) = trainer.test(model)