In [None]:
%load_ext autoreload 
%autoreload 2 

# Generalizing to 2D and beyond

In this notebook, I want to show how we can generalize from 1D to 2D and beyond.
To do that, I will use a mixture 2D Gaussian as an anchoring example.

In [None]:
from jax import random, numpy as np, vmap

Firstly, we generate data.

In [None]:
from score_models.utils import generate_mixture_2d

key = random.PRNGKey(42)
data, k3 = generate_mixture_2d(key)

Let's plot the data just to make sure we know what it's all about.

In [None]:
import matplotlib.pyplot as plt 

plt.scatter(data[:, 0], data[:, 1], alpha=0.1)
plt.show()

Now, let's train a score model for the data.

In [None]:
data.shape

The score of a 2D dataset is the gradient w.r.t. the inputs.
Because the data are multi-dimensional,
our gradients are necessarily equally dimensioned;
they would be esssentially partial derivatives w.r.t. the input.
Specifically, the score function maps $\mathbb{R}^d \rightarrow \mathbb{R}^d$.

In [None]:
from score_models.models import nn_model
from score_models.losses import score_matching_loss
from functools import partial
from jaxopt import GradientDescent

init_fun, nn_score_func = nn_model(output_dim=2)
k4, k5 = random.split(k3)
_, params_init = init_fun(k4, input_shape=(None, 2))

# Test-drive forward pass
out_test = vmap(partial(nn_score_func, params_init))(data)
out_test

In [None]:
# For debugging purposes
out_test.shape

Now, we need to write the score matching loss.
The score matching loss is the sum over all dimensions
of the mean over all samples,
as given by equation 6 in the JMLR paper (2005) by Aapo Hyvärinen.
In earlier experiments, I also observed exploding weights leading to NaN values,
so I will be applying weight L2 regularization to prevent that from happening.

In [None]:
from score_models.losses import score_matching_loss

In [None]:
from jax import jit 
myloss = jit(partial(score_matching_loss, score_func=nn_score_func))
solver = GradientDescent(fun=myloss, maxiter=10000)
result = solver.run(params_init, batch=data)


In [None]:
from jax.tree_util import tree_flatten, tree_map

params_flat, _ = tree_flatten(result.params)
params_flat = tree_map(lambda x: x.flatten(), params_flat)
params_flat = np.concatenate(params_flat)
params_flat.max(), params_flat.min()

## Visualize learned gradient field

In [None]:
n_points = 21
xs = np.linspace(-30, 30, n_points)
ys = np.linspace(-30, 30, n_points)
xxs, yys = np.meshgrid(xs, ys)
xxs.shape, yys.shape

x_y_pair = np.vstack([xxs.flatten(), yys.flatten()]).T
x_y_pair.shape

In [None]:
fig, axes = plt.subplots(figsize=(10, 10))

result.params
gradient_field = vmap(partial(nn_score_func, result.params))(x_y_pair)

for xy_pair, vect in zip(x_y_pair, gradient_field):
    axes.arrow(*xy_pair, *vect * 0.1, width=0.3, alpha=0.1)    
axes.scatter(*data.T, alpha=0.1, color="black")

Now, we sample with Langevin Dynamics.

In [None]:
from score_models.sampler import langevin_dynamics

starter_xs = random.multivariate_normal(k5, mean=np.array([-5, -5]), cov=np.eye(2)*20, shape=(4000,)) 
epsilon = 5e-3
starting_states, final_states, chain_samples = langevin_dynamics(
    n_chains=4000, 
    n_samples=8000, 
    key=key, 
    epsilon=epsilon, 
    score_func=nn_score_func, 
    params=result.params, 
    init_scale=10, 
    starter_xs=starter_xs,
)


In [None]:
final_states

In [None]:
import matplotlib.pyplot as plt 

plt.figure(figsize=(8, 8))
plt.scatter(data[:, 0], data[:, 1], alpha=0.1, label="data")
plt.scatter(starting_states[:, 0], starting_states[:, 1], alpha=0.1, label="starting samples")
plt.scatter(final_states[:, 0], final_states[:, 1], alpha=0.1, label="final samples")
plt.xlim(-15, 15)
plt.ylim(-15, 15)
plt.gca().set_aspect("equal")
plt.legend()
plt.show()

In [None]:
from celluloid import Camera
from tqdm.autonotebook import tqdm 

fig = plt.figure()
camera = Camera(fig)

for timepoint in tqdm(chain_samples.swapaxes(0, 1)[::10]):
    plt.scatter(*timepoint.T, color="blue", alpha=0.1)
    camera.snap()



In [None]:
animation = camera.animate()

In [None]:
from IPython.display import display_html

# display_html(animation)
animation.save("sampling2.mp4", dpi=300, fps=60)

Firstly, it's powerful to just "see" what's happening amongst the chain samples!
