# Score Functions

One of the coolest new developments in generative modelling
is the use of _score models_ to generate samples
that look like an existing collection of samples.
Pioneering this work is Yang Song,
who did this work while a PhD student at Stanford University.
The way I first uncovered his work is through Twitter,
where I was brought to his excellently written [blog post][yangblog] on the topic.

[yangblog]: https://yang-song.github.io/blog/2021/score/

In this collection of notebooks,
I would like to explore the fundamental ideas that underlie his work.
Along the way, we will work towards
a pedagogical implementation of score models as generative models.
By the end of this journey,
we should have a much better understanding of score models and their core concepts,
and should also have a framework for writing the code necessary
to implement score models in JAX and Python.

We're going to need some basic knowledge and terminology established first,
otherwise, the terminology may become overwhelming,
especially for those who are not well-versed in probabilistic modelling.
As such, we're going to start with a bunch of definitions.
Don't skip these, they're important!

## Definition

What's a score function?
The score function is defined as follows:

> The score function is
> the gradient of the log of the probability density function 
> of a probability distribution
> with respect to the distribution's support.

There's a lot to unpack in there, 
so let's dissect the anatomy of this definition bit by bit.

### Probability Distributions

Probability distributions are super cool objects in stats[^bayes].
Distributions can be **configured** through their parameters;
for example, by setting the values $\mu$ and $\sigma$ of a Gaussian respectively.
We can use probability distributions to generate data, 
and we can use them to evaluate the likelihood of observed data.
The latter point is done by using a probability distribution's
**probability density function**[^discrete].

[^bayes]: I've explored the anatomy of a probability distribution
in my essay on [Bayesian and computational statistics][bayes],
and would recommend looking at it for a refresher.

[^discrete]: Or the probability mass function, for discrete distributions,
but we're going to stick with continuous distributions for this essay.

[bayes]: https://ericmjl.github.io/essays-on-data-science/machine-learning/computational-bayesian-stats/

### Probability Density Function

A distribution's probability density function (PDF)
describes the propensity of a probability distribution
to generate draws of a particular value.
As mentioned above, we primarily use the PDF to
_evaluate the likelihood of the observing data, given the distribution's configuration_.
If you need an anchoring example, 
think of the venerable Gaussian probability density function in @fig-likelihood:

In [None]:
#| echo: false 
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

Every distribution has a **support**,
which is the range of values for which the probability distribution is defined.
The Gaussian has support in the range $(-\infty, \infty)$,
while positive-only distributions (such as the Exponential)
have support in the range $(0, \infty)$.

### Log PDF

Because the PDF is nothing more than a math function, we can take its logarithm!
In computational statistics, taking the log is usually done for pragmatic purposes,
as we usually end up with underflow issues otherwise.
For the standard Gaussian above, its log PDF looks like what we see in @fig-likelihood.

We often call the log PDF **logp** for short,
and in the probabilistic programming language PyMC,
`logp` is the name of the class method use for calculating 
the log likelihood of data under the distribution.

### Score

Finally, we get to the **score**.
As it turns out, because the logp function is differentiable,
we can take its derivative easily (using JAX, for example).
The derivative of the logp function is called the **score function**.
The score of a distribution is the gradient of the logp function w.r.t. the support.
You can visualize what it is like in @fig-likelihood.

In [None]:
#| echo: false
#| label: fig-likelihood
#| fig-cap: "$P(x)$ (likelihood, PDF), $log P(x)$ (log likelihood, logp), and $dlogP(x)$ (score) of a Gaussian."
from jax.scipy.stats import norm
from jax import grad, vmap
import matplotlib.pyplot as plt
 
fig, axes = plt.subplots(figsize=(8, 3), nrows=1, ncols=3, sharex=True)

plt.sca(axes[0])
x = np.linspace(-3.5, 3.5, 1000)
y = norm.pdf(x)
plt.plot(x, y, color="black")
plt.xlabel("Support")
plt.ylabel("Likelihood")
plt.title("PDF")
sns.despine()

y = norm.logpdf(x)
plt.sca(axes[1])
plt.plot(x, y, color="black")
plt.ylabel("logP(x)")
plt.title("Log PDF")

# Tangent Line
def line(x):
    return grad(norm.logpdf)(x_pt) * (x - x_pt) + y_pt
x_pt = -1.5
y_pt = norm.logpdf(x_pt)
xrange = np.linspace(x_pt - 1, x_pt + 1, 10)
plt.plot(x, y, color="black")
plt.scatter(x_pt, y_pt, color="gray")
plt.plot(xrange, vmap(line)(xrange), color="gray", ls="--")


plt.sca(axes[2])
plt.plot(x, vmap(grad(norm.logpdf))(x), color="black")
plt.axhline(y=0, ls="--", color="black")
plt.axvline(x=0, ls="--", color="black")
plt.title("Score")

sns.despine()
plt.tight_layout()

In JAX, obtaining the score function is relatively easy.
We simply need to use JAX's `grad` function to obtain the transformed logp function.

In [None]:
import jax.numpy as np

x = np.linspace(-3, 3, 1000)
y = norm.logpdf(x, loc=0, scale=1)

gaussian_score = grad(norm.logpdf)


From visual inspection above,
we know that at the top of the Gaussian,
the gradient should be zero,
and can verify as much.

In [None]:
gaussian_score(0.0)


At the tails, the gradient should be of higher magnitude
than at the mean.

In [None]:
gaussian_score(-3.0)


In [None]:
gaussian_score(3.0)


## Estimating the score function

As it turns out, there is [a paper][hyvarinen]
published by Aapo Hyvärinen in the 2005 in the Journal of Machine Learning Research
that details how to _estimate_ the score function
in the absence of knowledge of the true data generating distribution.
When I first heard of the idea, I thought it was crazy --
crazy cool that we could even do this!

[hyvarinen]: https://jmlr.csail.mit.edu/papers/volume6/hyvarinen05a/old.pdf

One key equation in the paper is equation #4.
This equation details how we can use an arbitrary function, $\psi(x, \theta)$,
to approximate the score function,
and the loss function needed to train the parameters of the function $\theta$
to approximate the score function.
I've replicated the equation below,
alongside a bullet-point explanation of what each of the terms are:

$$J(\theta) = \frac{1}{T} \sum_{t=1}^{T} \sum_{i=1}^{n} [\delta_i \psi_i(x(t); \theta) + \frac{1}{2} \psi_i(x(t); \theta)^2 ] + \text{const}$$

Here:

- $J(\theta)$ is the loss function that we wish to minimize w.r.t. the parameters $\theta$
- $\theta$ are the parameters of the function $\psi_i$
- $\psi_i(x(t); \theta)$ are the score function estimators for each dimension $i$ in $x$, which has parameters $\theta$.
- $x(t)$ are the i.i.d. samples from the unknown data-generating distribution.
- $\delta_i$ refers to the partial derivative w.r.t. dimension $i$ in $x$.
- $\text{const}$ is a constant term that effectively can be ignored.


Let's explore the idea in a bit more detail.
What we're going to do here is use a simple feed-forward neural network
as the score function estimator $\psi(x(t), \theta)$.

In [None]:
from jax import random

key = random.PRNGKey(44)

true_mu = 3.0
true_sigma = 1.0
data = random.normal(key, shape=(1000,)) * true_sigma + true_mu
data[0:10]  # showing just the first 10 samples drawn


Let's also verify that the $\mu$ and $\sigma$ of the data 
are as close to the ground truth as possible.

In [None]:
data.mean(), data.std()

Evaluate the score of the data _under the true model_.
We ensure that the resulting score function has the same signature
as the underlying distribution that it is based on.

In [None]:
from score_models.models import gaussian_model

init_fun, apply_fun = gaussian_model()
true_params = (true_mu, np.log(true_sigma))

(
    apply_fun(true_params, true_mu),
    apply_fun(true_params, true_mu + true_sigma),
    apply_fun(true_params, true_mu - true_sigma * 3),
)

Calculate the true data score per draw.

In [None]:
# Don't forget to pass in log of true_sigma!!!
true_data_score = vmap(partial(apply_fun, true_params))(data)
true_data_score


Use gradient descent to find parameters of the Gaussian
that minimize score function loss.
To do this, we will use the GradientDescent solver from `jaxopt`,
which will give us a really concise syntax.

In [None]:
k1, k2, k3 = random.split(key, 3)
_, params_init = init_fun(k1)
params_init

In [None]:
score_func = partial(apply_fun, params_init)
dscore_func = grad(score_func)

In [None]:
from score_models.losses import l2_norm
from jax import grad
from typing import Callable

def score_matching_loss(params, score_func, batch):
    score_func = partial(score_func, params)
    dscore_func = grad(score_func)

    term1 = vmap(dscore_func)(batch)
    term2 = (0.5 * vmap(score_func)(batch) ** 2)

    inner_term = term1 + term2
    return np.mean(inner_term).squeeze()



In [None]:

score_matching_loss(params_init, apply_fun, data)
myloss = partial(score_matching_loss, score_func=apply_fun)

In [None]:
solver = GradientDescent(fun=myloss, maxiter=20000, stepsize=5e-2)
result = solver.run(params_init, batch=data)

Do the resulting params match up?

In [None]:
mu, log_sigma = result.params
mu, np.exp(log_sigma)


In [None]:
np.mean(data), np.std(data)


Looks like they do!

Now let's compare:

In [None]:
def y_eq_x(x, y, ax):
    minval = min(min(x), min(y))
    maxval = max(max(x), max(y))

    ax.plot([minval, maxval], [minval, maxval])


In [None]:
est_mu, est_log_sigma = result.params

est_mu - true_mu, np.exp(est_log_sigma) - true_sigma


In [None]:
gaussian_scores = vmap(partial(apply_fun, (result.params)))(data)
plt.scatter(true_data_score, gaussian_scores)
y_eq_x(true_data_score, gaussian_scores, plt.gca())
plt.xlabel("True Data Score")
plt.ylabel("Model Score")
plt.title("Gaussian Model Performance")
plt.show()


What if we try to approximate the score function with a neural network instead?

## Approximate Score Function with NN

In [None]:
from typing import Tuple
from jax.example_libraries import stax 



In [None]:
from score_models.models import nn_model
init_fun, apply_fun = nn_model()

def score_fun(params, batch):
    out = apply_fun(params, batch).squeeze()
    return out 


_, params_init = init_fun(rng=random.PRNGKey(44), input_shape=(1,))
myloss = partial(score_matching_loss, score_func=score_fun)

solver = GradientDescent(fun=myloss, maxiter=1200)
result = solver.run(params_init, batch=data)


In [None]:
gaussian_scores = vmap(partial(apply_fun, result.params))(data).squeeze()
plt.scatter(true_data_score, gaussian_scores)
y_eq_x(true_data_score, gaussian_scores, plt.gca())
plt.title("Trained Neural Network")
plt.xlabel("True Data Score")
plt.ylabel("Model Data Score")
plt.show()

In [None]:
gaussian_scores = vmap(partial(apply_fun, params_init))(data).squeeze()
plt.scatter(true_data_score, gaussian_scores)
y_eq_x(true_data_score, gaussian_scores, plt.gca())
plt.title("Initializsed Neural Network")
plt.xlabel("True Data Score")
plt.ylabel("Model Data Score")
plt.show()


## Ensure this works with mixture distributions

Mixture distributions are what our data will look the most like.
Let's make sure our approximate score function 
can approximate the mixture distribution scores as accurately as possible.

In [None]:
import seaborn as sns 

x = np.linspace(-10, 10, 200)
mus = np.array([-3, 3])
sigmas = np.array([1, 1])
ws = np.array([0.1, 0.9])

def mixture_pdf(x, mus, sigmas, ws):
    component_pdfs = vmap(partial(norm.pdf, x))(mus, sigmas)  # 2, n_draws)
    scaled_component_pdfs = vmap(np.multiply)(component_pdfs, ws)
    total_pdf = np.sum(scaled_component_pdfs, axis=0)
    return total_pdf

def mixture_logpdf(x, mus, sigmas, ws):
    return np.log(mixture_pdf(x, mus, sigmas, ws))

dmixture_logpdf = grad(mixture_logpdf, argnums=0)
mixture_logpdf_grads = vmap(partial(dmixture_logpdf, mus=mus, sigmas=sigmas, ws=ws))(x)
plt.plot(x, mixture_logpdf_grads)
plt.xlabel("Support")
plt.ylabel("Score")
sns.despine()

Now, we need to make sure that our neural network model
is able to approximate the score function above.

In [None]:
import numpy as onp 


draws = 1000
mix1 = random.normal(k1, shape=(1000,)) * 1 - 3
mix2 = random.normal(k2, shape=(9000,)) * 1 + 3
data = np.concatenate([mix1, mix2])
plt.hist(onp.array(data))


In [None]:
_, params_init = init_fun(rng=random.PRNGKey(44), input_shape=(1,))
myloss = partial(score_matching_loss, score_func=score_fun)

solver = GradientDescent(fun=myloss, maxiter=12000)
result = solver.run(params_init, batch=data)


In [None]:
xs = np.linspace(-6, 6, 1000)
out = apply_fun(result.params, inputs=xs.reshape(-1, 1))

plt.plot(xs, out.squeeze(), label="NN Estimated")
plt.plot(x, mixture_logpdf_grads, label="Ground Truth")
plt.xlabel("Support")
plt.ylabel("Score")
plt.legend()
plt.title("Estimated vs. Truth Score Function")
sns.despine()


Not bad, we can estimate the gradient in the regime where we have lots of data!
However, we can't seem to say the same for the tails (and midpoint of the mixture)
where there is very little data present.

## Confirm with 3-component mixture

In [None]:
mus = np.array([-7, -2, 3])
sigmas = np.ones(3)
ws = np.array([0.1, 0.4, 0.5])
mixture_logpdf_grads = vmap(partial(dmixture_logpdf, mus=mus, sigmas=sigmas, ws=ws))(x)
plt.plot(x, mixture_logpdf_grads)
plt.xlabel("Support")
plt.ylabel("Score")
plt.title("3-Component Mixture Score")
sns.despine()

In [None]:
draws = 1000
mix1 = random.normal(k1, shape=(1000,)) * 1 - 6
mix2 = random.normal(k2, shape=(4000,)) * 1 - 2
mix3 = random.normal(k3, shape=(5000,)) * 1 + 1
data = np.concatenate([mix1, mix2, mix3])
plt.hist(onp.array(data), bins=100)


In [None]:
_, params_init = init_fun(rng=random.PRNGKey(44), input_shape=(1,))
myloss = partial(score_matching_loss, score_func=score_fun)

solver = GradientDescent(fun=myloss, maxiter=12000, tol=1e-4)
result = solver.run(params_init, batch=data)


In [None]:
xs = np.linspace(-10, 8, 1000)
out = apply_fun(result.params, inputs=xs.reshape(-1, 1))

plt.plot(xs, out.squeeze(), label="NN Estimated")
plt.plot(x, mixture_logpdf_grads, label="Ground Truth")
plt.xlabel("Support")
plt.ylabel("Score")
plt.legend()  
plt.title("Estimated vs. Truth Score Function")
sns.despine()


Not bad, this actually works!