# Hybrid Monte Carlo

We first define a Hamiltonian by augmenting the action $S(\mathbf{x}) \equiv -\log f_2(\mathbf{x})$ with auxiliary Gaussian degrees of freedom $\mathbf{p} \in \mathbb{R}^3$, which play the role of conjugate momenta,

$$
H(\mathbf{x}, \mathbf{p} \mid \kappa, \boldsymbol{\mu}) = \frac{1}{2} \mathbf{p} \cdot \mathbf{p} - \kappa \boldsymbol{\mu} \cdot \mathbf{x}
$$

At all times the system is subject to the following constraints,

$$
\mathbf{x} \cdot \mathbf{x} = 1 \, , \qquad
\mathbf{p} \cdot \mathbf{x} = 0 \, .
$$

The system can be evolved along a trajectory of constant $H$ by integrating Hamilton's equations of motion,

$$
\begin{align}
    \frac{\mathrm{d}}{\mathrm{d} t}
    \begin{pmatrix} \mathbf{x} \\ \mathbf{p} \end{pmatrix}
    = \begin{pmatrix} \mathbf{0} & \mathbf{1} \\ -\mathbf{1} & \mathbf{0} \end{pmatrix}
    \begin{pmatrix} \mathcal{P}_\mathbf{x} \nabla_\mathbf{x} H \\ \nabla_\mathbf{p} H \end{pmatrix}
    = \begin{pmatrix} \mathbf{p} \\ \mathcal{P}_\mathbf{x} \kappa \boldsymbol{\mu} \end{pmatrix}
\end{align}
$$

where

$$
\mathcal{P}_\mathbf{x} \equiv (\mathbb{I} - \mathbf{x} \mathbf{x}^\top)
$$

is a matrix that projects out any components that are not orthogonal to $\mathbf{x}$.

In a numerical HMC simulation these equations of motion are integrated using finite-different updates.
We will use the Leapfrog algorithm, which accumulates finite-step errors proportional to the square of the step size, $\epsilon^2$.
A full update is composed of three steps:

$$
\begin{split}
\mathbf{p}^{(t+\epsilon/2)} &= \mathbf{p}^{(t)} + \frac{\epsilon}{2} \mathcal{P}_{\mathbf{x}^{(t)}} \kappa \boldsymbol{\mu} \\
\\
\begin{pmatrix} \mathbf{x}^{(t+\epsilon)} \\ \tilde{\mathbf{p}}^{(t+\epsilon/2)} \end{pmatrix}
   &= \begin{pmatrix} 
        \cos \epsilon \lvert \mathbf{p}^{(t+\epsilon/2)} \rvert
        & \frac{1}{\lvert \mathbf{p}^{(t+\epsilon/2)} \rvert} \sin \epsilon \lvert \mathbf{p}^{(t+\epsilon/2)} \rvert \\
        - \lvert \mathbf{p}^{(t+\epsilon/2)} \rvert \sin  \epsilon \lvert \mathbf{p}^{(t+\epsilon/2)} \rvert
        & \cos \epsilon \lvert \mathbf{p}^{(t+\epsilon/2)} \rvert
    \end{pmatrix}
\begin{pmatrix} \mathbf{x}^{(t)} \\ \mathbf{p}^{(t+\epsilon/2)} \end{pmatrix} \\
\\
\mathbf{p}^{(t+\epsilon)} &= \tilde{\mathbf{p}}^{(t+\epsilon/2)} + \frac{\epsilon}{2} \mathcal{P}_{\mathbf{x}^{(t+\epsilon)}} \kappa \boldsymbol{\mu}
\end{split}
$$

The middle step, which updates both coordinates and momenta, is the full solution to the equations of motion with no forces ($\kappa = 0$ i.e. a uniform target density) subject to the constraints.
It can be viewed as an orthogonal transformation which rotates $\mathbf{x}$ (such that it remains on the sphere) and $\mathbf{p}$ (such that it remains in the tangent plane).

In [None]:
from random import random
from typing import TypeAlias

import torch
import torch.linalg as LA

from distributions import SphericalUniformPrior3D
from visualisations import pairplot, heatmap, line3d

Tensor: TypeAlias = torch.Tensor

In [None]:
@torch.no_grad()
def hmc(
    n_traj: int,
    traj_length: int,
    n_steps: int,
    κ: float,
    μ: Tensor = Tensor([0.0, 0.0, 1.0]),
) -> Tensor:
    # Make sure μ is a Tensor and correctly normalised
    μ = μ if isinstance(μ, torch.Tensor) else torch.tensor(μ, dtype=torch.float32)
    μ.div_(LA.vector_norm(μ))

    # Step size
    ε = traj_length / n_steps

    # Initial state randomly distributed on the sphere
    x0, _ = next(SphericalUniformPrior3D(1))
    x0.squeeze_()
    assert x0.shape == torch.Size([3])

    sample = torch.empty(n_traj, 3)

    n_accepted = 0

    for i in range(n_traj):
        # Coordinates at t=0
        x = x0.clone()

        # Momenta at t=0 (project out part that is not orthogonal to x)
        p = torch.empty_like(x).normal_()
        M = torch.eye(3) - torch.outer(x, x)
        p = M @ p

        # assert torch.allclose(M @ x, torch.zeros(3), atol=1e-4), f"{M @ x}"
        # assert torch.allclose(torch.dot(x, p), torch.zeros(3), atol=1e-4), f"{torch.dot(x, p)}"

        # Hamiltonian at t=0
        H0 = 0.5 * torch.dot(p, p) - κ * torch.dot(μ, x)

        # ------------- Begin leapfrog ----------- #

        p += 0.5 * ε * M @ (κ * μ)

        for t in range(n_steps):
            # Simultaneous update of coordinates and momenta in absence of forces
            mod_p = LA.vector_norm(p)
            cos_εp = torch.cos(ε * mod_p)
            sin_εp = torch.sin(ε * mod_p)
            x_tmp = cos_εp * x + (1 / mod_p) * sin_εp * p
            p = -mod_p * sin_εp * x + cos_εp * p
            x = x_tmp

            # print("|x| = ", LA.vector_norm(x).item(), "x . p = ", torch.dot(p, x).item())

            # Momentum update using forces
            M = torch.eye(3) - torch.outer(x, x)
            if t == n_steps - 1:
                p += 0.5 * ε * M @ (κ * μ)
            else:
                p += ε * M @ (κ * μ)

        # ---------- End leapfrog ------------- #

        HT = 0.5 * torch.dot(p, p) - κ * torch.dot(μ, x)

        if HT < H0 or (H0 - HT).exp() > random():
            n_accepted += 1
            x0 = x

        sample[i] = x0.clone()

    print("acceptance: ", n_accepted / n_traj)

    max_dev = (LA.vector_norm(sample, dim=-1) - 1).abs().max()
    print(f"largest deviation from unit sphere: ||x| - 1| = {max_dev}")

    return sample


_ = hmc(10, 1, 10, 1)

## Uniform target ($\kappa = 0.001$)

In [None]:
x_hmc = hmc(n_traj=100, traj_length=1, n_steps=1, κ=0.001, μ=[0, 0, 1])
_ = line3d(x_hmc, lw=0.5)

In [None]:
x_hmc = hmc(n_traj=10000, traj_length=1, n_steps=1, κ=0.001, μ=[0, 0, 1])
pairplot(x_hmc)

## Concentrated target ($\kappa = 10$, $\mu = (0, 0, 1)$)

In [None]:
x_hmc = hmc(n_traj=10000, traj_length=1, n_steps=4, κ=10, μ=[0, 0, 1])
_ = line3d(x_hmc, lw=0.5)

In [None]:
x_hmc = hmc(n_traj=10000, traj_length=1, n_steps=4, κ=10, μ=[0, 0, 1])
_ = pairplot(x_hmc)

## Concentrated target ($\kappa = 10$, $\mu = (1, 1, 0)$)

In [None]:
x_hmc = hmc(n_traj=10000, traj_length=1, n_steps=4, κ=10, μ=[1, 1, 0])

_ = pairplot(x_hmc)
_ = heatmap(x_hmc)