# Normalizing Flow

In [None]:
import math
from typing import TypeAlias

import matplotlib.pyplot as plt
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import seaborn as sns
import torch
import torch.linalg as LA

from distributions import SphericalUniformPrior3D
from models import RecursiveFlowS2
from transforms import (
    MobiusMixtureTransform,
    RQSplineTransform,
    RQSplineTransformCircularDomain,
    BSplineTransform,
)
from utils import (
    metropolis_acceptance,
    effective_sample_size,
    spherical_mesh,
    simple_fnn_conditioner,
)
from visualisations import scatter, pairplot, heatmap, spherical_mesh

Tensor: TypeAlias = torch.Tensor

π = math.pi

sns.set_theme()

In [None]:
?RecursiveFlowS2

## Testing

In [None]:
model = RecursiveFlowS2(
    κ=10,
    μ=[1, -1, 1],
    # z_transformer=RQSplineTransform([-1, 1], 12),
    z_transformer=BSplineTransform([-1, 1], 12),
    xy_transformer=RQSplineTransformCircularDomain(12),
    # xy_transformer=MobiusMixtureTransform(8),
    n_layers=2,
    batch_size=1000,
    net_hidden_shape=[],
    init_lr=0.01,
)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=2000,
    val_check_interval=2000,
    check_val_every_n_epoch=None,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
    enable_checkpointing=False,
)

trainer.fit(model)

(metrics,) = trainer.test(model)

with torch.no_grad():
    z, _ = next(model.val_dataloader())
    z.squeeze_(dim=1)
    x, ldj = model(z)

fig = scatter(x, s=2)
fig.suptitle("Data Generated by Normalizing Flow")

## Save a trained model

In [None]:
model = RecursiveFlowS2(
    κ=20,
    μ=[1, 0, .5],
    z_transformer=RQSplineTransform([-1, 1], 12),
    # z_transformer=BSplineTransform([-1, 1], 12),
    xy_transformer=RQSplineTransformCircularDomain(12),
    # xy_transformer=MobiusMixtureTransform(8),
    n_layers=2,
    batch_size=4000,
    net_hidden_shape=[],
    init_lr=0.01,
    softmax_beta=1e12,
)

logger = TensorBoardLogger(save_dir="tb_logs", name="test")
checkpointing = ModelCheckpoint(save_last=True)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=5000,
    val_check_interval=500,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=logger,
    callbacks=[checkpointing],
)

trainer.fit(model)

(metrics,) = trainer.test(model)