# Cholesky and Gram-schmidt

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [3]:
import seaborn as sns
sns.set_palette("colorblind")
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False

In [4]:
%config InlineBackend.figure_format = "retina"

## The setup

### Building innovations — the Gramm-Schmidt algorithm
The innovations initialised according to
$$
\begin{aligned}
    \varepsilon_1 &= y_1,\\
    R_1 &= {\rm Var}(\varepsilon_1) = {\rm Var}(y_1).
\end{aligned}
$$
Then, 
$$
\begin{aligned}
    \varepsilon_t &= y_t - \sum_{j=1}^{t-1}{\rm Cov}(y_t, \varepsilon_j)\,R_j^{-1}\,\varepsilon_j,\\
    R_t &= {\rm Var}(\varepsilon_t).
\end{aligned}
$$

As a consequence, it can be shown that
$$
    y_{1:t} = {\bf L}\,\varepsilon_{1:t},
$$
with
$$
    {\bf L}_{t,j} =
    \begin{cases}
    {\rm Cov}(y_t, \varepsilon_j)R_j^{-1} & j=1,\ldots,t-1,\\
    {\bf I} & j=t,\\
    {\bf 0} & j > t.
    \end{cases}
$$

### The cholesky decomposition
The relationship between innovations and the Cholesky decomposition is given by the fact that
$$
    {\rm Var}(y_{1:t}) = {\bf L}\,{\bf R}\,{\bf L}^\intercal
$$
with
$$
    {\bf R} = {\rm Diag}(R_1, \ldots, R_t)
$$
and
$$
    {\bf L}_{t,j} =
    \begin{cases}
    {\rm Cov}(y_t, \varepsilon_j)R_j^{-1} & j=1,\ldots,t-1,\\
    {\bf I} & j=t,\\
    {\bf 0} & j > t.
    \end{cases}
$$

Hence,
> computation of the innovations is equivalent to computing the cholesky factorisation of ${\rm Var}(y_{1:t})$.

In [11]:
def ssm(key, signal_init, n_steps, qt, rt):
    def ssm_step(signal_prev, key):
        key_measurement, key_state = jax.random.split(key)
        ut = jax.random.normal(key_state) * qt
        et = jax.random.normal(key_measurement) * rt

        signal = signal_prev + ut
        measurement = signal + et

        return signal, (signal, measurement)

    keys = jax.random.split(key, n_steps)
    _, (signals, measurements) = jax.lax.scan(ssm_step, signal_init, keys)
    return signals, measurements

vssm = jax.vmap(ssm, in_axes=(0, None, None, None, None), out_axes=-1)

In [7]:
key = jax.random.PRNGKey(314)
key_sample, key_test = jax.random.split(key)

In [8]:
signal_init = 0.0
n_steps = 100
qt = 0.05
rt = 0.1
signals, measurements = ssm(key, signal_init, n_steps, qt, rt)

## The cholesky decompoistion

Recall:
$$
    y_{1:t} = {\bf L}\,\varepsilon_{1:t}.
$$
Then,
$$
    \varepsilon_{1:t} = {\bf L}^{-1}\,y_{1:t}.
$$

In [12]:
n_trials = 500
keys = jax.random.split(key_sample, n_trials)

In [88]:
n_steps = 100
signal_init = 0.0
rt = 0.1
qt = 0.05

signals, measurements = vssm(keys, signal_init, n_steps, qt, rt)

#### Building innovations from cholesky decomposition

In [89]:
var_y = jnp.cov(measurements)

L = jnp.linalg.cholesky(var_y)

R = jnp.diag(L)
L = L.at[jnp.diag_indices(n_steps)].set(1.0)

In [75]:
innovations = jnp.linalg.solve(L, measurements)

## Numerical evaluation of the BLUP
(Using the innovations)

$$
\begin{aligned}
    f_{t|j}
    &= {\rm Cov}(f_t, y_{1:j})\,{\rm Var}(y_{1:j})^{-1}\,y_{1:j}\\
    &= {\rm Cov}(f_t, {\bf L}\,y_{1:j}){\rm Var}({\bf L}\,\varepsilon_{1:j})^{-1}\,{\bf L}\,\varepsilon_{1:j}\\
    &= {\rm Cov}(f_t, \varepsilon_{1:j})\,{\rm Var}(\varepsilon_{1:j})^{-1}\,\varepsilon_{1:j}\\
    &= \sum_{k=1}^j\,{\rm Cov}(f_t, \varepsilon_k)\,{\rm Var}(\varepsilon_k)^{-1}\,\varepsilon_k\\
    &= \sum_{k=1}^j\,{\rm Cov}(f_t, \varepsilon_k)\,R_k^{-1}\,\varepsilon_k\\
\end{aligned}
$$