# Difussion score matching
Working with *robust* posteriors via difussion score matching

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [2]:
%config InlineBackend.figure_format = "retina"

## The setting
### Proposition 3.1 of [1]
Suppose $y\vert\boldsymbol\theta$ is a member of the exponential family such that
$$
 p({\bf y} \vert \theta) = \exp(\eta(\boldsymbol\theta)^\intercal r(x) - a(\boldsymbol\theta) + b(x)).
$$

Take $\eta(\boldsymbol\theta) = \boldsymbol\theta$ and a squared exponential prior of the form 
$$
    \pi(\boldsymbol\theta)\propto
    \exp\left(-\frac{1}{2}(\boldsymbol\theta - \boldsymbol\mu)^\intercal\boldsymbol\Sigma^{-1}(\boldsymbol\theta - \boldsymbol\mu)\right)
$$

with
* $\eta: \boldsymbol\Theta \to \mathbb{R}^p$,
* $r: {\cal X} \to \mathbb{R}^p$,
* $a: \boldsymbol\Theta \to \mathbb{R}$, and
* $b: {\cal X} \to \mathbb{R}$

Then the posterior takes the form of a truncated Gaussian of the form
$$
    \pi(\boldsymbol\theta \vert y_{1:t}) \propto
    \left(
        -\frac{1}{2}(\boldsymbol\theta - \boldsymbol\mu_T)^\intercal \boldsymbol\Sigma_T^{-1} (\boldsymbol\theta - \boldsymbol\mu_T)
    \right)
$$
with
* $\boldsymbol\Sigma_T^{-1} = \boldsymbol\Sigma^{-1} + 2\omega T \Lambda_T$
* $\boldsymbol\mu_T = \boldsymbol\Sigma_T (\boldsymbol\Sigma^{-1}\boldsymbol\mu - \omega T \nu_T)$
* $\Lambda_T = \frac{1}{T}\sum_{t=1}^T \Lambda(y_t)$
* $\nu_T = \frac{2}{T} \nu(y_t)$
* $\Lambda(y) = \left( \nabla r^\intercal m m^\intercal \nabla r \right)(y)$
* $\nu(y) = \left( \nabla r^\intercal m m^\intercal \nabla r^\intercal + \nabla\cdot (mm^\intercal)\nabla r_t \right)$
* $m: {\cal X} \to \mathbb{R}^{M\times M}$

## Example: $\epsilon$-contaminated Gaussian

Let

$$
a(\theta_1, \theta_2) = -\frac{\theta_1^2}{4\theta_2} - \frac{1}{2}\log(-2\theta_2)
$$

$$
    r(x) =
    \begin{bmatrix}
    x, x ^2
    \end{bmatrix}
$$

In [31]:
key = jax.random.PRNGKey(314)
mean, std = 3, 4
mean, std = 0, 1

n_samples = 100
eps = 0.05
contamination = 10
key_data, key_epsilon = jax.random.split(key)
samples = jax.random.normal(key_data, shape=(n_samples,))
samples_contaminated = jax.random.bernoulli(key_epsilon, p=eps, shape=(n_samples,))

samples = samples * (1 - samples_contaminated) + contamination * (samples_contaminated)

In [49]:
def r(x):
    return jnp.array([[x, x ** 2]])

In [73]:
m = jnp.eye(1)
r_samples = jax.vmap(jax.jacfwd(r))(samples)

Lambda_samples = jnp.einsum("sji,jk,slk,lm->sik", r_samples, m, r_samples, m)
LambdaT = Lambda_samples.mean(axis=0)

In [69]:
ix = 1
(r_samples[ix].T @ m) @ (r_samples[ix].T @ m).T

Array([[1.       , 1.1393881],
       [1.1393881, 1.2982053]], dtype=float32)

In [35]:
r_samples.shape

(100, 2)

# References
* [1] Robust and Scalable Bayesian Online Changepoint Detection: https://arxiv.org/abs/2302.04759