# Score Functions

The first thing I need to wrap my head around
is what the score function is
and how we estimate it from data.
I will use the venerable Gaussian to anchor my understanding.

## Definition

The definition of the score function
is the derivative of the log density of a probability distribution
w.r.t. the support.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax.numpy as np
from jax import grad, random, vmap
from jax.scipy.stats import norm
from jax.tree_util import Partial as partial
import matplotlib.pyplot as plt
from jaxopt import GradientDescent


In [None]:
x = np.linspace(-3, 3, 1000)
y = norm.logpdf(x, loc=0, scale=1)

model_score = grad(norm.logpdf)


We know that at the top of the Gaussian,
the gradient should be zero.

In [None]:
model_score(0.0)


At the tails, the gradient should be of higher magnitude
than at the mean.

In [None]:
model_score(-3.0)


In [None]:
model_score(3.0)


## Estimating the score function

What happens if we have data
but don't know the parameters of the true data-generating density?
In this case, we need to estimate the score function,
which means estimating the parameters of the model.
To do this, I will lean on work by [Aapo Hyvärinen from 2005 in JMLR][jmlr2005].
In that work, Hyvärinen proposes 
to estimate the parameters of the data-generating density:

> by minimizing the expected squared distance between the model score function
> $\psi(.;\theta)$ and the data score function $\psi_x(.)$.

[jmlr2005]: https://www.jmlr.org/papers/volume6/hyvarinen05a/hyvarinen05a.pdf

This squared distance is defined as the function $J(\theta)$,
where $\theta$ are the parameters of the data-generating model.
For a finite sample, Hyvärinen provides an exact formula that we can try to implement
in Python/JAX NumPy.

> $\tilde{J}(\theta) = \frac{1}{T} \sum_{t=1}^{T} \sum_{i=1}^{n} [\delta_i \psi_i(x(t); \theta) + \frac{1}{2} \psi_i (x(t); \theta)^2 ] + \text{const.}$

That's one heck of a complicated formula.


As with all formulas, we need the definitions.

- $i$ _probably_ is an indexer into dimensions 
  for a multidimensional probability distribution.
- $\psi_i$ is the data score function for a finite sample.
- $\delta_i \psi_i$ is the gradient of the data score function, $\psi_i$.
  Yes, you heard right, we need the derivative of a derivative, i.e. the 2nd derivative!
- $\theta$ are the parameters of the score function. 
  If the score function comes from the logpdf of a Gaussian,
  then $\mu$ and $\sigma$ are the parameters;
  if the score function is a neural network approximation,
  then $\theta$ refer to the parameters of the neural network.
- $x(t)$ are the observed samples of data. 
  As the outer $\frac{1}{T} \sum_{t=1}^{T}$ suggests, 
  we will need to do a mean over observed samples.

Sample some data from a Gaussian.

In [None]:
key = random.PRNGKey(44)

true_mu = 3.0
true_sigma = 1.0
data = random.normal(key, shape=(1000,)) * true_sigma + true_mu
data


In [None]:
data.mean(), data.std()

Evaluate the score of the data _under the true model_.
We ensure that the resulting score function has the same signature
as the underlying distribution that it is based on.

In [None]:
from score_models.models import gaussian_model

init_fun, apply_fun = gaussian_model()
true_params = (true_mu, np.log(true_sigma))

(
    apply_fun(true_params, true_mu),
    apply_fun(true_params, true_mu + true_sigma),
    apply_fun(true_params, true_mu - true_sigma * 3),
)

Calculate the true data score per draw.

In [None]:
# Don't forget to pass in log of true_sigma!!!
true_data_score = vmap(partial(apply_fun, true_params))(data)
true_data_score


Use gradient descent to find parameters of the Gaussian
that minimize score function loss.
To do this, we will use the GradientDescent solver from `jaxopt`,
which will give us a really concise syntax.

In [None]:
k1, k2, k3 = random.split(key, 3)
_, params_init = init_fun(k1)
params_init

In [None]:
score_func = partial(apply_fun, params_init)
dscore_func = grad(score_func)

In [None]:
from score_models.losses import l2_norm
from jax import grad
from typing import Callable

def score_matching_loss(params, score_func, batch):
    score_func = partial(score_func, params)
    dscore_func = grad(score_func)

    term1 = vmap(dscore_func)(batch)
    term2 = (0.5 * vmap(score_func)(batch) ** 2)

    inner_term = term1 + term2
    return np.mean(inner_term).squeeze()



In [None]:

score_matching_loss(params_init, apply_fun, data)
myloss = partial(score_matching_loss, score_func=apply_fun)

In [None]:
solver = GradientDescent(fun=myloss, maxiter=20000, stepsize=5e-2)
result = solver.run(params_init, batch=data)

Do the resulting params match up?

In [None]:
mu, log_sigma = result.params
mu, np.exp(log_sigma)


In [None]:
np.mean(data), np.std(data)


Looks like they do!

Now let's compare:

In [None]:
def y_eq_x(x, y, ax):
    minval = min(min(x), min(y))
    maxval = max(max(x), max(y))

    ax.plot([minval, maxval], [minval, maxval])


In [None]:
est_mu, est_log_sigma = result.params

est_mu - true_mu, np.exp(est_log_sigma) - true_sigma


In [None]:
model_scores = vmap(partial(apply_fun, (result.params)))(data)
plt.scatter(true_data_score, model_scores)
y_eq_x(true_data_score, model_scores, plt.gca())
plt.xlabel("True Data Score")
plt.ylabel("Model Score")
plt.title("Gaussian Model Performance")
plt.show()


What if we try to approximate the score function with a neural network instead?

## Approximate Score Function with NN

In [None]:
from typing import Tuple
from jax.example_libraries import stax 



In [None]:
from score_models.models import nn_model
init_fun, apply_fun = nn_model()

def score_fun(params, batch):
    out = apply_fun(params, batch).squeeze()
    return out 


_, params_init = init_fun(rng=random.PRNGKey(44), input_shape=(1,))
myloss = partial(score_matching_loss, score_func=score_fun)

solver = GradientDescent(fun=myloss, maxiter=1200)
result = solver.run(params_init, batch=data)


In [None]:
model_scores = vmap(partial(apply_fun, result.params))(data).squeeze()
plt.scatter(true_data_score, model_scores)
y_eq_x(true_data_score, model_scores, plt.gca())
plt.title("Trained Neural Network")
plt.xlabel("True Data Score")
plt.ylabel("Model Data Score")
plt.show()

In [None]:
model_scores = vmap(partial(apply_fun, params_init))(data).squeeze()
plt.scatter(true_data_score, model_scores)
y_eq_x(true_data_score, model_scores, plt.gca())
plt.title("Initializsed Neural Network")
plt.xlabel("True Data Score")
plt.ylabel("Model Data Score")
plt.show()
