In [None]:
#| echo: false 
#| output: false
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import os 
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax.numpy as np  # import here so that any warnings about no GPU are not shown in website.
np.arange(3)

# Generalizing to 2D

Thus far, we've explored score models in the context of 1D data.
This is intentional!
By working out the core ideas in a single dimension,
we can more easily reason about what actually is happening --
humans are, after all, very good at thinking in 1D.
In effect, we eliminate the cognitive load that comes with thinking multi-dimensionally.
Through this, the framework of how to think about 
how to use score models to generate data is quite clear.
Our ingredients are:

- Data,
- A trainable model that can approximate the score of our data (implying that yes, we will train that model!), and
- A procedure for noising up data and reversing that process to re-generate new data.

Alas, however, the world of data that inhabits our world is rarely just 1D.
More often than not, the data that we will encounter is going to be multi-dimensional.
To exacerbate the matter, our data are also oftentimes discrete and not continuous,
such as text, protein sequences, and more.
Do the ideas explored in 1D generalize to multiple dimensions?[^1]
In this notebook, I want to show how we can generalize from 1D to 2D.
(With a bit of hand-waving,
I'll claim at the end that this all works in n-dimensions too!)

[^1]: Of course, yes -- this is a rhetorical question --
and the more important point here is figuring out 
what we need to do to generalize beyond 1D.

## Data: Half Moons

As our anchoring example, we will use the half-moons dataset from `scikit-learn`.

In [None]:
#| code-fold: true
#| fig-cap: Half moons dataset.
#| label: fig-moons
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import seaborn as sns

X, y = make_moons(n_samples=1000, noise=0.1)
fig, axes = plt.subplots(figsize=(4, 4))
sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y)
sns.despine()

Comment goes here on labels. Though we have labels, we'll work with them later.

Next we noise up the data.
Strictly speaking with a constant drift term,
we need only parameterize our diffusion term using `t` (time)
and don't really need to use `diffrax`'s SDE capabilities.
We can noise up data by applying a draw
from an isotropic Gaussian with covariance equal to the time elapsed.

In [None]:
from jax import vmap
from functools import partial

def noise_batch(key, X: np.ndarray, t: float) -> np.ndarray:
    """Noise up one batch of data.
    
    :param x: One batch of data.
        Should be of shape (1, n_dims).
    :param t: Time scale at which to noise up.
    :returns: A NumPy array of noised up data.
    """
    if t == 0.0:
        return X
    cov = np.eye(len(X)) * t
    return X + random.multivariate_normal(key=key, mean=np.zeros(len(X)), cov=cov)


def noise(key, X, t):
    keys = random.split(key, num=len(X))
    return vmap(partial(noise_batch, t=t))(keys, X)

from jax import random 

fig, axes = plt.subplots(figsize=(8, 8), nrows=3, ncols=3)
ts = np.linspace(0.001, 0.2, 9)
key = random.PRNGKey(99)
noise_level_keys = random.split(key, 9)
noised_datas = []
for t, ax, key in zip(ts, axes.flatten(), noise_level_keys):
    noised_data = noise(key, X, t)
    noised_datas.append(noised_data)
    ax.scatter(noised_data[:, 0], noised_data[:, 1], alpha=0.1)
    ax.set_title(f"{t:.2f}")
noised_datas = np.stack(noised_datas)
sns.despine()
plt.tight_layout()

Sanity check `noised_data`'s shape, should be `(time, batch, n_data_dims)`:

In [None]:
noised_datas.shape

Now, we can set up a score model to be trained on each time point's noised-up data.

In [None]:
import equinox as eqx
from jax import nn


class ScoreModel2D(eqx.Module):
    """Time-dependent score model.

    We choose an MLP here with 2 inputs (`x` and `t` concatenated),
    and output a scalar which is the estimated score.
    """

    mlp: eqx.Module

    def __init__(
        self,
        in_size=3,
        out_size=2,
        width_size=256,
        depth=2,
        activation=nn.softplus,
        key=random.PRNGKey(45),
    ):
        self.mlp = eqx.nn.MLP(
            in_size=in_size,
            out_size=out_size,
            width_size=width_size,
            depth=depth,
            activation=activation,
            key=key,
        )

    @eqx.filter_jit
    def __call__(self, x: np.array, t: float):
        """Forward pass.

        :param x: Data. Should be of shape (1, :),
            as the model is intended to be vmapped over batches of data.
        :returns: Estimated score of a Gaussian.
        """
        t = np.array([t])
        x = np.concatenate([x, t])
        return self.mlp(x)


Test that the model's forward pass works:

In [None]:
from functools import partial
from jax import vmap 

model = ScoreModel2D()
t = 0.1

X_noised = noise(key, X, t)
out = vmap(partial(model, t=t))(X_noised)
out.shape

Now we define the loss function.

In [None]:
from jax import jacfwd, jit 

def sde_score_matching_loss(model, noised_data: np.ndarray, t: float):
    """Score matching loss for SDE-based score models.
    
    :param model: Equinox model.
    :param noised_data: Batch of data from 1 noise scale of shape (batch, n_data_dims).
    :param t: Time in SDE at which the noise scale was evaluated.
    """
    model = partial(model, t=t)
    dmodel = jacfwd(model, argnums=0)
    term1 = vmap(dmodel)(noised_data)
    term1 = vmap(np.diagonal)(term1)
    term2 = 0.5 * vmap(model)(data) ** 2
    inner_term = term1 + term2
    summed_by_dims = vmap(np.sum)(inner_term)
    return np.mean(summed_by_dims)

@eqx.filter_jit
def joint_sde_score_matching_loss(model, noised_data_all, ts):
    """Joint score matching loss.
    
    :param model: An equinox model.
    :param noised_data_all: An array of shape (time, batch, n_data_dims).
    :param ts: An array of shape (time,).
    """
    loss = partial(sde_score_matching_loss, model)
    losses = vmap(loss)(noised_data_all, ts)
    return np.sum(losses)

Now, we calculate loss over all noised up data.

In [None]:
# arguments
key = random.PRNGKey(55)
model = ScoreModel2D()
data = X

# model
joint_sde_score_matching_loss(model, noised_datas, ts=ts)

In [None]:
dloss = eqx.filter_value_and_grad(joint_sde_score_matching_loss)
value, grads = dloss(model, noised_datas, ts=ts)
value

In [None]:
import optax
from tqdm.auto import tqdm

model = ScoreModel2D()

optimizer = optax.chain(
    optax.adam(5e-2),
    optax.clip(0.001),
)

opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
dloss = eqx.filter_value_and_grad(joint_sde_score_matching_loss)

n_steps = 13_000
iterator = tqdm(range(n_steps))
loss_history = []
key = random.PRNGKey(555)
keys = random.split(key, n_steps)

updated_score_model = model
for step in iterator:
    loss_score, grads = dloss(updated_score_model, noised_datas, ts)
    updates, opt_state = optimizer.update(grads, opt_state)
    updated_score_model = eqx.apply_updates(updated_score_model, updates)
    iterator.set_description(f"Score: {loss_score}")
    loss_history.append(float(loss_score))


In [None]:
import matplotlib.pyplot as plt 
plt.plot(loss_history)

In [None]:
from jax import random, numpy as np, vmap

Firstly, we generate data.

In [None]:
from score_models.utils import generate_mixture_2d

key = random.PRNGKey(42)
data, k3 = generate_mixture_2d(key)

Let's plot the data just to make sure we know what it's all about.

In [None]:
import matplotlib.pyplot as plt 

plt.scatter(data[:, 0], data[:, 1], alpha=0.1)
plt.show()

Now, let's train a score model for the data.

In [None]:
data.shape

The score of a 2D dataset is the gradient w.r.t. the inputs.
Because the data are multi-dimensional,
our gradients are necessarily equally dimensioned;
they would be esssentially partial derivatives w.r.t. the input.
Specifically, the score function maps $\mathbb{R}^d \rightarrow \mathbb{R}^d$.

In [None]:
from score_models.models import nn_model
from score_models.losses import score_matching_loss
from functools import partial
from jaxopt import GradientDescent

init_fun, nn_score_func = nn_model(output_dim=2)
k4, k5 = random.split(k3)
_, params_init = init_fun(k4, input_shape=(None, 2))

# Test-drive forward pass
out_test = vmap(partial(nn_score_func, params_init))(data)
out_test

In [None]:
# For debugging purposes
out_test.shape

Now, we need to write the score matching loss.
The score matching loss is the sum over all dimensions
of the mean over all samples,
as given by equation 6 in the JMLR paper (2005) by Aapo Hyvärinen.
In earlier experiments, I also observed exploding weights leading to NaN values,
so I will be applying weight L2 regularization to prevent that from happening.

In [None]:
from score_models.losses import score_matching_loss

In [None]:
from jax import jit 
myloss = jit(partial(score_matching_loss, score_func=nn_score_func))
solver = GradientDescent(fun=myloss, maxiter=10000)
result = solver.run(params_init, batch=data)


In [None]:
from jax.tree_util import tree_flatten, tree_map

params_flat, _ = tree_flatten(result.params)
params_flat = tree_map(lambda x: x.flatten(), params_flat)
params_flat = np.concatenate(params_flat)
params_flat.max(), params_flat.min()

## Visualize learned gradient field

In [None]:
n_points = 21
xs = np.linspace(-30, 30, n_points)
ys = np.linspace(-30, 30, n_points)
xxs, yys = np.meshgrid(xs, ys)
xxs.shape, yys.shape

x_y_pair = np.vstack([xxs.flatten(), yys.flatten()]).T
x_y_pair.shape

In [None]:
fig, axes = plt.subplots(figsize=(10, 10))

result.params
gradient_field = vmap(partial(nn_score_func, result.params))(x_y_pair)

for xy_pair, vect in zip(x_y_pair, gradient_field):
    axes.arrow(*xy_pair, *vect * 0.1, width=0.3, alpha=0.1)    
axes.scatter(*data.T, alpha=0.1, color="black")

Now, we sample with Langevin Dynamics.

In [None]:
from score_models.sampler import langevin_dynamics

starter_xs = random.multivariate_normal(k5, mean=np.array([-5, -5]), cov=np.eye(2)*20, shape=(4000,)) 
epsilon = 5e-3
starting_states, final_states, chain_samples = langevin_dynamics(
    n_chains=4000, 
    n_samples=8000, 
    key=key, 
    epsilon=epsilon, 
    score_func=nn_score_func, 
    params=result.params, 
    init_scale=10, 
    starter_xs=starter_xs,
)


In [None]:
final_states

In [None]:
import matplotlib.pyplot as plt 

plt.figure(figsize=(8, 8))
plt.scatter(data[:, 0], data[:, 1], alpha=0.1, label="data")
plt.scatter(starting_states[:, 0], starting_states[:, 1], alpha=0.1, label="starting samples")
plt.scatter(final_states[:, 0], final_states[:, 1], alpha=0.1, label="final samples")
plt.xlim(-15, 15)
plt.ylim(-15, 15)
plt.gca().set_aspect("equal")
plt.legend()
plt.show()

In [None]:
from celluloid import Camera
from tqdm.autonotebook import tqdm 

fig = plt.figure()
camera = Camera(fig)

for timepoint in tqdm(chain_samples.swapaxes(0, 1)[::10]):
    plt.scatter(*timepoint.T, color="blue", alpha=0.1)
    camera.snap()



In [None]:
animation = camera.animate()

In [None]:
from IPython.display import display_html

# display_html(animation)
animation.save("sampling2.mp4", dpi=300, fps=60)

Firstly, it's powerful to just "see" what's happening amongst the chain samples!
