# Cholesky and Gram-schmidt

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [3]:
import seaborn as sns
sns.set_palette("colorblind")
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False

In [4]:
%config InlineBackend.figure_format = "retina"

## The setup

### Building innovations — the Gramm-Schmidt algorithm
The innovations initialised according to
$$
\begin{aligned}
    \varepsilon_1 &= y_1,\\
    R_1 &= {\rm Var}(\varepsilon_1) = {\rm Var}(y_1).
\end{aligned}
$$
Then, 
$$
\begin{aligned}
    \varepsilon_t &= y_t - \sum_{j=1}^{t-1}{\rm Cov}(y_t, \varepsilon_j)\,R_j^{-1}\,\varepsilon_j,\\
    R_t &= {\rm Var}(\varepsilon_t).
\end{aligned}
$$

### The cholesky decomposition
The relationship between innovations and the Cholesky decomposition is given by
$$
    {\rm Var}(y_{1:t}) = {\bf L}\,{\bf R}\,{\bf L}^\intercal
$$
with
$$
    {\bf R} = {\rm Diag}(R_1, \ldots, R_t)
$$
and
$$
    {\bf L}_{t,j} =
    \begin{cases}
    {\rm Cov}(y_t, \varepsilon_j)R_j^{-1} & j=1,\ldots,t-1,\\
    {\bf I} & j=t,\\
    {\bf 0} & j > t.
    \end{cases}
$$

Hence,
> computation of the innovations is equivalent to computing the cholesky factorisation of ${\rm Var}(y_{1:t})$.

In [11]:
def ssm(key, signal_init, n_steps, qt, rt):
    def ssm_step(signal_prev, key):
        key_measurement, key_state = jax.random.split(key)
        ut = jax.random.normal(key_state) * qt
        et = jax.random.normal(key_measurement) * rt

        signal = signal_prev + ut
        measurement = signal + et

        return signal, (signal, measurement)

    keys = jax.random.split(key, n_steps)
    _, (signals, measurements) = jax.lax.scan(ssm_step, signal_init, keys)
    return signals, measurements

vssm = jax.vmap(ssm, in_axes=(0, None, None, None, None), out_axes=-1)

In [7]:
key = jax.random.PRNGKey(314)
key_sample, key_test = jax.random.split(key)

In [8]:
signal_init = 0.0
n_steps = 100
qt = 0.05
rt = 0.1
signals, measurements = ssm(key, signal_init, n_steps, qt, rt)

## The cholesky decompoistion

In [12]:
n_trials = 500
keys = jax.random.split(key_sample, n_trials)

In [13]:
n_steps = 100
signal_init = 0.0
rt = 0.1
qt = 0.05

signals, measurements = vssm(keys, signal_init, n_steps, qt, rt)

In [40]:
var_y = jnp.cov(measurements)

In [42]:
jnp.linalg.cholesky(var_y)

Array([[ 0.10886046,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.01506142,  0.12025248,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.0172424 ,  0.03949473,  0.12908962, ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [ 0.04262104,  0.09049029, -0.01284586, ...,  0.11989492,
         0.        ,  0.        ],
       [ 0.03957634,  0.084027  , -0.01073   , ...,  0.05197307,
         0.10901859,  0.        ],
       [ 0.03986522,  0.08856162, -0.00640565, ...,  0.04934085,
         0.03799685,  0.1199625 ]], dtype=float32)

In [22]:
Rvals = jnp.diag()

In [23]:
Rvals

Array([0.0118506 , 0.0146875 , 0.01852126, 0.0192252 , 0.0247822 ,
       0.02648306, 0.02678788, 0.03088834, 0.03053957, 0.03285145,
       0.03272242, 0.03545737, 0.04072683, 0.04229142, 0.04577431,
       0.0471628 , 0.05026382, 0.05352138, 0.05901953, 0.05662038,
       0.05921943, 0.06327067, 0.06815398, 0.07089395, 0.07562165,
       0.07366988, 0.0751664 , 0.08050065, 0.08111937, 0.0832592 ,
       0.09176692, 0.09055013, 0.09518983, 0.09526988, 0.09790165,
       0.1070644 , 0.10218427, 0.10821304, 0.10993323, 0.1121665 ,
       0.10933848, 0.11234052, 0.10574164, 0.10864503, 0.11467545,
       0.1139204 , 0.12170965, 0.12352176, 0.12664904, 0.13058473,
       0.13054559, 0.12624188, 0.13363758, 0.1360575 , 0.13613725,
       0.13879074, 0.14282578, 0.145495  , 0.14597006, 0.15086403,
       0.15087079, 0.15628406, 0.15450262, 0.16241577, 0.16619822,
       0.16896376, 0.16326421, 0.16755441, 0.17383629, 0.17321117,
       0.17214847, 0.18351167, 0.17444202, 0.17694551, 0.18983