In [None]:
%load_ext autoreload 
%autoreload 2

# Noise Annealing

We're going to do noise annealing once more, just this time in 2D.
Here we go!

In [None]:
from score_models.utils import generate_mixture_2d
from jax import random , numpy as np
import jax 

# jax.config.update('jax_platform_name', 'cpu')  # this notebook's code doesn't fit in 8GB of GPU memory!

key = random.PRNGKey(42)

data, k3 = generate_mixture_2d(key)

In [None]:
noise_scales = np.linspace(1, 5, 20)  # 20 different noise scales.

We take the original data and progressively noise it up.
This is a `lax.scan` operation, not a `vmap` operation;
the second noising depends on the first,
the third depends on the second, etc.

TODO: In my first notebook (`noise-scales.ipynb`),
I implemented noising incorrectly, I think.

In [None]:
from jax import vmap , lax
from functools import partial


k4, k5, k6 = random.split(k3, 3)

# The draws to add are definitely independently drawn.

def noise_up_one(key, sample, noise_scales):

    def noise_up(prev_x, draw):
        """A function to noise up existing data."""
        new_x = prev_x + draw
        return new_x, prev_x 

    draw_keys = random.split(key, len(noise_scales))
    covs = vmap(lambda x: np.eye(len(sample)) * x)(noise_scales)
    draws = vmap(partial(random.multivariate_normal, mean=np.zeros(len(sample))))(draw_keys, cov=covs)
    _, noised = lax.scan(noise_up, sample, draws)
    return noised

# Test-drive
out = noise_up_one(k4, data[0], noise_scales)
out.shape

In [None]:
sample_keys = random.split(k5, len(data))

noised_up_data = vmap(partial(noise_up_one, noise_scales=noise_scales))(sample_keys, data)
noised_up_data

## Visualize the data

Plot distribution over noise scales.

In [None]:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(figsize=(20, 16), nrows=4, ncols=5, sharex=True, sharey=True)
for (scale_index, noised_data), ax in zip(enumerate(noised_up_data.swapaxes(0, 1)), axes.flatten()):
    ax.scatter(noised_data[:, 0], noised_data[:, 1])
    ax.set_title(f"Noise scale: {noise_scales[scale_index]:.2f}")

## Jointly train

We now jointly train all 20 neural networks together.

In [None]:
from score_models.models import nn_model
init_fun, nn_score_func = nn_model(output_dim=2)
keys = random.split(key, len(noise_scales))
key = random.split(keys[-1])[0]
_, params_init = vmap(partial(init_fun, input_shape=(None, 2)))(keys)  # n sets of params now

## Test output

Should be of shape (None, 2)

In [None]:
from score_models.losses import score_matching_loss
from typing import Callable

def multi_scale_loss(params, score_func: Callable, batch: np.ndarray, scales: np.ndarray):
    """Joint loss function for multi-scale loss.
    
    :param params: `vmap`-able collection of score functions.
        Should be of length equal to the number of scales used.
    :param score_func: A function that estimates the score of a data point.
        It is `vmap`-ed over `params`.
    :param batch: A collection of data points sampled from the noised-up distribution.
        Should be of shape (n_observations, n_scales, n_dims).
    """
    batch = batch.swapaxes(0, 1)
    lossfunc = partial(score_matching_loss, score_func=score_func)  # already includes l2 norm on params
    loss = vmap(lossfunc)(params, batch=batch)
    scaled_loss = loss * scales 
    return np.sum(scaled_loss) 
    

In [None]:
multi_scale_loss(params_init, nn_score_func, noised_up_data, noise_scales)

In [None]:
from jaxopt import GradientDescent
from jax import jit 

joint_loss_func = jit(partial(multi_scale_loss, score_func=nn_score_func))
solver = GradientDescent(joint_loss_func, maxiter=1000)
result = solver.run(params_init, batch=noised_up_data, scales=noise_scales)

In [None]:
noised_up_data.shape

## Visualize learned gradient field

In [None]:
n_points = 21
xs = np.linspace(-30, 30, n_points)
ys = np.linspace(-30, 30, n_points)
xxs, yys = np.meshgrid(xs, ys)
xxs.shape, yys.shape

x_y_pair = np.vstack([xxs.flatten(), yys.flatten()]).T
x_y_pair.shape

In [None]:
from tqdm.autonotebook import tqdm 
from jax.tree_util import tree_map
def get_params(params, i):
    param = tree_map(lambda x: x[i], params)
    return param


fig, axes = plt.subplots(figsize=(20, 16), nrows=4, ncols=5, sharex=True, sharey=True)

for i in tqdm(range(len(noise_scales))):
    param = get_params(result.params, i) 
    gradient_field = vmap(partial(nn_score_func, param))(x_y_pair)

    for xy_pair, vect in zip(x_y_pair, gradient_field):
        axes.flatten()[i].arrow(*xy_pair, *vect, width=0.3, alpha=0.1)    
    axes.flatten()[i].scatter(*noised_up_data[:, i, :].T, alpha=0.1, color="black")

## Confirm sampling from score function

In [None]:
from jax.tree_util import tree_map
from score_models.sampler import langevin_dynamics


# Get out the params independently
fig, axes = plt.subplots(figsize=(16, 16), nrows=5, ncols=4, sharex=True, sharey=True)
k6, k7 = random.split(k5)
keys = random.split(k6, len(noise_scales))
for i, scale in tqdm(enumerate(noise_scales)):
    param = tree_map(lambda x: x[i], result.params)
    initial_states, final_states, chain_samples_joint = langevin_dynamics(
        n_chains=200, 
        n_samples=1000, 
        key=keys[i], 
        epsilon=5e-3, 
        score_func=nn_score_func, 
        params=param,
        init_scale=3, 
        sample_shape=(None, 2)
    )
    final_states = np.clip(final_states, -30, 30)
    axes.flatten()[i].scatter(noised_up_data[:, i, 0], noised_up_data[:, i, 1], label="Data", color="blue", alpha=0.1)
    axes.flatten()[i].scatter(final_states[:, 0], final_states[:, 1], label=f"Perturbation: {scale:.2f}", color="black", alpha=0.1)
    axes.flatten()[i].legend()
    axes.flatten()[i].set_title(f"Perturbation: {scale:.2f}")
plt.legend()
plt.show()

## Annealed Langevin Dynamics

In [None]:
fig, axes = plt.subplots(figsize=(20, 16), nrows=5, ncols=4, sharex=True, sharey=True)
n_chains = 1000
n_samples = 1000
k8, k9, k10 = random.split(k7, 3)
starter_xs = random.normal(k8, shape=(n_chains, 2))
sampler_starter_xs_record = [starter_xs]
chain_samples_record = []
epsilon = 5e-3
kk, kk_ = random.split(k9)
# We start first by sampling from the 
for i, scale in tqdm(enumerate(noise_scales[::-1])):
    kk, kk_ = random.split(kk_)
    param = tree_map(lambda x: x[-i], result.params)
    final_states, starter_xs, chain_samples_annealed = langevin_dynamics(
        n_chains=n_chains, 
        n_samples=n_samples, 
        key=kk, 
        epsilon=epsilon, 
        score_func=nn_score_func, 
        params=param, 
        init_scale=10, 
        starter_xs=starter_xs,
    )
    sampler_starter_xs_record.append(starter_xs)
    chain_samples_record.append(chain_samples_annealed)
    axes.flatten()[i].scatter(*final_states.T, label=f"Noise {scale:.2f}", color="black", alpha=0.1)
    axes.flatten()[i].scatter(*noised_up_data[:, -(i+1), :].T, label="Data", alpha=0.1)
    axes.flatten()[i].legend()
    axes.flatten()[i].set_title(f"Scale: {scale:.2f}")
plt.show()

## Make an animation

I want to see if we are actually doing the thing that I though we are doing.

In [None]:
from celluloid import Camera 

fig, axes = plt.subplots(figsize=(16, 8), ncols=2, sharex=True, sharey=True)
cam = Camera(fig)


for i, scale_chain_record in enumerate(chain_samples_record):
    for step in range(0, n_samples, int(n_samples / 50)):
        axes[0].scatter(*noised_up_data[:, -(i+1), :].T, color="black", alpha=0.1)
        axes[1].scatter(*scale_chain_record[:, step, :].T, color="blue", alpha=0.1)
        cam.snap()

In [None]:
animation = cam.animate()
animation.save("annealed_sampling_joint.mp4", dpi=300, fps=60)