# L-RVGA for rotation MNIST

The Limited-memory Recursive Variational Gaussian Approximation equations are given by

$$
\begin{aligned}
    q^*_t(\theta) &= \arg\min_{\mu^*, P^*}
        \text{KL}\left(  \mathcal{N}(\theta \vert \mu^*, P^*)  || \mathcal{N}(\theta \vert \mu_{t-1}, P_{t-1}^{-1}) p(y_t \vert\theta) \right)\\
    q_t(\theta) &= \arg\min_{\mu, W, \Psi}
         \text{KL}\left( \mathcal{N}(\theta \vert \left(WW^T + \Psi\right)^{-1} || \mathcal{N}(\theta \vert \mu^*, P^*)   \right)
\end{aligned}
$$

In [1]:
import jax
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from dynamax.utils import datasets

## Load Rotation MNIST

In [2]:
np.random.seed(314)
num_train = 100

train, test = datasets.load_rotated_mnist(target_digit=2)
X_train, y_train = train
X_test, y_test = test

X_train = jnp.array(X_train)
y_train = jnp.array(y_train)

X = jnp.array(X_train)[:num_train]
y = jnp.array(y_train)[:num_train]

# ix_sort = jnp.argsort(y)
# X = X[ix_sort]
# y = y[ix_sort]

## L-RVGA

We consider a linear problem. Let $y_t = \theta^T x_t + \epsilon_t$, $\epsilon_t \sim \mathcal{N}(0, 1)$. Considering the approximated R-VGA, the update for $\mu_t\in\mathbb{R}^{d}$, $W_t\in\mathbb{R}^{d\times p}$, and $\Psi_t\in\mathbb{R}^{d\times d}$ becomes

$$
\begin{aligned}
\mu_t &= \mu_{t-1} + \left(W_t W_t^\intercal + \Psi_t\right)^{-1}x_t\left(y_t - \mu_{t-1}^\intercal x_t\right)\\
W_t W_t^\intercal + \Psi_t &\underset{\text{FA}}{\approx} \alpha_t\left( W_{t-1}W_{t-1}^\intercal + \Psi_{t-1} \right) + \beta x_tx_t^\intercal
\end{aligned}
$$

The FA update for the posterior covariance is obtained by iterating `nb_inner_loop` number of times the following set of equations

$$
\begin{aligned}
M &= {\bf I}_p + W^\intercal \Psi^{-1} W\\
V &= \beta_t x_tx_t^\intercal\Psi^{-1} W + \alpha_t\left[ W_{t-1}W_{t-1}^\intercal \Psi^{-1} W + \Psi_{t-1}\Psi^{-1} W \right]\\
W^{(n)} &= V\left({\bf I}_p + M^{-1} W^\intercal\Psi^{-1}V\right)^{-1}\\
\Psi^{(n)} &= \beta_t \text{diag}\left(x_tx_t^\intercal\right) + \alpha_t \text{diag}\left(W_{t-1}W_{t-1}^\intercal\right) + \alpha_t\Psi_{t-1} - \text{diag}\left(W^{(n)}M^{-1}V^\intercal\right)\\
W &= W^{(n)}, \Psi = \Psi^{(n)}
\end{aligned}
$$

### Initialisiation

We consider the initialisiation rule of §5.1.2:
1. $\Psi_0 = \psi_0 {\bf I}_d$, where $\psi_0 > 0$,
2. $W_0 \in \mathbb{R}^{d\times p}$,  where the columns are random vectors independently drawn from an istropic Gaussian distribution in $\mathbb{R}^d$ and which have been normalised so that $\forall k. ||u_k|| = w_0$.
We let

$$
\begin{aligned}
    \psi_0 &= (1 - \epsilon) \frac{1}{\sigma^2_0},\\
    w_0 &= \sqrt{\frac{\epsilon d}{p\sigma^2_0}}
\end{aligned}
$$

with $0 < \epsilon \ll 1$ a small parameter.

In [12]:
x = X_train[0]
num_dims, *_ = x.shape

In [59]:
num_basis = 100
key = jax.random.PRNGKey(314)

x = jax.random.normal(key, (num_dims,))
W = jax.random.normal(key, (num_dims, num_basis))
Psi_inv = 1 / jax.random.normal(key, (num_dims,))

In [62]:
def fa_approx():
    num_basis = len(Psi)
    Psi_inv = 1 / Psi
    I = jnp.eye(num_basis)
    M = I + jnp.einsum("ji,j,jk->jk", W, inv_Psi, W)
    M_inv = jnp.linalg.inv(M)
    V_beta = jnp.einsum("i,j,j,jk->ik", x, x, Psi_inv, W)
    V_alpha = (
        jnp.einsum("ij,kj,k,kl->il", W_prev, W_prev, Psi_inv, W) +
        jnp.einsum("i,i,ij->ij", Psi_prev, Psi_inv, W)
    )
    V = beta * V_beta + alpha * V_alpha
    # Value_update
    W_solve = I + jnp.einsum("ij,kj,k,kl->il", M_inv, W, Psi_inv, V)
    W = jnp.linalg.solve(W_solve, V).T
    Psi = (
        beta * jnp.einsum("i,i", x, x) +
        alpha * jnp.einsum("ij,ij->i", W_prev, W_prev) + 
        alpha * Psi_prev -
        jnp.einsum("ij,jk,ik->i", W, M_inv, V)
    )
    return W, Psi