# MCMC from scratch 

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np

import torch
import torch.nn.functional as F

Create a multi-modal density: we can evaluate it, know its structure,
but not the normalizing constant.

$$
p(\theta)
    = \tfrac1Z \mathop{\mathrm{exp}} \bigl(
        \log f(\theta)
    \bigr)
    \,.
$$


Let's make a banana distribution:

$$
p(x)
    \propto p_{\mathcal{N}(0, 1)} \circ \phi(x)
    \,, $$

where $\phi$ is the based on the
[Banana](https://en.wikipedia.org/wiki/Rosenbrock_function)
function and given by $
\phi
\colon \mathbb{R}^2 \to \mathbb{R}^2
\colon (x, y) \mapsto (a x, b (y-x^2))
$

In [None]:
def log_banana_base(x, a=0.75, b=1.05):
#     phi = torch.stack([a - x[..., 0], b * (x[..., 1] - x[..., 0]**2)], dim=-1)  # a=1.75, b=5
    phi = torch.stack([a * (x[..., 0] - x[..., 1]**2), b * (x[..., 1] - x[..., 0]**2)], dim=-1)
    return -0.5 * torch.norm(phi, p=2, keepdim=False, dim=-1)**2

In [None]:
def log_density(x):
    mu = map(torch.tensor, [(2., 2.), (-2., -2.)])
    a, b, s = [+0.75, -0.75], [+3.05, -1.05], [+1, -1]

    compo = map(lambda m, a, b, s: log_banana_base(s*(x - m), a, b), mu, a, b, s)
    stacked = torch.stack([*compo], dim=0)

    return torch.logsumexp(stacked, dim=0)

And a plot of it

In [None]:
mesh = torch.meshgrid(2*[torch.linspace(-6, +6, 101)])

marg = torch.stack(mesh, dim=-1).flatten(0, -2)

z = torch.exp(log_density(marg))

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, xlabel=r"$\theta_1$", ylabel=r"$\theta_2$",
                    title="2x'blob' density")

ax.contourf(*mesh, z.reshape_as(mesh[0]), levels=51, cmap=plt.cm.terrain)

plt.show()

Plot the gradient field

In [None]:
theta = marg.clone().requires_grad_(True)
log_density(theta).mean().backward()

dz = theta.grad.reshape(*mesh[0].shape, -1)

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, xlabel=r"$\theta_1$", ylabel=r"$\theta_2$",
                    title="2x'blob' density")

ax.contourf(*mesh, z.reshape_as(mesh[0]), levels=51, cmap=plt.cm.terrain)

if True:
    ax.quiver(*mesh, dz[..., 0], dz[..., 1], pivot='mid',
              color="fuchsia", scale=.5, alpha=0.5)


plt.show()

Now let's create some samplers!

In [None]:
class Proposal:
    def sample(self, x=None, n_samples=1):
        raise NotImplementedError()

    def log_prob(self, x, at=None):
        raise NotImplementedError()

#### Notation

Consider a measurable space $(\Omega, \mathcal{F}, \mu)$.
The Markov Chain sampler needs:
* *(proposal)* the transition `kernel`
$P \colon \Omega \times \mathcal{F} \to [0, 1]$

In MC-MC we typically consider kernels over Lebsegue carrier measures
$\mu = d{x}$ and defined by

$$
P(x, d\omega)
    = q(\omega \vert x) \mu(d\omega) + r(x) \delta_x(d\omega)
    \,, $$

where $Q(\bullet \vert x) = q(\cdot \vert x) \mu(d\omega)$ is a nonnegative measure
with $Q(\Omega\vert x) \leq 1$ and  $q(x \vert x) = 0$.

What does this notation mean?

Well, $\delta_x(\cdot)$ defines a probability measure on the
measurable space $(\Omega, \mathcal{F})$ according to

$$
\delta_x
\colon \mathcal{F} \to [0, 1]
\colon A \mapsto 1_A(x)
    \,. $$

This notation is a shorthand for (like in SDE)
$$
P(x, dy) = q(y \vert x) dy + r(x) \delta_x(dy)
    \Leftrightarrow
    P(x, A)
        = \int_A q(y \vert x) dy + r(x) \delta_x(dy)
        = Q(A \vert x) + r(x) 1_A(x)
    $$

This implies that for $P(x, \cdot)$ to be a probability measure
we need $r(x) = 1 - Q(\Omega \vert x)$.

#### Did i really ned this proof? NO!
$\delta_x(\emptyset) = 1_\emptyset(x) = 0$, and
$$
\delta_x\bigl(\biguplus_{n\geq 1} A_n\bigr)
    = 1_{\uplus_{n\geq 1} A_n}(x)
    = \begin{cases}
    1 & \exists{n\geq1}\, x \in A_n\\
    0 & 
    \end{cases}
    = \sum_{n\geq 1} 1_{A_n}(x)
    = \sum_{n\geq 1} \delta_x(A_n)
    \,. $$

Via the MCT this implies that $\int f \delta_x(d\omega) = f(x)$

In [None]:
class Target(Proposal):
    def __init__(self, fn):
        self.fn = fn

    def log_prob(self, x, at=None):
        return self.fn(x)

In [None]:
class RandomWalkProposal(Proposal):
    def __init__(self, distribution):
        self.distribution = distribution
    
    def log_prob(self, x, at):
        return self.distribution.log_prob(x - at)

    def sample(self, x, n_samples=1):
        *head, n_features = x.shape

        step = self.distribution.sample((*head, n_samples))
        return (x.unsqueeze(-2) + step).squeeze(-2)

In [None]:
class MetropolisHastingsProposal(Proposal):
    def __init__(self, proposal, target):
        self.target, self.proposal = target, proposal

    def sample(self, x, n_samples=1):
        *head, n_features = x.shape
        head = [1] if not head else head
        
        p, q = self.target, self.proposal

        x_next = q.sample(x, 1)

        log_alpha = p.log_prob(x_next) + q.log_prob(x_next, x)
        log_alpha -= p.log_prob(x) + q.log_prob(x, x_next)
        alpha = torch.exp(torch.clamp(log_alpha, max=0))

        accept = torch.rand_like(log_alpha) < alpha

        return torch.where(accept.unsqueeze(-1), x_next, x)

Use a random walk transition kernel

In [None]:
from  torch.distributions import MultivariateNormal

gauss = MultivariateNormal(torch.zeros(2), torch.eye(2))
rwp = RandomWalkProposal(gauss)

In [None]:
mhp = MetropolisHastingsProposal(rwp, Target(log_density))

A path plotter using quiver.

In [None]:
def plot_path(path, ax=None, **kwargs):
    ax = plt.gca() if ax is None else ax

    uv, xy = path[1:] - path[:-1], path[:-1]
    return ax.quiver(xy[:, 0], xy[:, 1], uv[:, 0], uv[:, 1],
                     scale_units='xy', angles='xy', scale=1.,
                     **kwargs)

In [None]:
paths = [torch.randn(5, 2) * 5]
for _ in range(500):
    paths.append(mhp.sample(paths[-1]))

paths = torch.stack(paths, dim=-2)

Checkout the Markov Chain paths

In [None]:
ll = paths.flatten(0, -2).min(0)[0]
ur = paths.flatten(0, -2).max(0)[0]

mesh = torch.meshgrid(*map(torch.linspace, ll, ur, [201, 201]))
marg = torch.stack(mesh, dim=-1).flatten(0, -2)

z = torch.exp(log_density(marg))

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, xlabel=r"$\theta_1$", ylabel=r"$\theta_2$")

ax.contourf(*mesh, z.reshape_as(mesh[0]), levels=21,
           cmap=plt.cm.terrain, alpha=0.05, zorder=10)

colours = plt.cm.Accent(np.linspace(0, 1, num=len(paths)))
for i, col in enumerate(colours):
    pts = paths[i].numpy()
    plot_path(pts, color=col, alpha=.5)
#     ax.scatter(pts[:, 0], pts[:, 1], color=col, s=10, alpha=0.05)
    

plt.show()

$x \in \Omega^{b_1 \times \ldots \times b_p}$ for $\Omega \subseteq \mathbb{R}$
and any transition kernel must 

The MH kernel is given by the density:
$$
P(x, dy)
    = p(x, y) dy + r(x) \delta_x(dy)
    \,. $$

In [None]:
assert False

<br>

In [None]:
def leap_frog(nabla, q, p, grad=None, eps=0.01):
    r"""Leap-frog integrator for \dot{p} = - \nabla_q V(q), \dot{q} = p"""
    # p_{\tfrac12} = p_0 - \tfrac\epsilon2 \nabla_q V(q_0)
    grad = nabla(q) if grad is None else grad  # .detach()
    p.sub_(- eps * grad / 2)

    # q_1 = q_0 + \epsilon p_{\tfrac12}
    q.add_(eps * p)

    # p_1 = p_{\tfrac12} - \tfrac\epsilon2 \nabla_q V(q_1)
    grad = nabla(q)
    p.sub_(- eps * grad / 2)
    return p, q, grad


In [None]:
mu = torch.tensor([
    [-3., -3.],
    [+3., -3.],
    [+3., +3.],
])

In [None]:
cov = torch.tensor[]

In [None]:
torch.triangular_solve()

In [None]:
def log_prob(value, mu, loc):
    pass

In [None]:
a = Normal(mu, 1.)

In [None]:
target.rsample((10,)).mean(0)

In [None]:
target.log_prob??

In [None]:
target.scale

In [None]:
q = mu[0].clone().requires_grad_(True)
target.log_prob(q).mean(0)