# MCMC from scratch 

This notebook is a study of Markov Chain MOnte Carlo methods,
and is inspired, in part, by these great papers: 
* A basic overview of sampling methods such as AR, MH and Gibbs
  [Chib, Greenberg (1995)](http://web1.sph.emory.edu/users/hwu30/teaching/statcomp/papers/chibGreenbergMH.pdf)
. Includes a hybrid MH-AR method (whrere the proposal is sampled from using AR)
* A self contained introduction and study of Monte Carlo based on Hamiltonian dynamics
  [Betancourt (2017)](https://arxiv.org/abs/1701.02434.pdf)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np

import torch
import torch.nn.functional as F

import tqdm

Get inflated `rect` for better $2$-d plot aesthetics

In [None]:
def get_rect(data, dim=0, r=5e-2, a=1e-3):
    # deal with nans!
    dims, brdc = [*range(data.dim())], data.dim()*[1]
    dims.pop(dim)
    brdc[dim] = -1
    mask = (~torch.isfinite(data)).sum(dims) == 0
    mask = mask.reshape(brdc[:dim+1])
    
    # get the enclosing rectangle ...
    (uu, _), (ll, _) = data[mask].max(dim), data[mask].min(dim)

    # ... center it, infalte, ...
    cc = (uu + ll) / 2
    uu, ll = uu - cc, ll - cc
    uu = uu + abs(uu) * r + a
    ll = ll - abs(ll) * r - a

    # ... and translate back
    return ll + cc, uu + cc

Create a multi-modal density: we can evaluate it, know its structure,
but not the normalizing constant.

$$
p(\theta)
    = \tfrac1Z \mathop{\mathrm{exp}} \bigl(
        \log f(\theta)
    \bigr)
    \,.
$$


Let's make a banana distribution:

$$
p(x)
    \propto p_{\mathcal{N}(0, 1)} \circ \phi(x)
    \,, $$

where $\phi$ is the based on the
[Banana](https://en.wikipedia.org/wiki/Rosenbrock_function)
function and given by $
\phi
\colon \mathbb{R}^2 \to \mathbb{R}^2
\colon (x, y) \mapsto (a x, b (y-x^2))
$

In [None]:
def log_banana_base(x, a=0.75, b=1.05):
#     phi = torch.stack([a - x[..., 0], b * (x[..., 1] - x[..., 0]**2)], dim=-1)  # a=1.75, b=5
    phi = torch.stack([a * (x[..., 0] - x[..., 1]**2), b * (x[..., 1] - x[..., 0]**2)], dim=-1)
    return -0.5 * torch.norm(phi, p=2, keepdim=False, dim=-1)**2

In [None]:
def log_density(x):
    mu = map(torch.tensor, [(2., 2.), (-2., -2.)])
    a, b, s = [+0.75, -0.75], [+3.05, -1.05], [+1, -1]

    compo = map(lambda m, a, b, s: log_banana_base(s*(x - m.to(x)), a, b), mu, a, b, s)
    stacked = torch.stack([*compo], dim=0)

    return torch.logsumexp(stacked, dim=0)

In [None]:
def log_funnel(x, a=5e-1, b=1.):
#     return -0.5 * (a * x[..., 1]**2 + torch.exp(- b * x[..., 1]) * x[..., 0]**2)
    return -0.5 * (a * (torch.exp(- b * x[..., 0]) - x[..., 1])**2
                   + torch.exp(- b * x[..., 1]) * x[..., 0]**2)

In [None]:
# log_density = log_funnel

In [None]:
# def log_density(x):
#     return -0.5 * torch.norm(x, p=2, keepdim=False, dim=-1)**2

And a plot of it

In [None]:
mesh = torch.meshgrid(2*[torch.linspace(-5, +16, 101)])
marg = torch.stack(mesh, dim=-1).flatten(0, -2)

z = torch.exp(log_density(marg))

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, xlabel=r"$\theta_1$", ylabel=r"$\theta_2$",
                    title="2x'blob' density")

ax.contourf(*mesh, z.reshape_as(mesh[0]), levels=51, cmap=plt.cm.terrain)

plt.show()

Plot the gradient field

In [None]:
theta = marg.clone().requires_grad_(True)
log_density(theta).mean().backward()

dz = theta.grad.reshape(*mesh[0].shape, -1)

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, xlabel=r"$\theta_1$", ylabel=r"$\theta_2$",
                    title="2x'blob' density")

ax.contourf(*mesh, z.reshape_as(mesh[0]), levels=51, cmap=plt.cm.terrain)

if True:
    ax.quiver(*mesh, dz[..., 0], dz[..., 1], pivot='mid',
              color="fuchsia", scale_units='xy', angles='uv', scale=.5, alpha=0.5)


plt.show()

<br>

Now let's create some samplers!

In [None]:
class Proposal:
    def sample(self, n_samples=1, *, at=None):
        raise NotImplementedError()

    def log_prob(self, x, *, at=None):
        raise NotImplementedError()

Simulate paths

In [None]:
def sample_chain(sampler, n_steps=501, n_chains=15):
    chain = [torch.randn(n_chains, 2) * 15]
#     chain = [torch.randn(1, 2).repeat(n_chains, 1) * 15]
    for _ in tqdm.trange(n_steps):
        chain.append(sampler.sample(1, at=chain[-1]))

    return torch.stack(chain, dim=-2)

A path plotter using quiver.

In [None]:
def plot_path(path, ax=None, **kwargs):
    ax = plt.gca() if ax is None else ax

    uv, xy = path[1:] - path[:-1], path[:-1]

    size = np.linalg.norm(uv, axis=-1, keepdims=False)
    stuck = path[1:][size == 0]

    ax.scatter(stuck[:, 0], stuck[:, 1], c="k", alpha=0.1)

    return ax.quiver(xy[:, 0], xy[:, 1], uv[:, 0], uv[:, 1],
                     scale_units='xy', angles='xy', scale=1.,
                     **kwargs)

Check out the Markov Chain paths

In [None]:
def plot_chains(paths, log_density):
    fig = plt.figure(figsize=(16, 16))
    fig.patch.set_alpha(1.0)

    # get the mesh
    ll, ur = get_rect(paths.flatten(0, -2))
    mesh = torch.meshgrid(*map(torch.linspace, ll, ur, [201, 201]))
    marg = torch.stack(mesh, dim=-1).flatten(0, -2)
    z = torch.exp(log_density(marg))

    ax = fig.add_subplot(111, xlabel=r"$\theta_1$", ylabel=r"$\theta_2$")

    ax.contourf(*mesh, z.reshape_as(mesh[0]), levels=21,
               cmap=plt.cm.terrain, alpha=0.05, zorder=10)

    colours = plt.cm.Accent(np.linspace(0, 1, num=len(paths)))
    for i, col in enumerate(colours):
        pts = paths[i].numpy()
        plot_path(pts, color=col, alpha=.5)
    #     ax.scatter(pts[:, 0], pts[:, 1], color=col, s=10, alpha=0.05)

    plt.show()

<br>

#### Notation

Consider a measurable space $(\Omega, \mathcal{F}, \mu)$.
The Markov Chain sampler needs:
* *(proposal)* the transition `kernel`
$P \colon \Omega \times \mathcal{F} \to [0, 1]$

In MC-MC we typically consider kernels over Lebsegue carrier measures
$\mu = d{x}$ and defined by

$$
P(x, d\omega)
    = q(\omega \vert x) \mu(d\omega) + r(x) \delta_x(d\omega)
    \,, $$

where $Q(\bullet \vert x) = q(\cdot \vert x) \mu(d\omega)$ is a nonnegative measure
with $Q(\Omega\vert x) \leq 1$ and  $q(x \vert x) = 0$.

What does this notation mean?

Well, $\delta_x(\cdot)$ defines a probability measure on the
measurable space $(\Omega, \mathcal{F})$ according to

$$
\delta_x
\colon \mathcal{F} \to [0, 1]
\colon A \mapsto 1_A(x)
    \,. $$

This notation is a shorthand for (like in SDE)
$$
P(x, dy) = q(y \vert x) dy + r(x) \delta_x(dy)
    \Leftrightarrow
    P(x, A)
        = \int_A q(y \vert x) dy + r(x) \delta_x(dy)
        = Q(A \vert x) + r(x) 1_A(x)
    \,, $$

This implies that for $P(x, \cdot)$ to be a probability measure
we need $r(x) = 1 - Q(\Omega \vert x)$.

#### Did i really need this proof? NO!

$\delta_x(\emptyset) = 1_\emptyset(x) = 0$, and
$$
\delta_x\bigl(\biguplus_{n\geq 1} A_n\bigr)
    = 1_{\uplus_{n\geq 1} A_n}(x)
    = \begin{cases}
    1 & \exists{n\geq1}\, x \in A_n\\
    0 & 
    \end{cases}
    = \sum_{n\geq 1} 1_{A_n}(x)
    = \sum_{n\geq 1} \delta_x(A_n)
    \,. $$

Via the MCT this implies that $\int f \delta_x(d\omega) = f(x)$

In [None]:
class Target(Proposal):
    def __init__(self, fn):
        self.fn = fn

    def log_prob(self, x, *, at=None):
        return self.fn(x)

In [None]:
class IndependentProposal(Proposal):
    def __init__(self, distribution):
        self.distribution = distribution
    
    def log_prob(self, x, *, at=None):
        return self.distribution.log_prob(x)

    def sample(self, n_samples=1, *, at):
        *head, n_features = at.shape
        return self.distribution.sample((*head, n_samples)).squeeze(-2)

In [None]:
class RandomWalkProposal(Proposal):
    def __init__(self, distribution):
        self.distribution = distribution

    def log_prob(self, x, *, at):
        return self.distribution.log_prob(x - at)

    def sample(self, n_samples=1, *, at):
        *head, n_features = at.shape

        step = self.distribution.sample((*head, n_samples))
        return (at.unsqueeze(-2) + step).squeeze(-2)

To evaluate gradients of a scalar function in bulk, observe the following:
$$
(\nabla_\theta f(\theta_i))_{i}
    \colon (\theta_i)_{i} \mapsto
        \bigl(
            \tfrac{\partial}{\partial \theta_j}
                \sum_i f(\theta_i)
        \bigr)_{ji} \mathbf{1}
    \,. $$

In [None]:
def batch_grad(x, f, **kwargs):
    theta = x.clone().requires_grad_(True)
    f(theta, **kwargs).sum().backward()
    return theta.grad

```python
def f(x, q):
    return - log_density(x) + 0.5 * torch.norm(q, p=2, dim=-1)**2

from torch.autograd import grad

x = torch.randn(10, 2).requires_grad_(True)
q = torch.randn_like(x).requires_grad_(True)

par = x, q
dx, dq = grad(f(*par).sum(), par, create_graph=False)
```

A handy function for clipping a tenspr's norm:
$$ 
    x \mapsto x \min\bigl\{\tfrac{C}{\| x \|_p}, 1 \bigr\}
    \,. $$

In [None]:
def clip_norm(x, max_norm, p=2, dim=-1, inplace=True):
    mul = torch.Tensor.mul_ if inplace else torch.Tensor.mul

    # an ell-p norm grad clipping step
    norm = torch.norm(x, p=p, dim=dim, keepdim=True)
    x = mul(x, torch.clamp(max_norm / norm, max=1))

    return x

<br>

### Langevin proposal

The SDE

$$
dX_t
    = \mu_t(X_t) dt
    + \sigma_t(X_t) dW_t
    \,, $$

can be numerically approximated using Euler-Maruyama method: for $\delta > 0$

$$
X_{t+\delta} - X_t
    = \mu_t(X_t) \delta
    + \sigma_t(X_t) \sqrt{\delta} \xi_t
    \,, $$

where $\xi_t \sim \mathcal{N}(0, I)$.

Let's consider Langevin's Îto diffusion:

$$
d\theta_t
    = \nabla_\theta \log \pi(\theta_t) dt
    + \sqrt{2} dW_t
    \,. $$

The Euler-Maruyama step with $\delta > 0$ yields the following
finite-difference approximation

$$
\theta_{t + \delta}
    = \theta_t
    + \delta \nabla_\theta \log \pi(\theta_t)
    + \sqrt{2 \delta} \xi_t
    \,. $$

Thus the proposal is 

$$
q(\theta \vert x)
    = \mathcal{N}\bigl(
        \theta \,\big\vert\,
        x + \delta \nabla_\theta \log \pi(\theta) \vert_{\theta=x},
        2 \delta I
    \bigr)
    \,. $$

Recall that, if $x = \mu+ \sigma \xi$ and $\xi \sim p_\xi$, then 
$$
p_x(x)
    = \tfrac1\sigma p_\xi(\tfrac{x - \mu}\sigma)
    \,, \text{ or }\,
    p_\xi(\xi)
        = \sigma p_x(\mu + \sigma \xi)
    \,. $$

In [None]:
import math

class LangevinProposal(Proposal):  # , RandomWalkProposal):
    def __init__(self, proposal, target, delta=1e-2, clip_grad=0):
        self.target, self.delta = target, delta
        self.proposal, self.clip_grad = proposal, clip_grad

    def log_prob(self, x, *, at):
        z = (x - self.mu(at=at)) / math.sqrt(2 * self.delta)

        return self.proposal.log_prob(z) - 0.5 * math.log(2 * self.delta)
    
    def mu(self, *, at):
        grad = batch_grad(at, self.target.log_prob)

        # an ell-2 norm grad clipping step
        if self.clip_grad > 0:
            grad = clip_norm(grad, self.clip_grad, p=2, dim=-1)

        return at + grad * self.delta

    def sample(self, n_samples=1, *, at):
        *head, n_features = at.shape

        step = math.sqrt(2 * self.delta) * self.proposal.sample((*head, n_samples))
        return (self.mu(at=at).unsqueeze(-2) + step).squeeze(-2)

<br>

$x \in \Omega^{b_1 \times \ldots \times b_p}$ for $\Omega \subseteq \mathbb{R}$
and any transition kernel must 

The MH kernel is given by the density:
$$
P(x, dy)
    = p(x, y) dy + r(x) \delta_x(dy)
    \,. $$

<br>

### Metropolis-Hastings Proposal

For some transition density (or pmf) $p(y \vert x)$ the transition kernel
$$
P(x, dy)
%     = p(dy \vert x) + r(x) \delta_x(dy)
    = p(y \vert x) dy + r(x) \delta_x(dy)
    \,, $$
with $r(x) = 1 - \int_\Omega p(dy \vert x) \leq 1$ has $\pi$ as the stationary
distribution if $p$ complies with the detailed (micro-) balance condition:
$
\pi(dx) p(dy \vert x) = \pi(dy) p(dx \vert y)
$.

Indeed, we have
$$
\begin{align}
\int_\Omega \pi(dx) P(x, B)
    &= \int_\Omega \pi(dx) \int_B P(x, dy)
    = \int_\Omega \int_B p(dy \vert x) \pi(dx)
      + \int_\Omega \int_B r(x) \delta_x(dy) \pi(dx)
    \\
    &= \int^x_\Omega \int^y_B p(dy \vert x) \pi(dx)
      + \int_\Omega r(x) 1_B(x) \pi(dx)
    = \int^y_B \int^x_\Omega p(dx \vert y) \pi(dy) + \int_B r(x) \pi(dx)
    \\
    &= \int_B (1 - r(y)) \pi(dy) + \int_B r(x) \pi(dx)
    = \pi(B)
    \,,
\end{align}$$

Thus for a chosen proposal density $q(dy \vert x)$ we need to find
a transition density $p(dy \vert x)$, that satisfies the balance.

**(heuristic)** Let's use importance sampling (analogue of): introduce a rv that
controls the transitions and adjusts the resulting `density` so that is has the
needed mass.

Consider $
p(y\vert x)
    = q(y\vert x) \alpha(y \vert x)
$, where $\alpha$ enforces `reversibility`:
$$
\pi(dx) q(dy\vert x) \alpha(y \vert x)
    = \pi(dy) q(dx\vert y) \alpha(x \vert y)
    \,. $$

Assuming $\pi(dx) = \pi(x) dx$ and $q(dy \vert x) = q(y\vert x) dy$ we
may observe the following:
* If $\pi(x) q(y\vert x) > \pi(y) q(x\vert y)$ then $
\alpha(y \vert x)
    = \tfrac{\pi(y) q(x\vert y)}{\pi(x) q(y\vert x)}
$ and $\alpha(x \vert y) = 1$

* If $\pi(x) q(y\vert x) < \pi(y) q(x\vert y)$ then $
\alpha(x \vert y)
    = \tfrac{\pi(x) q(y\vert x)}{\pi(y) q(x\vert y)}
$ and $\alpha(y \vert x) = 1$

Thus the sought $\alpha$ is
$$
\alpha(y\vert x)
    = \min\Bigl\{
        1, \frac{
            \overbrace{
                \pi(y) q(x\vert y)
            }^{
                y \to x
            }
        }{
            \underbrace{\pi(x) q(y\vert x)}_{
                x \to y
            }
        }
    \Bigr\}
    \,. $$

It is called the *probability of move*

Thus the basic step of MH MC sampler is: given $x_t$ do
1. sample $y \sim q(y \vert x_t)$
2. independently draw $u \sim \mathrm{U}[0, 1]$ and
    * set $x_{t+1} = y$ if $u \leq \alpha(y\vert x_t)$
    * put $x_{t+1} = x_t$ otherwise

In [None]:
class MetropolisHastingsProposal(Proposal):
    def __init__(self, proposal, target):
        self.target, self.proposal = target, proposal

    def sample(self, n_samples=1, *, at):
        *head, n_features = at.shape
        head = [1] if not head else head

        p, q = self.target, self.proposal

        assert n_samples == 1
        prop, curr = q.sample(n_samples, at=at), at
        # curr = curr.unsqueeze(-2)  # DON'T : dimension growth! but gives branching paths!

        # \log \pi(prop) q(curr \vert prop)
        log_alpha  = q.log_prob(curr, at=prop) + p.log_prob(prop)

        # - \log \pi(curr) q(prop \vert curr)
        log_alpha -= q.log_prob(prop, at=curr) + p.log_prob(curr)

        alpha = torch.exp(torch.clamp(log_alpha, max=0))

        accept = torch.rand_like(log_alpha) <= alpha

        # print(accept.float().mean(), alpha)
        self.alpha_, self.log_alpha_ = alpha, log_alpha

        return torch.where(accept.unsqueeze(-1), prop, curr)


From [Betancourt (2017) p.11](https://arxiv.org/abs/1701.02434)
* **(first phase)** MC travels towards the typical set. MC-based
estimators have strong bias (*burn-in*)

* **(second phase)** MC `persists through the first sojourn across the typical set`.
Accuracy of estimators improves as the bias from burn-in dampens

* **(third phase)** MC continues and gradually refines its exploration
of the typical set.

Use a random walk transition kernel

In [None]:
from  torch.distributions import MultivariateNormal

proposal = RandomWalkProposal(MultivariateNormal(torch.zeros(2), torch.eye(2)))
rwmh = MetropolisHastingsProposal(proposal, Target(log_density))

plot_chains(sample_chain(rwmh, n_chains=15), log_density)

In [None]:
gauss = MultivariateNormal(torch.zeros(2), torch.eye(2))
# ldp = LangevinProposal(gauss, Target(log_density), delta=1e-1, clip_grad=1e1)
ldp = LangevinProposal(gauss, Target(log_density), delta=1e-2, clip_grad=5e2)
ldmh = MetropolisHastingsProposal(ldp, Target(log_density))

plot_chains(sample_chain(ldmh, n_steps=501, n_chains=15), log_density)

From [Betancourt (2017) pp. 16-19](https://arxiv.org/abs/1701.02434)

Although MH with Random Walk proposal (aka RandomWalk Metropolis, RWMH) is simple
and intuitive clear, it drammatically suffers from the curse of dimensionality
and the complexity of the target distribution.

> ... the volume exterior to the typical set overwhelms the interior volume
and almost every RWMH chain gets stuck outside of the typical set towards the
tails, due to low acceptance rate, induced by negligible densities. ... In the
worst case RWMH won't even complete a single sojourn through the typical set.

The idea behind Hamiltonian MC is to use first order information,
$\nabla_\theta \log \pi(\theta)$, about the target distribution $\pi(\theta)$
to make informed moves towards the typical set. However, by itself
the gradient pulls towrads a mode of $\pi(\theta)$ and would make
the chain collapse in it, which is not the typical set (what is it then?)

So in HMC we consider $\pi(\theta, m) = \pi(\theta) p(m \vert \theta)$
where we have introduced auxiliary random varaible, momentum $m$. This
`lifts the target distribution onto a joint probability distribution on
pahse space` If the momentum is marginalized, then the original target
density is recovred, which means that during sampling we can simply discard
$m$ when requesting a sampel of $\theta$.

Consider $\pi(\theta, m) = \exp\bigl\{ - H(\theta, m) \bigr\}$. Then for a
trajectory $t\mapsto (\theta_t, m_t)$ so stay on the level set of $H$
we must have $\tfrac{d}{d t} H(\theta_t, m_t) = 0$, i.e.

$$
dH = \nabla_\theta^\top H(\theta_t, m_t) \dot{\theta}_t dt
    + \nabla_m^\top H(\theta_t, m_t) \dot{m}_t dt
    = 0
    \,, $$

which is satisfied when
$$
\dot{\theta}_t = \nabla_m H(\theta_t, m_t)
    \,,\,
    \dot{m}_t = \nabla_\theta H(\theta_t, m_t)
    \,. $$

Typically the Hamiltonian is decomposed as

$$
H(\theta, m)
    = - \log \pi(\theta) p(m \vert \theta)
    = \underbrace{- \log \pi(\theta)}_{\text{potential}}
    + \bigl(
        \underbrace{- \log p(m \vert \theta)}_{\text{kinetic}}
    \bigr)
    \,. $$

From [Betancourt (2017) pp. 27](https://arxiv.org/abs/1701.02434)

The choice of the kinetic energy term is what determines the interaction if the Chain with the target.
Herein lies the scope of HMC design.

The randomness in HMC scheme comes from the randomness of the `momentun lift`,
whereas the traversal of the level of $H$ is deterministic.

Since the HMC trajectory journies along the level set in the phase-space,
it is natural to factorize the canonical distribution $\pi(\theta, m)$,
byt foliating into concentric level sets:

$$
\pi(\theta, m)
    = \pi(x_E \vert H(x_E) = E)
    \, \pi(H(x_E) = E)
    \,. $$

Lifts determine jumps beteween the energy levels, while each tajectory explores
the corresponding level set $\{H(\theta, m) = E\}$. Thus we get two phases

* deterministic traversal of energy level sets (how long we integrate)
* stochastic exploration between level sets (how quickly jumps `diffuse accross energies typical to the energy marginal distribution`

<br>

The following kinetic energy is called Euclidean-Gaussian if $G(\theta) = \Sigma$,
and Riemannian-Gaussian if $G(\theta) = $.

* we can optimize $\Sigma$ using the extendend burn-in pahse
[Betancourt (2017) pp. 31](https://arxiv.org/abs/1701.02434)

* if $G(\theta)$ resembles the Hessian of the target, i.e. its Fisher info-matrix,
then the energy level exploration would be uniform and efficient

<br>

Choosing the integration time: if we integrate for a short period of time,
we risk having insufficient diversity on the samples as they would tend to
clump together. On the other hand, long integration times can degrade exploration
in case when the level sets are `topologically` compact (is there any other
notion of compactness?).

Dynamic Ergodicity for orbits $\phi = \{(\theta_t, m_t)\colon t \geq 0\}$ states
that a uniform temporal sample form a trajectory resemples a uniform spatial sample
from $\phi$.

<br>

Simplectic integrators [Betancourt (2017) pp. 36](https://arxiv.org/abs/1701.02434)

### Let's try to integrate something

The Cauchy problem for ODE $$
\dot{x}
    = A x\,, x(0) = x_0
\,. $$

In [None]:
def euler(x, f, eps=1e-3):
    return x + f(x) * eps

In [None]:
grad = lambda x: np.dot(x, np.array([
    [-1., +2.],
    [-10., -1.],
]))

grad = lambda x: np.dot(x, np.array([
    [0, +1.],
    [-1., 0],
]))

In [None]:
paths = [np.random.randn(5, 2)]
for _ in range(100):
    paths.append(euler(paths[-1], grad, eps=1e-1))

paths = np.stack(paths, axis=-2)

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)

colours = plt.cm.Reds(np.linspace(0.5, 1, num=len(paths)))
for i, col in enumerate(colours):
    plot_path(paths[i], color=col, alpha=1)

In [None]:
np_mesh = np.meshgrid(
    np.linspace(-5, +5, num=101),
    np.linspace(-5, +5, num=101),
)
np_marg = np.stack(np_mesh, axis=-1).reshape(-1, 2)

In [None]:
uv = grad(np_marg).reshape(*np_mesh[0].shape, -1)

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, xlabel=r"$\theta_1$", ylabel=r"$\theta_2$",
                    title="2x'blob' density")

ax.quiver(*np_mesh, uv[..., 0], uv[..., 1], pivot='mid',
          color="fuchsia", scale=2000., alpha=0.5)


plt.show()

<br>

### For Hamiltonian MonteCarlo we need a hamiltonian

Lets introduce an auxiliary variable $q \approx \dot{\theta}$ with
$p(q\vert \theta) = \mathcal{N}(q \vert 0, G(\theta))$. Then the
joint density of $(\theta, q)$

$$
p(\theta, q)
    = p(\theta) p(q\vert \theta)
    \propto \exp{(- H(\theta, q))}
    \,. $$

$$
H(\theta, q)
    = - \ell(\theta)
    + \tfrac12\log\det 2 \pi G(\theta)
    + \tfrac12 q^\top G(\theta)^{-1} q
    \,, $$
with $\ell(\theta) = \log p(\theta)$.

Hamiltonian dynamics in the case of $q$ behaving like a momentum
of $\theta$ is
$$
\begin{align}
    \dot{\theta}
        =& \partial_q H(\theta, q)
        \,, \\
   - \dot{q}
        =& \partial_\theta H(\theta, q)
        \,,
\end{align}
$$

Now the derivative of $\log-\det$ is

$$
\partial_\theta \log \det G(\theta)
    = - \partial_\theta \log \det G(\theta)^{-1}
    = - \mathrm{tr}
        \tfrac1{\det G(\theta)^{-1}} (\det G(\theta)^{-1})
            G(\theta)^\top \partial_\theta G(\theta)^{-1}
    = - \mathrm{tr} G(\theta) \partial_\theta G(\theta)^{-1}
    \,, $$

and the second term is

$$
\partial_\theta \tfrac12 q^\top G(\theta)^{-1} q
    = \tfrac12 \mathrm{tr} q q^\top \partial_\theta G(\theta)^{-1}
    \,. $$

And $
\partial_\theta G(\theta)^{-1}
    = - G(\theta)^{-1}
        \bigl( \partial_\theta G(\theta) \bigr)
    G(\theta)^{-1}
$

The dynamics is thus

$$
\begin{align}
    \dot{\theta}
        =& G(\theta)^{-1} q
        \,, \\
    - \dot{q}
%         =& - \partial_\theta \ell(\theta)
%            + \tfrac12 \mathrm{tr} \bigl(
%                q q^\top - G(\theta)
%            \bigr) \partial_\theta G(\theta)^{-1}
        =& - \partial_\theta \ell(\theta)
           - \tfrac12 \mathrm{tr} G(\theta)^{-1} \bigl(
               q q^\top - G(\theta)
           \bigr) G(\theta)^{-1}
           \partial_\theta G(\theta)
        \,,
\end{align}
$$

For $G(\theta) = \Sigma$ we have

$$
\dot{\theta} = \Sigma^{-1} q
    \,, \dot{q} = \partial_\theta \ell(\theta)
    \,, $$

<br>

$$
\ddot{x} = F(x)
    \Leftrightarrow
    \dot{x} = v
    \,, \dot{v} = F(x)
\,. $$

$$
\begin{pmatrix}
    \dot{q} \\ \dot{p}
\end{pmatrix}
    = \begin{pmatrix}
        0 & \partial_p \\
        - \partial_q & 0
    \end{pmatrix}
    H(q, p)
    \,, $$

For Euclidean-Gaussian kinetic energy we have

$$
H(\theta, q)
    = - \log p(\theta)
    + \tfrac12\log\det 2 \pi \Sigma
    + \tfrac12 q^\top \Sigma^{-1} q
    \,. $$

In [None]:
def hamiltonian(x, q):
    return - log_density(x) + 0.5 * torch.norm(q, p=2, dim=-1)**2

Hence the dynamics obeys

$$
\dot{\theta} = \Sigma^{-1} q
    \,,\, \dot{q} = - \nabla_\theta H(\theta, q) = \nabla_\theta \log p(\theta)
    \,. $$

For $\tau = s t$, $s = \pm 1$ we have:

$$
\frac{d x}{d \tau}
    = \frac{d x}{d t} \frac{d t}{d \tau}
    % = \frac{d x}{d t} \frac1s
    = s \frac{d x}{d t}
    \,, $$

whence the dynamics becomes

$$
\dot{\theta} = \Sigma^{-1} s q
    \,,\, \dot{q} = s \nabla_\theta \log p(\theta)
    \,. $$

In [None]:
def integrate_(H, x, q, *, epsilon=1e-3, M=None, inplace=False, clip_grad=0):
    # applies the one-step leapforg integrator for \ddot{x} = - M \nabla_x H(x)
    r"""Leap-frog integrator for \dot{x} = q, \dot{q} = - M \nabla_x H(x)"""
    add = torch.Tensor.add_ if inplace else torch.Tensor.add
    sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
    if not inplace:
        x, q = x.clone(), q.clone()

    grad = batch_grad(x, H, q=q)
    while True:
        # q_{\tfrac12} = q_0 - \nabla_x H(x_0) \tfrac\epsilon2
        q = sub(q, grad * epsilon / 2)

        # x_1 = x_0 + M q_{\tfrac12} \epsilon
        x = add(x, epsilon * (q if M is None else torch.matmul(q, M)))

        # q_1 = q_{\tfrac12} - \nabla_x H(x_1) \tfrac\epsilon2
        grad = batch_grad(x, H, q=q)
        if clip_grad > 0:
            grad = clip_norm(grad, clip_grad, p=2, dim=-1)

        q = sub(q, grad * epsilon / 2)

        # preemptively flip the sign of the momentum to ease time-reversibility
        yield x, -q

    # (todo) add stopped process t \wedge n_j
    # (todo) understand what NUTS does and it does
    # (todo) investigate rare NANs
    # (todo) add time direction

<br>

The Hamiltinian Metropolis Hastings proposal is the same as MH
but with varaites in phase-space and 

$$
\alpha(y\vert x)
    = \min\Bigl\{
        1, \frac{
            p(\theta_y, -q_y)
        }{
            p(\theta_x, q_x)
        } \frac{
            h_T((\theta_x, q_x) \vert (\theta_y, -q_y))
        }{
            h_T((\theta_y, -q_y) \vert (\theta_x, q_x))
        }
    \Bigr\}
    \,, $$

where $
h_T((\theta, q) \vert (\theta_0, q_0))
    = \delta_{\theta - \theta_T} \delta_{q - (-q_T)}
$
and $(\theta_T, q_T)$ is the integral of the Hamiltonian ODE at $T$

$$
\begin{align}
    \dot{\theta}
        =& \partial_q H(\theta, q)
        \,, \\
   \dot{q}
        =& - \partial_\theta H(\theta, q)
        \,,
\end{align}
$$

with initial conditions $(\theta_0, q_0)$. Thus

$$
\alpha((\theta_T, -q_T)\vert (\theta_0, q_0))
\alpha_T(\theta_0, q_0)
    = \min\Bigl\{
        1, \frac{
            p(\theta_T, -q_T)
        }{
            p(\theta_0, q_0)
        } \frac{
            h_T((\theta_0, q_0) \vert (\theta_T, -q_T))
        }{
            h_T((\theta_T, -q_T) \vert (\theta_0, q_0))
        }
    \Bigr\}
%     = \min\Bigl\{
%         1, \frac{
%             \exp{-H(\theta_T, -q_T)}
%         }{
%             \exp{-H(\theta_0, q_0)}
%         }
    = \min\Bigl\{
        1, \exp{\{H(\theta_0, q_0) - H(\theta_T, -q_T)\}}
    \Bigr\}
    \,. $$

The need for mnegation is to make sure that momentum is flipped and,
since hamiltonian dynamics is deterministic and time-reversible, that
when integrating back in time from $(\theta_T, -q_T)$ we recover $(\theta_0, q_0)$.

In [None]:
class HamiltonianMHProposal(Proposal):
    def __init__(self, proposal, target, time=2e-2, delta=1e-3, clip_grad=0):
        self.target, self.proposal = target, proposal
        self.delta, self.time = delta, time
        self.clip_grad = clip_grad
    
    def log_prob(self, x, q, *, at=None):
        return self.target.log_prob(x) + self.proposal.log_prob(q, at=x)

    def hamiltonian(self, x, q, *, at=None):
        return - self.log_prob(x, q, at=at)

    def sample(self, n_samples=1, *, at):
        *head, n_features = at.shape
        head = [1] if not head else head

        assert n_samples == 1
        # sample the initial momentum q (from kinetic)
        curr = at, self.proposal.sample(n_samples, at=at)

        # integrate \dot{x} = q, \dot{q} = -\nabla_x H(x, q)
        #  from x_0, q_0 until T flipping the sign
        integrator = integrate_(self.hamiltonian, *curr,
                                epsilon=self.delta, inplace=False,
                                clip_grad=self.clip_grad)
        for _, prop in zip(range(-int(-self.time // self.delta)), integrator):
            # consume the integrator until the required number of steps is done
            pass
        del integrator

        # \log q(curr \vert prop) \pi(prop) - \log q(prop \vert curr) \pi(curr)
        log_alpha = self.log_prob(*prop) - self.log_prob(*curr)
        alpha = torch.exp(torch.clamp(log_alpha, max=0))
        accept = torch.rand_like(log_alpha) < alpha

        self.alpha_, self.log_alpha_ = alpha, log_alpha

        return torch.where(accept.unsqueeze(-1), prop[0], curr[0])

<br>

Let's integrate some!

In [None]:
x = torch.randn(255, 2)
q = torch.randn_like(x)
until, epsilon = 0.5, 5e-3

int_ = integrate_(hamiltonian, x, q, epsilon=epsilon, inplace=False)
path = [(x, q)]
path.extend((x, q) for _, (x, q) in zip(tqdm.trange(-int(-until // epsilon)), int_))

In [None]:
x, q = zip(*path)
paths = torch.stack(x, dim=-2)

In [None]:
ll, ur = get_rect(paths.flatten(0, -2))

mesh = torch.meshgrid(*map(torch.linspace, ll, ur, [201, 201]))
marg = torch.stack(mesh, dim=-1).flatten(0, -2)

z = torch.exp(log_density(marg))

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)

ax.contourf(*mesh, z.reshape_as(mesh[0]), levels=51, cmap=plt.cm.terrain)

colours = plt.cm.Reds(np.linspace(0.5, 1, num=len(paths)))
for i, col in enumerate(colours):
    plot_path(paths[i][-100:], color=col, alpha=0.5)

Simulate paths

In [None]:
from  torch.distributions import MultivariateNormal

idp = IndependentProposal(MultivariateNormal(torch.zeros(2), torch.eye(2)))
hmh = HamiltonianMHProposal(idp, Target(log_density), time=2e-1, delta=1e-2, clip_grad=100)

plot_chains(sample_chain(hmh, n_chains=15), log_density)

<br>

In [None]:
assert False

In [None]:
# run the sampler
from pyhmc import hmc

In [None]:
# !pip install git+https://github.com/rmcgibbo/pyhmc

In [None]:
def logprob(x):
    x = torch.from_numpy(x)

    logp = log_density(x)
    grad = batch_grad(x, log_density)

    return logp.numpy(), grad.numpy()

In [None]:
from pyhmc import hmc

In [None]:
samples = hmc(logprob, x0=np.random.randn(2), args=(), n_samples=10000)

In [None]:
plot_chains(torch.from_numpy(samples[np.newaxis]), log_density)

<br>

In [None]:
paths[0][torch.isfinite(paths[0])]

In [None]:
mask = (~torch.isfinite(paths)).sum(dim=-1).sum(dim=-1)
paths[mask > 0]

<br>

In [None]:
def leap_frog(nabla, q, p, grad=None, eps=0.01):
    r"""Leap-frog integrator for \dot{p} = - \nabla_q V(q), \dot{q} = p"""
    # p_{\tfrac12} = p_0 - \tfrac\epsilon2 \nabla_q V(q_0)
    grad = nabla(q) if grad is None else grad  # .detach()
    p.sub_(- eps * grad / 2)

    # q_1 = q_0 + \epsilon p_{\tfrac12}
    q.add_(eps * p)

    # p_1 = p_{\tfrac12} - \tfrac\epsilon2 \nabla_q V(q_1)
    grad = nabla(q)
    p.sub_(- eps * grad / 2)
    return p, q, grad


In [None]:
mu = torch.tensor([
    [-3., -3.],
    [+3., -3.],
    [+3., +3.],
])

In [None]:
cov = torch.tensor[]

In [None]:
torch.triangular_solve()

In [None]:
def log_prob(value, mu, loc):
    pass

In [None]:
a = Normal(mu, 1.)

In [None]:
target.rsample((10,)).mean(0)

In [None]:
target.log_prob??

In [None]:
target.scale

In [None]:
q = mu[0].clone().requires_grad_(True)
target.log_prob(q).mean(0)

In [None]:
x = torch.randn(15, 2)
q = torch.randn_like(x)
eps = 5e-2

# applies the one-step leapforg integrator for \ddot{x} = \nabla_x \ell(x)
r"""Leap-frog integrator for \dot{x} = q, \dot{q} = \nabla_x \ell(x)"""

paths = [x]
grad = batch_grad(x, log_density)
for _ in range(2500):
    # (todo) add stopped process t \wedge n_j
    # (todo) understand what NUTS does and it does
    # (todo) investigate rare NANs

    # leap-frog integrator of \ddot{x} = f(x)
    # p_{\tfrac12} = p_0 + \nabla_q V(q_0) \tfrac\epsilon2
    q_half = q + grad * eps / 2

    # q_1 = q_0 + p_{\tfrac12} \epsilon
    x = x + q_half * eps

    # p_1 = p_{\tfrac12} + \nabla_q V(q_1) \tfrac\epsilon2
    grad = batch_grad(x, log_density)
    q = q_half + grad * eps / 2

    paths.append(x)

paths = torch.stack(paths, dim=-2)

In [None]:
until, epsilon = 1., 1e-3

In [None]:
x = torch.randn(15, 2)
q = torch.randn_like(x)

In [None]:
x, q

In [None]:
integrator = integrate_(hamiltonian, x, q, epsilon=epsilon, inplace=True)
for _ in zip(range(-int(-until // epsilon)), integrator):
    pass

In [None]:
paths = sample_chain(ldmh, n_chains=15)

In [None]:
paths[0]

In [None]:
clip_grad = 10.
grad = batch_grad(paths[0], ldmh.target.log_prob)