# Score Functions

The first thing I need to wrap my head around
is what the score function is
and how we estimate it from data.
I will use the venerable Gaussian to anchor my understanding.

## Definition

The definition of the score function
is the derivative of the log density of a probability distribution
w.r.t. the support.

In [None]:
from jax import numpy as np, grad
from jax.scipy.stats import norm 

In [None]:
x = np.linspace(-3, 3, 1000)
y = norm.logpdf(x, loc=0, scale=1)

model_score = grad(norm.logpdf)

We know that at the top of the Gaussian,
the gradient should be zero.

In [None]:
model_score(0.0)

At the tails, the gradient should be of higher magnitude
than at the mean.

In [None]:
model_score(-3.0)

In [None]:
model_score(3.0)

## Estimating the score function

What happens if we have data
but don't know the parameters of the true data-generating density?
In this case, we need to estimate the score function,
which means estimating the parameters of the model.
To do this, I will lean on work by Aapo Hyvärinen from 2005 in JMLR.
In this work, Hyvärinen proposes 
to estimate the parameters of the data-generating density

> by minimizing the expected squared distance between the model score function
> $\psi(.;\theta)$ and the data score function $\psi_x(.)$.

This squared distance is defined as the function $J(\theta)$,
where $\theta$ are the parameters of the data-generating model:

> $J(\theta) = 


TBD 

In [None]:
from jax import random 

key = random.PRNGKey(44)

true_mu = 3.0
true_sigma = 1.0
data = random.normal(key, shape=(100,)) * true_sigma + true_mu
data

In [None]:
from functools import partial 
from jax import vmap 

score_data = grad(partial(norm.logpdf, loc=true_mu, scale=true_sigma))

true_score_data = vmap(score_data)(data)
true_score_data

In [None]:
k1, k2, k3 = random.split(key, 3)
mu = random.normal(k1)
log_sigma = random.normal(k2)

params_init = (mu, log_sigma)

In [None]:
def loss(params, obs_data):
    mu_est, log_sigma_est = params
    sigma_est = np.exp(log_sigma_est)
    model_score_fn = grad(partial(norm.logpdf, loc=mu_est, scale=sigma_est))
    model_score_ffn = grad(model_score_fn)

    term1 = vmap(model_score_ffn)(obs_data)
    term2 = 0.5 * vmap(model_score_fn)(obs_data) ** 2

    inner_term = term1 + term2

    return np.mean(inner_term)


In [None]:
loss(params_init, data)

In [None]:
from jax import value_and_grad
dloss = value_and_grad(loss)

In [None]:
# SGD

mu, log_sigma = params_init
losses = []
for i in range(1200):
    loss_val, (dmu, dsigma) = dloss((mu, log_sigma), data)
    mu -= dmu * 0.1
    log_sigma -= dsigma * 0.1
    losses.append(loss_val)

In [None]:
import matplotlib.pyplot as plt 

plt.plot(losses)

In [None]:
params_init

In [None]:
mu, np.exp(log_sigma)