# Flow HMC

In [None]:
import math
from typing import TypeAlias, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.linalg as LA

from distributions import SphericalUniformPrior3D
from models import NormalizingFlowRQS, NormalizingFlowC2, DummyNormalizingFlow
from transforms import MobiusTransform, RQSplineTransform, RQSplineTransformCircularDomain, C2SplineTransform
from utils import metropolis_acceptance, effective_sample_size, spherical_mesh, simple_fnn_conditioner, batched_dot, batched_outer, batched_mv
from visualisations import scatter, pairplot

Tensor: TypeAlias = torch.Tensor

π = math.pi

sns.set_theme()

In [None]:
def debug_plots(
    z: list,
    x: list,
    p: list,
    F: list,
    H: list,
    S: list,
    ldj: list,
    *,
    n_traj: int,
    n_steps: int,
    ε: float,
    κ: float,
    μ: Tensor,
    lines: bool,
):
    z = torch.stack(z, dim=1)
    x = torch.stack(x, dim=1)
    p = torch.stack(p, dim=1)
    F = torch.stack(F, dim=1)
    H = torch.stack(H, dim=1)
    S = torch.stack(S, dim=1)
    ldj = torch.stack(ldj, dim=1)
    
    modz = LA.vector_norm(z, dim=-1)
    modx = LA.vector_norm(x, dim=-1)
    modp = LA.vector_norm(p, dim=-1)
    modF = LA.vector_norm(F, dim=-1)
    
    phi_z = torch.fmod(torch.atan2(z[..., 1], z[..., 0]) + 2 * π, 2 * π)
    phi_x = torch.fmod(torch.atan2(x[..., 1], x[..., 0]) + 2 * π, 2 * π)
    
    pdotz = batched_dot(p, z)
    Fdotz = batched_dot(F, z)
    
    grad_ldj = F - κ * μ
        
    nt = z.shape[1] - 1
    t = np.linspace(0, nt * ε, nt + 1)
    
    fig, axes = plt.subplots(6, 1, sharex=True, figsize=(16, 24), gridspec_kw=dict(hspace=0.1))
    for ax in axes:
        [ax.axvline(traj_num * n_steps * ε, linestyle="--", color="grey", alpha=0.5) for traj_num in range(1, n_traj)]
    
    #axes[0].set_xticks(np.arange(n_traj) * n_steps * ε)
    
    if lines:
        ls, ls2 = "-", ":"
        m, m2 = "", ""
    else:
        ls, ls2 = "", ""
        m, m2 = "+", "x"
    
    axes = iter(axes)
    
    ax = next(axes)
    ax.set_ylabel("$z \mid x$")
    for i, zi in zip(range(1, 4), z.split(1, dim=-1)):
        ax.plot(t, zi.squeeze(), marker=m, linestyle=ls, label=f"$z_{i}$")
    ax.plot(t, modz.squeeze(), marker=m, linestyle=ls, label=r"$|\mathbf{z}|$")
    ax.set_prop_cycle(None)
    for i, xi in zip(range(1, 4), x.split(1, dim=-1)):
        ax.plot(t, xi.squeeze(), marker=m2, linestyle=ls2, label=f"$x_{i}$")
    ax.plot(t, modx.squeeze(), marker=m2, linestyle=ls2, label=r"$|\mathbf{x}|$")
    ax.legend(ncols=2)
    
    ax = next(axes)
    ax.set_ylabel(r"$\phi_z \mid \phi_x$")
    ax.plot(t, phi_z.squeeze(), marker=m, linestyle=ls, label=r"$\phi_z$")
    ax.plot(t, phi_x.squeeze(), marker=m2, linestyle=ls, label=r"$\phi_x$")
    ax.set_ylim(-0.5, 2 * π + 0.5)
    ax.axhline(2 * π, linestyle=":")
    ax.annotate(r"$2 \pi$", xy=(0, 2 * π), xycoords="data")
    ax.legend()
    
    ax = next(axes)
    ax.set_ylabel("$p$")
    for i, pi in zip(range(1, 4), p.split(1, dim=-1)):
        ax.plot(t, pi.squeeze(), marker=m, linestyle=ls, label=f"$p_{i}$")
    ax.plot(t, modp.squeeze(), marker=m, linestyle=ls, label=r"$|\mathbf{p}|$")
    ax.legend()
    
    ax = next(axes)
    ax.set_ylabel("$S \mid H$")
    ax.plot(t, S.squeeze(), marker=m, linestyle=ls, label="$S$")
    ax.plot(t, H.squeeze(), marker=m2, linestyle=ls, label="$H$")
    ax.legend()
    
    ax = next(axes)
    ax.set_ylabel("$F$")
    for i, Fi in zip(range(1, 4), F.split(1, dim=-1)):
        ax.plot(t, Fi.squeeze(), marker=m, linestyle=ls, label=f"$F_{i}$")
    ax.plot(t, modF.squeeze(), marker=m, linestyle=ls, label=r"$|\mathbf{F}|$")
    ax.legend()
    
    ax = next(axes)
    ax.set_ylabel(r"$\log \vert \partial \mathbf{x} / \partial \mathbf{z} \vert$")
    label = r"$\log \vert \partial \mathbf{x} / \partial \mathbf{z} \vert$"
    ax.plot(t, ldj.squeeze(), marker=m, linestyle=ls, label=label)
    for i, gradi in zip(range(1, 4), grad_ldj.split(1, dim=-1)):
        ax.plot(t, gradi.squeeze(), marker=m2, linestyle=ls, label=f"$\partial / \partial z_{i}$" + label)
    ax.legend()

    
    return fig

In [None]:
@torch.no_grad()
def flow_hmc(
    trained_model: pl.LightningModule, batch_size: int, n_traj: int, traj_length: int, n_steps: int, κ: Optional[float] = None, μ: Optional[Tensor] = None, debug: bool = False, lines: bool = True,
):
    flow = trained_model

    κ = κ if κ is not None else trained_model.κ
    if μ is not None:
        μ = μ if isinstance(μ, torch.Tensor) else torch.tensor(μ, dtype=torch.float32)
        μ.div_(LA.vector_norm(μ))
    μ = μ if μ is not None else trained_model.μ

    ε = traj_length / n_steps

    # Initial state randomly distributed on the sphere
    z0, _ = next(SphericalUniformPrior3D(batch_size))
    z0.squeeze_(dim=1)
    assert z0.shape == torch.Size([batch_size, 3])

    outputs = torch.empty(n_traj, batch_size, 3)

    def get_force(z: Tensor) -> Tensor:
        z.requires_grad_(True)
        z.grad = None
        with torch.enable_grad():
            x, log_dxdz = flow(z)
            S = -κ * torch.mv(x, μ) - log_dxdz
            assert S.shape == torch.Size([batch_size])
            S.backward(gradient=torch.ones_like(S))
        F = z.grad.negative()
        z.requires_grad_(False)
        z.grad = None
        
        # For dummy flow these should be equal
        #print("F - κμ", F - κ * μ)
        
        return F

    n_accepted = 0
    
    # Quantities to track for debugging
    H_history = []
    F_history = []
    z_history = []
    p_history = []
    x_history = []
    S_history = []
    ldj_history = []

    for i in range(n_traj):

        z = z0.clone()
        
        # Initial momenta
        p = torch.empty_like(z).normal_()
        M = torch.eye(3) - batched_outer(z, z)
        p = batched_mv(M, p)
        
        #assert torch.allclose(batched_mv(M, z), torch.zeros(batch_size), atol=1e-5)
        #assert torch.allclose(batched_dot(p, z), torch.zeros(batch_size), atol=1e-5)

        x, log_dxdz = flow(z)
        S0 = -κ * torch.mv(x, μ) - log_dxdz
        H0 = 0.5 * (p ** 2).sum(dim=1) + S0
        
        # Begin leapfrog

        F = get_force(z)
        p += 0.5 * ε * batched_mv(M, F)

        # print("(PRE) |z| = ", LA.vector_norm(z, dim=-1))
        # print("(PRE) z . π = ", batched_dot(π, z))

        if i == 0 and debug:
            H_history.append(H0)
            F_history.append(F)
            z_history.append(z)
            p_history.append(p)
            x_history.append(x)
            S_history.append(S0)
            ldj_history.append(log_dxdz)
        
        for t in range(n_steps):

            # Non-trivial coordinate update to preserve unit norm
            mod_p = LA.vector_norm(p, dim=-1, keepdim=True)
            cos_εp = torch.cos(ε * mod_p)
            sin_εp = torch.sin(ε * mod_p)
            z_tmp = cos_εp * z + (1 / mod_p) * sin_εp * p
            p = -mod_p * sin_εp * z + cos_εp * p
            z = z_tmp

            #print("|z| = ", LA.vector_norm(z, dim=-1))
            #print("z . p = ", batched_dot(p, z))

            F = get_force(z)
            M = torch.eye(3) - batched_outer(z, z)
            if t == n_steps - 1:
                p += 0.5 * ε * batched_mv(M, F)
            else:
                p += ε * batched_mv(M, F)
                
            if debug:
                x, log_dxdz = flow(z)
                S = -κ * torch.mv(x, μ) - log_dxdz
                H = 0.5 * (p ** 2).sum(dim=1) + S
                H_history.append(H)
                F_history.append(F)
                z_history.append(z)
                p_history.append(p)
                x_history.append(x)
                S_history.append(S)
                ldj_history.append(log_dxdz)

        x, log_dxdz = flow(z)
        HT = 0.5 * (p ** 2).sum(dim=1) - κ * torch.mv(x, μ) - log_dxdz

        accepted = (H0 - HT).clamp(max=0).exp() > torch.rand_like(H0)
        n_accepted += accepted.sum()

        z0[accepted] = z[accepted]
        
        x, _ = flow(z0)
        outputs[i] = x

    acceptance_rate = n_accepted / (batch_size * n_traj)
    print("acceptance: ", acceptance_rate)
    
    if debug:
        _ = debug_plots(
            z_history,
            x_history,
            p_history,
            F_history,
            H_history,
            S_history,
            ldj_history,
            n_traj=n_traj,
            n_steps=n_steps,
            ε=ε,
            κ=κ,
            μ=μ,
            lines=lines,
        )
        
    return outputs.transpose(0, 1)  # (batch_size, n_traj, 3)

## Test the algorithm with a dummy model

### Uniform target

In [None]:
dummy_model = DummyNormalizingFlow()
    
x_from_flow_hmc = flow_hmc(dummy_model, batch_size=1, n_traj=8, traj_length=1, n_steps=50, κ = 0.001, μ = [0, 0, 1], debug=True, lines=True)

plt.show()

In [None]:
with torch.no_grad():
    z, _ = next(SphericalUniformPrior3D(5000))
    z.squeeze_(dim=1)
    x_from_model, _ = dummy_model(z)

x_from_flow_hmc = flow_hmc(dummy_model, batch_size=10, n_traj=500, traj_length=1, n_steps=4, κ = 0.001, μ = [0, 0, 1], debug=False)
#x_from_flow_hmc = x_from_flow_hmc.flatten(start_dim=0, end_dim=1)

fig1 = scatter(x_from_model, s=2)
fig1.suptitle("Data Generated by Normalizing Flow")

fig2 = scatter(x_from_flow_hmc, s=2)
fig2.suptitle("Data Generated by Flow HMC")

plt.show()

### Concentrated model

In [None]:
dummy_model = DummyNormalizingFlow()

x_from_flow_hmc = flow_hmc(dummy_model, batch_size=1, n_traj=8, traj_length=1, n_steps=50, κ = 10, μ = [1, -1, 1], debug=True, lines=True)

plt.show()

In [None]:
x_from_flow_hmc = flow_hmc(dummy_model, batch_size=10, n_traj=500, traj_length=1, n_steps=4, κ = 10, μ = [1, -1, 1], debug=False)
#x_from_flow_hmc = x_from_flow_hmc.flatten(start_dim=0, end_dim=1)

fig1 = scatter(x_from_model, s=2)
fig1.suptitle("Data Generated by Normalizing Flow")

fig2 = scatter(x_from_flow_hmc, s=2)
fig2.suptitle("Data Generated by Flow HMC")

plt.show()

## Rational Quadratic Splines

In [None]:
ckpt_path = "tb_logs/rq_spline/version_1/checkpoints/last.ckpt"

model = NormalizingFlowRQS.load_from_checkpoint(ckpt_path)
print(model.hparams)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=4000,
    val_check_interval=500,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

(metrics,) = trainer.test(model)

In [None]:
x_from_flow_hmc = flow_hmc(model, batch_size=1, n_traj=8, traj_length=1, n_steps=100, debug=True)

plt.show()

In [None]:
with torch.no_grad():
    z, _ = next(SphericalUniformPrior3D(5000))
    z.squeeze_(dim=1)
    x_from_model, _ = model(z)

x_from_flow_hmc = flow_hmc(model, batch_size=10, n_traj=500, traj_length=1, n_steps=10, debug=False)
#x_from_flow_hmc = x_from_flow_hmc.flatten(start_dim=0, end_dim=1)

fig1 = scatter(x_from_model, s=2)
fig1.suptitle("Data Generated by Normalizing Flow")

fig2 = scatter(x_from_flow_hmc, s=2)
fig2.suptitle("Data Generated by Flow HMC")

plt.show()

## C2 Splines

In [None]:
ckpt_path = "tb_logs/c2_spline/version_1/checkpoints/last.ckpt"

model = NormalizingFlowC2.load_from_checkpoint(ckpt_path)
print(model.hparams)

trainer = pl.Trainer(
    accelerator="auto",
    max_steps=4000,
    val_check_interval=500,
    limit_val_batches=1,
    limit_test_batches=1,
    num_sanity_val_steps=1,
    logger=False,
)

(metrics,) = trainer.test(model)

In [None]:
x_from_flow_hmc = flow_hmc(model, batch_size=1, n_traj=8, traj_length=1, n_steps=100, debug=True)

plt.show()

In [None]:
with torch.no_grad():
    z, _ = next(SphericalUniformPrior3D(5000))
    z.squeeze_(dim=1)
    x_from_model, _ = model(z)

x_from_flow_hmc = flow_hmc(model, batch_size=10, n_traj=500, traj_length=1, n_steps=10, debug=False)
#x_from_flow_hmc = x_from_flow_hmc.flatten(start_dim=0, end_dim=1)

fig1 = scatter(x_from_model, s=2)
fig1.suptitle("Data Generated by Normalizing Flow")

fig2 = scatter(x_from_flow_hmc, s=2)
fig2.suptitle("Data Generated by Flow HMC")

plt.show()