In [None]:
%load_ext autoreload
%autoreload 2

# Noise Scales

With a score function approximator, we have one small issue:
in regions of low sample density,
our estimate of the score function will be inaccurate,
simply because we have few samples in those regimes.
To get around this, we can:

> perturb data points with noise 
> and train score-based models on the noisy data points instead.
> When the noise magnitude is sufficiently large, 
> it can populate low data density regions 
> to improve the accuracy of estimated scores.

There is a huge tradeoff here, though:
the larger the amount of perturbation,
the greater the corruption of the input data.
Let's see that in action.

In [None]:
from jax import random, vmap
import jax.numpy as np 
import matplotlib.pyplot as plt

def ecdf(data):
    x, y = np.sort(data), np.arange(1, len(data) + 1) / len(data)
    return x, y

key = random.PRNGKey(45)
k1, k2, k3 = random.split(key, 3)

mix1 = random.normal(k1, shape=(1000,)) * 2 - 10
mix2 = random.normal(k2, shape=(500,)) * 1 + 6

data = np.concatenate([mix1, mix2]).flatten()
plt.plot(*ecdf(data))
plt.show()

With perturbation:

In [None]:
noise_scales = np.array([1.0, 2.0, 5.0, 10.0])  # the first one doesn't have any changes!
k1, k2 = random.split(k3)
perturbations = random.normal(k1, shape=(len(data), len(noise_scales)))
perturbations *= noise_scales
data_perturbed = data.reshape(-1, 1) + perturbations

for d, n in zip(data_perturbed.T, noise_scales):
    plt.plot(*ecdf(d), label=f"sigma={n}")
plt.legend()
plt.show()

Should be evident from the figure above
that when we add more noise, the data look more and more like a single Gaussian
and less like the original.
Most crucially, in the regions of low density between the two mixture Gaussians,
we have a much more nicely-defined PDF,
and hence a better ability to compute the score function accurately,
which we will be able to use when generating data.

## Confirm that we can sample using score function



Want to make sure that we can sample from the blue curve
using the procedure we showed in the previous notebook.

In [None]:
from score_models.sampler import langevin_dynamics
from score_models.models import nn_model
from score_models.losses import score_matching_loss

from jaxopt import GradientDescent
from functools import partial

init_fun, nn_score_func = nn_model()
k1, k2 = random.split(k3)
_, params_init = init_fun(k1, input_shape=(None, 1))
myloss = partial(score_matching_loss, score_func=nn_score_func)
solver = GradientDescent(fun=myloss, maxiter=1200)
result = solver.run(params_init, batch=data.reshape(-1, 1))
initial_states, final_states, chain_samples_naive = langevin_dynamics(
    n_chains=10000, 
    n_samples=2000, 
    key=key, 
    epsilon=5e-3, 
    score_func=nn_score_func, 
    params=result.params, 
    init_scale=10,
    sample_shape=(None, 1),
)


In [None]:
plt.plot(*ecdf(final_states.flatten()), label=f"score model langevin")
plt.plot(*ecdf(data), label=f"sigma=1")
plt.legend()
plt.show()

OK, yes, we're able to!
The weights are off,
but at least we can _sample_.

## One score model per perturbation

One key idea in Yang Song's blog post
is that we can train score models for each of the noise levels
and then use Langevin dynamics in an annealed fashion
to progressively obtain better and better samples.
In our example, we will have four models trained with a single loss function,
which is the weighted sum of Fisher divergences.

In [None]:
# Joint loss
from score_models.models import nn_model
from functools import partial

key = random.PRNGKey(44)

# Four models
init_fun, apply_fun = nn_model()
keys = random.split(key, 4)
key = random.split(keys[-1])[0]
_, params_init = vmap(partial(init_fun, input_shape=(None, 1)))(keys)  # 4 sets of params now

In [None]:
from score_models.losses import score_matching_loss
from typing import Callable
from jax import jit 


def multi_scale_loss(params: list, score_func: Callable, batch: np.ndarray, scales: np.ndarray):
    """Joint loss function.
    
    :param params: Should be a list of params for `score_func`, 
        should be of length equal to `scales`.
    :param score_func: A function that estimates the score of a data point.
        It is vmapped over `params`.
    :param batch: A collection of data points of length equal to `scales`.
        Should be of shape (n_observations, n_scales).
    :param scales: Noise perturbation scale parameter.
        Should be equal to the number of perturbations made.
    """
    batch = batch.T  # shape: (n_scales, n_observations)
    lossfunc = partial(score_matching_loss, score_func=score_func)
    loss = vmap(lossfunc)(params, batch=batch)
    return np.sum(loss) 


multi_scale_loss(params_init, nn_score_func, data_perturbed, noise_scales)

In [None]:
from jaxopt import GradientDescent
from jax import jit 

joint_loss_func = jit(partial(multi_scale_loss, score_func=nn_score_func))
solver = GradientDescent(joint_loss_func, maxiter=5000)
result = solver.run(params_init, batch=data_perturbed, scales=noise_scales)

In [None]:
len(result.params), len(params_init)

## Confirm sampling works after joint training

Need to make sure that after joint training,
samples from the first set of params approximate most closely the true data.

In [None]:
from jax.tree_util import tree_map
# Get out the params independently
fig, axes = plt.subplots(figsize=(16, 4), nrows=1, ncols=4, sharex=True, sharey=True)
for i, scale in enumerate(noise_scales):
    param = tree_map(lambda x: x[i], result.params)
    initial_states, final_states, chain_samples_joint = langevin_dynamics(
        n_chains=2000, 
        n_samples=10000, 
        key=k1, 
        epsilon=5e-3, 
        score_func=nn_score_func, 
        params=param,
        init_scale=3, 
        sample_shape=(None, 1)
    )
    final_states = np.clip(final_states, -30, 30)
    axes[i].plot(*ecdf(final_states.flatten()), label=f"Perturbation: {scale}", color="black")
    axes[i].plot(*ecdf(data), label="Data")
    axes[i].legend()
    axes[i].set_title(f"Perturbation: {scale}")
plt.legend()
plt.show()

Hmm, not bad.
We see that the samples drawn from Perturbation 1.0 (i.e. no perturbation)
match the closest to the samples drawn from perturbation 10.0.

## Annealed Langevin Dynamics

We now implement annealed Langevin dynamics,
where we sequentially sample from the data distributions at each noise level,.

In [None]:
from scipy.stats import wasserstein_distance

In [None]:
# I probably have an issue here with annealed sampling...
from jax.tree_util import tree_map
from score_models.sampler import langevin_dynamics

fig, axes = plt.subplots(figsize=(16, 4), ncols=4, sharex=True)
n_chains = 40000
k1, k2 = random.split(key)
starter_xs = random.normal(k1, shape=(n_chains, 1))
sampler_starter_xs_record = [starter_xs]
chain_samples_record = []
epsilon = 5e-3
# We start first by sampling from the 
for i, scale in enumerate(noise_scales[::-1]):
    k1, k2 = random.split(k2)
    param = tree_map(lambda x: x[-i], result.params)
    _, starter_xs, chain_samples_annealed = langevin_dynamics(
        n_chains=n_chains, 
        n_samples=1000, 
        key=k1, 
        epsilon=epsilon, 
        score_func=nn_score_func, 
        params=param, 
        init_scale=10, 
        starter_xs=starter_xs,
    )
    sampler_starter_xs_record.append(starter_xs)
    chain_samples_record.append(chain_samples_annealed)
    
    dist = wasserstein_distance(data.flatten(), starter_xs.flatten())

    axes[i].plot(*ecdf(starter_xs.flatten()), label=f"Noise {scale}", color="black")
    axes[i].plot(*ecdf(data), label="Data")
    axes[i].legend()
    axes[i].set_title(f"Distance: {dist:.2f}")
plt.show()

We see a tighter and tighter match to the original data distribution. 