# Score Models

Generative models of data are all the rage in the machine learning world.
Being a person embedded deeply in both the biology and machine learning fields,
one really cool application of generative models
is to be able to generate novel biological sequences.
To do so, we can turn to a variety of model classes,
such as models that leverage autoregression (recurrent neural networks, RNNs),
models that use a cat-and-mouse game of trickery 
to learn the data generating distribution (generative adversarial networks, GANs),
and models that try to approximate the data generating distribution
using a latent distribution (variational autoencoders, VAEs).

Even back in the day when I was still relatively untrained in probabilistic modelling,
RNNs, GANs, and VAEs all felt a bit mystical.
The reasons? I actually couldn't pin them down back then.
Fast-forward a few years, though, 
and having been embedded in the world of probabilistic modelling and Bayes,
the reasons are much clearer.
I'd like to explore one collection of work,
based heavily of Yang Song's blog post on [score models][score]
(this is also his PhD thesis topic),
and share what I think is the core of that idea.

[score]: https://yang-song.github.io/blog/2021/score/


## Probability distributions and generative models

From what I have seen, the whole premise of score models 
is to deal with data that are drawn from unknown data-generating distributions.
How do we generate new data that looks like existing data
when we don't know what the underlying data generating distribution is?
That's the key question that score-based models attempt to answer.

Let's consider the case of generating new data
when we know the data-generating distribution.
To anchor our understanding, 
we'll use what I consider to be the simplest complex example for this topic:
a Gaussian distribution.

Firstly, we have to understand that all probability distributions,
including the Gaussian distribution,
are capable of generating new data.
(In my other essay, [An Introduction to Probability and Computational Bayesian Statistics][compstats]
we go over the anatomy of a probability distribution in detail,
and I would recommend referencing that essay.)
Generating new data is also called "drawing samples".
As such, probability distributions can be considered a "data generator".

[compstats]: ./computational-bayesian-stats

Now, all continuous probability distributions have a **probability density function** (PDF).
The PDF is a function that controls the propensity of a distribution
to draw values on the support of the distribution.
For a Gaussian, values around the mean have the highest propensity to be drawn,
while values far away from the mean have the lowest propensity to be drawn.

In [None]:
#| echo: false
import jax.numpy as np 
from jax.scipy.stats import norm
import matplotlib.pyplot as plt 
import seaborn as sns

fig, ax = plt.subplots()
x = np.linspace(-3, 3, 1000)
y = norm.pdf(x)
plt.plot(x, y)
plt.xlabel("Support")
plt.ylabel("Likelihood")
sns.despine()

Now, note that the PDF, 
which returns the likelihood of observing any point on the support,
is nothing more than a math function.
That means we can take its logarithm and obtain the "log PDF"
(from here onwards I'll call it the "logp" function).
The logp function is useful in many settings,
not least in computational settings where multiplying likelihoods together
can be transformed to adding log-likelihoods.
For the Gaussian above, it looks like this:


In [None]:
fig, ax = plt.subplots()
y = norm.logpdf(x)
plt.plot(x, y)
plt.xlabel("Support")
plt.ylabel("Log Likelihood")
sns.despine()

The logp function is also differentiable!
That means we can take its derivative easily (using JAX, for example).
This gives us what we call a **score function**.
The derivative of a logp function is the score function of a PDF.
To help remember, think of the following chain:

```text
PDF (likelihood) --> logPDF (log likelihood) --> dlogPDF (score)
```

In [None]:
#| echo: false
from jax.scipy.stats import norm
from jax import grad, vmap

fig, axes = plt.subplots(figsize=(8, 4), nrows=1, ncols=2, sharex=True, sharey=True)

x = np.linspace(-3, 3)
y = norm.logpdf(x)

plt.sca(axes[0])
plt.plot(x, y, color="black")
plt.ylabel("logP(x)")
plt.title("Log Likelihood")

plt.sca(axes[1])
x_pt = -1.5
y_pt = norm.logpdf(x_pt)


def line(x):
    return grad(norm.logpdf)(x_pt) * (x - x_pt) + y_pt
xrange = np.linspace(x_pt - 1, x_pt + 1, 10)
plt.plot(x, y, color="black")
plt.scatter(x_pt, y_pt, color="gray")
plt.plot(xrange, vmap(line)(xrange), color="gray", ls="--")
plt.title("Score")

sns.despine()
plt.tight_layout()

Now, one thing we know about gradients is that 
they point to us the direction in which to move on the x-axis
in order to go upwards on the y-axis.
So if $\frac{d\log{P(x)}}{dx} = 1$, 
then it means we need to move in the positive direction 
in order to approach the maxima of $d\log{P}$.
And if $\frac{d\log{P(x)}}{dx} = -2$, 
then it means we need to move in the negative direction 
in order to approach the maxima of $d\log{P}$.

When we know the data generating distribution's log PDF equation,
with a bit of calculus, we can easily derive the distribution's score function too.
However, what hapens when we _don't_ know the data generating distribution's log PDF?

## Estimating the score function

As it turns out, there is [a paper][hyvarinen]
published by Aapo Hyvärinen in the 2005 in the Journal of Machine Learning Research
that details how to _estimate_ the score function
in the absence of knowledge of the true data generating distribution.
When I first heard of the idea, I thought it was crazy --
crazy cool that we could even do this!

[hyvarinen]: https://jmlr.csail.mit.edu/papers/volume6/hyvarinen05a/old.pdf

One key equation in the paper is equation #4.
This equation details how we can use an arbitrary function, $\psi(x, \theta)$,
to approximate the score function,
and the loss function needed to train the parameters of the function $\theta$
to approximate the score function.
I've replicated the equation below,
alongside a bullet-point explanation of what each of the terms are:

$$J(\theta) = \frac{1}{T} \sum_{t=1}^{T} \sum_{i=1}^{n} [\delta_i \psi_i(x(t); \theta) + \frac{1}{2} \psi_i(x(t); \theta)^2 ] + \text{const}$$

Here:

- $J(\theta)$ is the loss function that we wish to minimize w.r.t. the parameters $\theta$
- $\theta$ are the parameters of the function $\psi_i$
- $\psi_i(x(t); \theta)$ are the score functions for each dimension $i$ in $x$
- $x(t)$ are the i.i.d. samples from the unknown data-generating distribution.
- $\delta_i$ refers to the partial derivative w.r.t. dimension $i$ in $x$.
- $\text{const}$ is a constant term that effectively can be ignored.


We can implement this loss function in JAX!
This is how it would look like for a univariate distribution:

In [None]:
from jax import grad, vmap
from functools import partial 

def score_matching_loss(params, score_func, batch):
    score_func = partial(score_func, params)
    dscore_func = grad(score_func)

    term1 = vmap(dscore_func)(batch)
    term2 = (0.5 * vmap(score_func)(batch) ** 2)

    inner_term = term1 + term2
    return np.mean(inner_term).squeeze()


And more generally, for multivariate data:

In [None]:
from jax import jacfwd 
from typing import Callable 

def score_matching_loss(params, score_func: Callable, batch: np.ndarray) -> float:
    """Score matching loss function.

    This is taken from (Hyvärinen, 2005) (JMLR)
    and https://yang-song.github.io/blog/2019/ssm/.

    :param params: The parameters to the score function.
    :param score_func: Score function with signature `func(params, batch)`,
        which returns a scalar.
    :param batch: A batch of data. Should be of shape (batch, :),
        where `:` refers to at least 1 more dimension.
    :returns: Score matching loss, a float.
    """
    score_func = partial(score_func, params)
    dscore_func = jacfwd(score_func)

    # Jacobian of score function (i.e. dlogp estimator function).
    # This is also the Hessian (2nd derivative) of the logp.
    # The Jacobian shape is: `(i, i)`,
    # where `i` is the number of dimensions of the input data,
    # or the number of random variables.
    # Here, we want the diagonals instead, which is of shape (i,)
    term1 = vmap(dscore_func)(batch)
    term1 = vmap(np.diagonal)(term1)

    # Discretized integral of score function.
    term2 = 0.5 * vmap(score_func)(batch) ** 2
    term2 = np.reshape(term2, term1.shape)

    # Summation over the inner term, by commutative property of addition,
    # automagically gives us the trace of the Jacobian of the score function.
    # Yang Song's blog post refers to the trace
    # (final equation in the section
    # "Learning unnormalized models with score matching"),
    # while Hyvärinen's JMLR paper uses an explicit summation in Equation 4.
    inner_term = term1 + term2
    summed_by_dims = vmap(np.sum)(inner_term)
    return np.mean(summed_by_dims)
