# Flow HMC

In [None]:
import math
from typing import TypeAlias, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.linalg as LA

from distributions import SphericalUniformPrior3D
from models import RecursiveFlowS2, DummyNormalizingFlow
from utils import batched_dot, batched_outer, batched_mv
from visualisations import line, line3d, scatter, scatter3d, pairplot, spherical_mesh

Tensor: TypeAlias = torch.Tensor

π = math.pi

sns.set_theme()

In [None]:
def debug_plots(
    z: list,
    x: list,
    p: list,
    F: list,
    H: list,
    S: list,
    ldj: list,
    *,
    n_traj: int,
    n_steps: int,
    ε: float,
    κ: float,
    μ: Tensor,
    lines: bool,
):
    z = torch.stack(z, dim=1)
    x = torch.stack(x, dim=1)
    p = torch.stack(p, dim=1)
    F = torch.stack(F, dim=1)
    H = torch.stack(H, dim=1)
    S = torch.stack(S, dim=1)
    ldj = torch.stack(ldj, dim=1)

    modz = LA.vector_norm(z, dim=-1)
    modx = LA.vector_norm(x, dim=-1)
    modp = LA.vector_norm(p, dim=-1)
    modF = LA.vector_norm(F, dim=-1)

    phi_z = torch.fmod(torch.atan2(z[..., 1], z[..., 0]) + 2 * π, 2 * π)
    phi_x = torch.fmod(torch.atan2(x[..., 1], x[..., 0]) + 2 * π, 2 * π)

    pdotz = batched_dot(p, z)
    Fdotz = batched_dot(F, z)

    grad_ldj = F - κ * μ

    nt = z.shape[1] - 1
    t = np.linspace(0, nt * ε, nt + 1)

    fig, axes = plt.subplots(
        8, 1, sharex=True, figsize=(16, 24), gridspec_kw=dict(hspace=0.1)
    )
    for ax in axes:
        [
            ax.axvline(traj_num * n_steps * ε, linestyle="--", color="grey", alpha=0.5)
            for traj_num in range(1, n_traj)
        ]

    if lines:
        ls, ls2 = "--", ":"
        m, m2 = "", ""
    else:
        ls, ls2 = "", ""
        m, m2 = "+", "x"

    axes = iter(axes)

    ax = next(axes)
    ax.set_ylabel("$z$")
    for i, zi in zip(range(1, 4), z.split(1, dim=-1)):
        ax.plot(t, zi.squeeze(), marker=m, linestyle=ls, label=f"$z_{i}$")
    ax.plot(t, modz.squeeze(), marker=m, linestyle=ls, label=r"$|\mathbf{z}|$")
    ax.legend()

    ax = next(axes)
    ax.set_ylabel("$x$")
    for i, xi in zip(range(1, 4), x.split(1, dim=-1)):
        ax.plot(t, xi.squeeze(), marker=m, linestyle=ls, label=f"$x_{i}$")
    ax.plot(t, modx.squeeze(), marker=m, linestyle=ls, label=r"$|\mathbf{x}|$")
    ax.legend()

    ax = next(axes)
    ax.set_ylabel(r"$\phi_z \mid \phi_x$")
    ax.plot(t, phi_z.squeeze(), marker=m, linestyle=ls, label=r"$\phi_z$")
    ax.plot(t, phi_x.squeeze(), marker=m2, linestyle=ls, label=r"$\phi_x$")
    ax.set_ylim(-0.5, 2 * π + 0.5)
    ax.axhline(2 * π, linestyle=":")
    ax.annotate(r"$2 \pi$", xy=(0, 2 * π), xycoords="data")
    ax.legend()

    ax = next(axes)
    ax.set_ylabel("$p$")
    for i, pi in zip(range(1, 4), p.split(1, dim=-1)):
        ax.plot(t, pi.squeeze(), marker=m, linestyle=ls, label=f"$p_{i}$")
    ax.plot(t, modp.squeeze(), marker=m, linestyle=ls, label=r"$|\mathbf{p}|$")
    ax.legend()

    ax = next(axes)
    ax.set_ylabel("$F$")
    for i, Fi in zip(range(1, 4), F.split(1, dim=-1)):
        ax.plot(t, Fi.squeeze(), marker=m, linestyle=ls, label=f"$F_{i}$")
    ax.plot(t, modF.squeeze(), marker=m, linestyle=ls, label=r"$|\mathbf{F}|$")
    ax.legend()

    ax = next(axes)
    ax.set_ylabel(r"$\log \vert \partial \mathbf{x} / \partial \mathbf{z} \vert$")
    label = r"$\log \vert \partial \mathbf{x} / \partial \mathbf{z} \vert$"
    for i, gradi in zip(range(1, 4), grad_ldj.split(1, dim=-1)):
        ax.plot(
            t,
            gradi.squeeze(),
            marker=m2,
            linestyle=ls,
            label=f"$\partial / \partial z_{i}$" + label,
        )
    ax.plot(t, ldj.squeeze(), marker=m, linestyle=ls, label=label)
    ax.legend()

    ax = next(axes)
    ax.set_ylabel("$S$")
    ax.plot(t, (S + ldj).squeeze(), marker=m, linestyle=ls, label="$S_{vMF}$")
    ax.plot(
        t,
        (-ldj).squeeze(),
        marker=m,
        linestyle=ls,
        label=r"$-\log \vert \partial \mathbf{x} / \partial \mathbf{z} \vert$",
    )
    ax.plot(t, S.squeeze(), marker=m, linestyle=ls, label=r"$\tilde{S}$")
    ax.legend()

    ax = next(axes)
    ax.set_ylabel("$H$")
    ax.plot(t, H.squeeze(), marker=m, linestyle=ls, label="$H$")
    ax.legend()

    fig2 = line(z.squeeze(), ls="", marker=".", markersize=1)
    fig3 = line3d(z.squeeze(), ls="-", lw=0.7)

    return fig, fig2, fig3

In [None]:
def visualise_forces(model: RecursiveFlowS2, inputs: Tensor, projection="aitoff"):
    """
    projection: aitoff, hammer, lambert or mollweide
    """

    θ = torch.acos(1 - 2 * torch.linspace(0, 1, bins))
    ϕ = torch.linspace(0, 2 * π, bins)
    θ, ϕ = torch.meshgrid(θ, ϕ)
    x = θ.sin() * ϕ.cos()
    y = θ.sin() * ϕ.sin()
    z = θ.cos()
    inputs = torch.stack([x, y, z], dim=-1).view(-1, 3)

    inputs.requires_grad_(True)
    inputs.grad = None
    with torch.enable_grad():
        outputs, log_dxdz = model(inputs)
        S_vMF = -model.κ * torch.mv(outputs, model.μ)
        S_eff = S_vMF - log_dxdz
        S_eff.backward(gradient=torch.ones_like(S_eff))
    F = inputs.grad.negative()
    inputs.requires_grad_(False)
    inputs.grad = None

    outputs = outputs.view(bins, bins, 3).detach()
    F = F.view(bins, bins, 3).detach()
    S_vMF = S_vMF.view(bins, bins).detach()
    S_eff = S_eff.view(bins, bins).detach()

    Fx, Fy, Fz = [Fi.squeeze() for Fi in F.split(1, dim=-1)]
    mod_F = LA.vector_norm(F, dim=-1)

    def make_heatmap(data, ax, title):
        cf = ax.pcolormesh(
            ϕ - π,
            θ - π / 2,
            data,
            cmap="viridis",
        )
        fig.colorbar(cf, ax=ax, shrink=1)
        ax.set_title(title)

    def symlog(x: Tensor):
        y = torch.empty_like(x)
        id_mask = x.abs() < 1
        y[id_mask] = x[id_mask]
        y[~id_mask] = (x[~id_mask].abs().log10() + 1) * x[~id_mask].sign()
        return y

    fig, axes = plt.subplots(
        5, 2, figsize=(12, 12), subplot_kw=dict(projection=projection)
    )
    axes = iter(axes.flatten())

    make_heatmap(S_vMF, next(axes), r"$S_{vMF}$")
    make_heatmap(S_eff, next(axes), r"$S_{vMF} - \log_e \mathcal{V}$")
    make_heatmap(Fx, next(axes), r"$F_x$")
    make_heatmap(Fx.abs().log10(), next(axes), r"$\log_{10}|F_x|$")
    # make_heatmap(symlog(Fx), next(axes), r"$\log_{10}|F_x|$")
    make_heatmap(Fy, next(axes), r"$F_y$")
    make_heatmap(Fy.abs().log10(), next(axes), r"$\log_{10}|F_y|$")
    # make_heatmap(symlog(Fy), next(axes), r"$\log_{10}|F_y|$")
    make_heatmap(Fz, next(axes), r"$F_z$")
    make_heatmap(Fz.abs().log10(), next(axes), r"$\log_{10}|F_z|$")
    # make_heatmap(symlog(Fz), next(axes), r"$\log_{10}|F_z|$")
    make_heatmap(mod_F, next(axes), r"$|\mathbf{F}|$")
    make_heatmap(mod_F.log10(), next(axes), r"$\log_{10}|\mathbf{F}|$")

    fig.tight_layout()

    # Ideally I would use flowed coordinates for visualisations,
    # but this causes issues with the heatmap
    fig2 = scatter(
        outputs.view(-1, 3),
        colours=S_vMF.view(-1),
        projection=projection,
        s=0.1,
        marker="x",
    )

    return fig, fig2

In [None]:
def visualise_forces(model: RecursiveFlowS2, bins: int = 50, projection="aitoff"):
    """
    projection: aitoff, hammer, lambert or mollweide
    """

    θ = torch.acos(1 - 2 * torch.linspace(0, 1, bins))
    ϕ = torch.linspace(0, 2 * π, bins)
    θ, ϕ = torch.meshgrid(θ, ϕ)
    x = θ.sin() * ϕ.cos()
    y = θ.sin() * ϕ.sin()
    z = θ.cos()
    inputs = torch.stack([x, y, z], dim=-1).view(-1, 3)

    inputs.requires_grad_(True)
    inputs.grad = None
    with torch.enable_grad():
        outputs, log_dxdz = model(inputs)
        S_vMF = -model.κ * torch.mv(outputs, model.μ)
        S_eff = S_vMF - log_dxdz
        S_eff.backward(gradient=torch.ones_like(S_eff))
    F = inputs.grad.negative()
    inputs.requires_grad_(False)
    inputs.grad = None

    outputs = outputs.view(bins, bins, 3).detach()
    F = F.view(bins, bins, 3).detach()
    S_vMF = S_vMF.view(bins, bins).detach()
    S_eff = S_eff.view(bins, bins).detach()

    Fx, Fy, Fz = [Fi.squeeze() for Fi in F.split(1, dim=-1)]
    mod_F = LA.vector_norm(F, dim=-1)

    def make_heatmap(data, ax, title):
        cf = ax.pcolormesh(
            ϕ - π,
            θ - π / 2,
            data,
            cmap="viridis",
        )
        fig.colorbar(cf, ax=ax, shrink=1)
        ax.set_title(title)

    def symlog(x: Tensor):
        y = torch.empty_like(x)
        id_mask = x.abs() < 1
        y[id_mask] = x[id_mask]
        y[~id_mask] = (x[~id_mask].abs().log10() + 1) * x[~id_mask].sign()
        return y

    fig, axes = plt.subplots(
        5, 2, figsize=(12, 12), subplot_kw=dict(projection=projection)
    )
    axes = iter(axes.flatten())

    make_heatmap(S_vMF, next(axes), r"$S_{vMF}$")
    make_heatmap(S_eff, next(axes), r"$S_{vMF} - \log_e \mathcal{V}$")
    make_heatmap(Fx, next(axes), r"$F_x$")
    make_heatmap(Fx.abs().log10(), next(axes), r"$\log_{10}|F_x|$")
    # make_heatmap(symlog(Fx), next(axes), r"$\log_{10}|F_x|$")
    make_heatmap(Fy, next(axes), r"$F_y$")
    make_heatmap(Fy.abs().log10(), next(axes), r"$\log_{10}|F_y|$")
    # make_heatmap(symlog(Fy), next(axes), r"$\log_{10}|F_y|$")
    make_heatmap(Fz, next(axes), r"$F_z$")
    make_heatmap(Fz.abs().log10(), next(axes), r"$\log_{10}|F_z|$")
    # make_heatmap(symlog(Fz), next(axes), r"$\log_{10}|F_z|$")
    make_heatmap(mod_F, next(axes), r"$|\mathbf{F}|$")
    make_heatmap(mod_F.log10(), next(axes), r"$\log_{10}|\mathbf{F}|$")

    fig.tight_layout()

    # Ideally I would use flowed coordinates for visualisations,
    # but this causes issues with the heatmap
    fig2 = scatter(
        outputs.view(-1, 3),
        colours=S_vMF.view(-1),
        projection=projection,
        s=0.1,
        marker="x",
    )

    return fig, fig2

In [None]:
@torch.no_grad()
def flow_hmc(
    trained_model: pl.LightningModule,
    batch_size: int,
    n_traj: int,
    traj_length: int,
    n_steps: int,
    κ: Optional[float] = None,
    μ: Optional[Tensor] = None,
    debug: bool = False,
    lines: bool = True,
):
    if debug and batch_size > 1:
        raise ValueError("debug only works with a batch size of one")

    flow = trained_model

    κ = κ if κ is not None else trained_model.κ
    if μ is not None:
        μ = μ if isinstance(μ, torch.Tensor) else torch.tensor(μ, dtype=torch.float32)
        μ.div_(LA.vector_norm(μ))
    μ = μ if μ is not None else trained_model.μ

    ε = traj_length / n_steps

    # Initial state randomly distributed on the sphere
    z0, _ = next(SphericalUniformPrior3D(batch_size))
    z0.squeeze_(dim=1)
    assert z0.shape == torch.Size([batch_size, 3])

    outputs = torch.empty(n_traj, batch_size, 3)

    def get_force(z: Tensor) -> Tensor:
        z.requires_grad_(True)
        z.grad = None
        with torch.enable_grad():
            x, log_dxdz = flow(z)
            S = -κ * torch.mv(x, μ) - log_dxdz
            assert S.shape == torch.Size([batch_size])
            S.backward(gradient=torch.ones_like(S))
        F = z.grad.negative()
        z.requires_grad_(False)
        z.grad = None
        return F

    n_accepted = 0

    # Quantities to track for debugging
    H_history = []
    F_history = []
    z_history = []
    p_history = []
    x_history = []
    S_history = []
    ldj_history = []

    for i in range(n_traj):
        z = z0.clone()

        H_history_i = []
        F_history_i = []
        z_history_i = []
        p_history_i = []
        x_history_i = []
        S_history_i = []
        ldj_history_i = []

        # Initial momenta
        p = torch.empty_like(z).normal_()
        M = torch.eye(3) - batched_outer(z, z)
        p = batched_mv(M, p)

        # assert torch.allclose(batched_mv(M, z), torch.zeros(batch_size), atol=1e-5)
        # assert torch.allclose(batched_dot(p, z), torch.zeros(batch_size), atol=1e-5)

        x, log_dxdz = flow(z)
        H0 = 0.5 * (p**2).sum(dim=1) - κ * torch.mv(x, μ) - log_dxdz

        # Begin leapfrog

        F = get_force(z)
        dpdt = batched_mv(M, F)
        p += 0.5 * ε * dpdt

        # print("(PRE) |z| = ", LA.vector_norm(z, dim=-1))
        # print("(PRE) z . π = ", batched_dot(π, z))

        for t in range(n_steps):
            if debug:
                x, log_dxdz = flow(z)
                S = -κ * torch.mv(x, μ) - log_dxdz
                pp = p - 0.5 * ε * dpdt  # move back half a step
                H = 0.5 * (pp**2).sum(dim=1) + S
                H_history_i.append(H)
                F_history_i.append(F)
                z_history_i.append(z)
                p_history_i.append(pp)
                x_history_i.append(x)
                S_history_i.append(S)
                ldj_history_i.append(log_dxdz)

            # Non-trivial coordinate update to preserve unit norm
            mod_p = LA.vector_norm(p, dim=-1, keepdim=True)
            cos_εp = torch.cos(ε * mod_p)
            sin_εp = torch.sin(ε * mod_p)
            z_tmp = cos_εp * z + (1 / mod_p) * sin_εp * p
            p = -mod_p * sin_εp * z + cos_εp * p
            z = z_tmp

            # print("|z| = ", LA.vector_norm(z, dim=-1))
            # print("z . p = ", batched_dot(p, z))

            # Re-normalise (correct for numerical errors)
            z = z / LA.vector_norm(z, dim=-1, keepdim=True)

            F = get_force(z)
            M = torch.eye(3) - batched_outer(z, z)
            dpdt = batched_mv(M, F)
            if t < n_steps - 1:
                p += ε * dpdt
            elif t == n_steps - 1:
                p += 0.5 * ε * dpdt
            else:
                raise Exception("whoops")

        x, log_dxdz = flow(z)
        HT = 0.5 * (p**2).sum(dim=1) - κ * torch.mv(x, μ) - log_dxdz

        accepted = (H0 - HT).clamp(max=0).exp() > torch.rand_like(H0)
        n_accepted += accepted.sum()

        z0[accepted] = z[accepted]

        if debug and n_accepted > 0:  # only track history if trajectory accepted
            # n.b. n_accepted = 0 or 1 since batch size must be 1 for debug
            H_history += H_history_i
            F_history += F_history_i
            z_history += z_history_i
            p_history += p_history_i
            x_history += x_history_i
            S_history += S_history_i
            ldj_history += ldj_history_i

        x, _ = flow(z0)
        outputs[i] = x

    acceptance_rate = n_accepted / (batch_size * n_traj)
    print("acceptance: ", acceptance_rate)

    if debug:
        _ = debug_plots(
            z_history,
            x_history,
            p_history,
            F_history,
            H_history,
            S_history,
            ldj_history,
            n_traj=n_traj,
            n_steps=n_steps,
            ε=ε,
            κ=κ,
            μ=μ,
            lines=lines,
        )

    return outputs.transpose(0, 1)  # (batch_size, n_traj, 3)

## Test the algorithm with a dummy model

### Uniform target

In [None]:
dummy_model = DummyNormalizingFlow()

x_from_flow_hmc = flow_hmc(
    dummy_model,
    batch_size=1,
    n_traj=10,
    traj_length=10,
    n_steps=400,
    κ=0.001,
    μ=[0, 0, 1],
    debug=True,
    lines=True,
)

plt.show()

In [None]:
with torch.no_grad():
    z, _ = next(SphericalUniformPrior3D(5000))
    z.squeeze_(dim=1)
    x_from_model, _ = dummy_model(z)

x_from_flow_hmc = flow_hmc(
    dummy_model,
    batch_size=10,
    n_traj=500,
    traj_length=1,
    n_steps=4,
    κ=0.001,
    μ=[0, 0, 1],
    debug=False,
)
# x_from_flow_hmc = x_from_flow_hmc.flatten(start_dim=0, end_dim=1)

fig1 = scatter(x_from_model, s=2)
fig1.suptitle("Data Generated by Normalizing Flow")

fig2 = scatter(x_from_flow_hmc, s=2)
fig2.suptitle("Data Generated by Flow HMC")

plt.show()

### Concentrated model

In [None]:
dummy_model = DummyNormalizingFlow()

x_from_flow_hmc = flow_hmc(
    dummy_model,
    batch_size=1,
    n_traj=10,
    traj_length=10,
    n_steps=500,
    κ=10,
    μ=[1, -1, 1],
    debug=True,
    lines=True,
)

plt.show()

In [None]:
x_from_flow_hmc = flow_hmc(
    dummy_model,
    batch_size=10,
    n_traj=500,
    traj_length=1,
    n_steps=4,
    κ=10,
    μ=[1, -1, 1],
    debug=False,
)
# x_from_flow_hmc = x_from_flow_hmc.flatten(start_dim=0, end_dim=1)

fig1 = scatter(x_from_model, s=2)
fig1.suptitle("Data Generated by Normalizing Flow")

fig2 = scatter(x_from_flow_hmc, s=2)
fig2.suptitle("Data Generated by Flow HMC")

plt.show()

## Load a trained model

In [None]:
ckpt_path = "tb_logs/test/version_6/checkpoints/last.ckpt"

model = RecursiveFlowS2.load_from_checkpoint(ckpt_path)
model.hparams

In [None]:
trainer = pl.Trainer(limit_test_batches=1, logger=False)

(metrics,) = trainer.test(model)

## Visualise the forces

In [None]:
_ = forces_heatmap(model, bins=100, projection="lambert")

## Sampling

In [None]:
with torch.no_grad():
    z, _ = next(SphericalUniformPrior3D(5000))
    z.squeeze_(dim=1)
    x_from_model, _ = model(z)

x_from_flow_hmc = flow_hmc(
    model, batch_size=10, n_traj=500, traj_length=1, n_steps=4, debug=False
)

fig1 = scatter(x_from_model, s=2)
fig1.suptitle("Data Generated by Normalizing Flow")

fig2 = scatter(x_from_flow_hmc, s=2)
fig2.suptitle("Data Generated by Flow HMC")

plt.show()

## Visualise trajectories

In [None]:
x_from_flow_hmc = flow_hmc(
    model, batch_size=1, n_traj=4, traj_length=4, n_steps=400, debug=True
)

plt.show()