## Langevin Dynamics

In this notebook, I'd like to explore further the use of [Langevin dynamics][langevin]
in generating samples from an unknown distribution.
(For now, I will stick to 1D distributions.)

[langevin]: https://yang-song.github.io/blog/2021/score/#langevin-dynamics

According to Yang Song's blog,

> Langevin dynamics provides an MCMC procedure to sample from a distribution
> $p(x)$ using only its score function $\nabla_x \log p(x)$. 
> Specifically, it initializes the chain from an arbitrary prior distribution
> $x_0 \sim \pi(x)$, and then iterates the following:
>
> $$x_{i+1} \leftarrow x_i + \epsilon \nabla_x \log p(x) + \sqrt{2 \epsilon} z_i$$

where $i = 0, 1, ... K$
and $z_i \sim Normal(0, I)$
is a multivariate Gaussian.

## Setup

We're going to generate data from a mixture 1D gaussian.

In [None]:
from jax import random 
import jax.numpy as np 
import matplotlib.pyplot as plt

def ecdf(data):
    x, y = np.sort(data), np.arange(1, len(data) + 1) / len(data)
    return x, y

key = random.PRNGKey(45)
k1, k2, k3 = random.split(key, 3)

mix1 = random.normal(k1, shape=(10000,)) * 3 - 5
mix2 = random.normal(k2, shape=(5000,)) * 1 + 6

data = np.concatenate([mix1, mix2]).flatten()
plt.plot(*ecdf(data))
plt.show()

## Approximate Gradients

Great, we have a mixture Gaussian.
Now, we need to train a score model to approximate its gradients.

In [None]:
from score_models.models import nn_score_func, nn_model
from score_models.losses import loss
from functools import partial
from jaxopt import GradientDescent


init_fun, apply_fun = nn_model()
k1, k2 = random.split(k3)
_, params_init = init_fun(k1, input_shape=(None, 1))
myloss = partial(loss, score_func=nn_score_func)
solver = GradientDescent(fun=myloss, maxiter=1200)
result = solver.run(params_init, batch=data.reshape(-1, 1))


In [None]:
from jax import vmap
model_scores = vmap(partial(nn_score_func, result.params))(data).squeeze()
model_scores

## Sample from Score Function

In [None]:
from typing import Callable
from jax import lax 

key = random.PRNGKey(40)
k1, k2 = random.split(key)
x = random.normal(k2, shape=(1,)) * 3 + 0
epsilon = 5e-3

# samples = []
# for i in range(10000):
#     k1, k2 = random.split(k2)
#     draw = random.normal(k1, shape=(1,))
#     x = x + epsilon * nn_score_func(result.params, x) + np.sqrt(2 * epsilon) * draw
#     samples.append(x)

def langevin_dynamics_one_chain(
    x: float, 
    key: random.PRNGKey, 
    n_samples: int, 
    epsilon: float, 
    score_func: Callable, 
    params,
):
    """One chain of Langevin dynamics sampling for score models."""

    keys = random.split(key, n_samples)
    def inner(prev_x, key):
        draw = random.normal(key, shape=(1,))
        new_x = prev_x + epsilon * score_func(params, prev_x) + np.sqrt(2 * epsilon) * draw
        return new_x, prev_x

    _, xs = lax.scan(inner, x, keys)
    return np.concatenate(xs) 


samples = langevin_dynamics_one_chain(x, key, 10000, epsilon, nn_score_func, result.params)
samples

In [None]:
# Now, vmap across many chains
def langevin_dynamics_many_chains(n_chains, n_samples, key, epsilon, score_func, params):
    """MCMC with Langevin dynamics to sample from the data generating distribution."""
    starter_xs = random.normal(key, shape=(n_chains,)).reshape(-1, 1)
    keys = random.split(key, num=n_chains)
    chain_func = partial(langevin_dynamics_one_chain, n_samples=n_samples, epsilon=epsilon, score_func=score_func, params=params)
    samples = vmap(chain_func)(starter_xs, keys)
    return samples

chain_samples = langevin_dynamics_many_chains(n_chains=1000, n_samples=2000, key=key, epsilon=epsilon, score_func=nn_score_func, params=result.params)


In [None]:
plt.plot(*ecdf(chain_samples.flatten()), label="samples")
plt.plot(*ecdf(data), label="data")
plt.legend()
plt.show()

OMG OMG OMG! We can actually sample from the multivariate Gaussian!

Some caveats:

1. The weight between the two Gaussians are different, which is disturbing.
   I am not sure whether this is because of the score function being inaccurate or not.
   In any case, we can definitely change the score function model
   as it is quite small at the moment.