# LoFi: Low-rank (extended) Kalman filter

In [1]:
import jax
import flax.linen as nn
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [2]:
from rebayes_mini.methods import low_rank_filter as lofi

In [3]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

In [4]:
key = jax.random.PRNGKey(314)
diag = jax.random.uniform(key, (10,))
W = jax.random.normal(key, (10, 20))
dyn = 0.3

In [5]:
diag_pred = 1 / (1 / diag + dyn)
diag_pred

Array([0.07444949, 0.07686263, 0.58481854, 0.33342364, 0.31272268,
       0.6124436 , 0.40948197, 0.32819864, 0.53268456, 0.4249124 ],      dtype=float32)

In [6]:
(W.T @ jnp.diag(diag_pred) @ jnp.diag(1 / diag) @ W).sum(axis=0)

Array([ 15.17951   ,  31.687904  ,   0.47476912,  27.473309  ,
        -2.2373953 ,  -0.28310037,  -0.5816016 ,   7.62045   ,
        19.146309  ,   5.816488  ,   4.3553276 ,  -8.488842  ,
        25.902523  ,   9.020243  ,  11.763821  ,  -5.0538864 ,
       -25.545195  ,  -4.041623  ,  30.341537  ,   2.0045276 ],      dtype=float32)

In [7]:
jnp.einsum("ji,j,j,jk->ik", W, diag_pred, 1 / diag, W).sum(axis=0)

Array([ 15.227335  ,  31.653662  ,   0.5590105 ,  27.50296   ,
        -2.227013  ,  -0.21534252,  -0.58198357,   7.592178  ,
        19.10627   ,   5.8170905 ,   4.2841187 ,  -8.524591  ,
        25.912235  ,   9.0385685 ,  11.78618   ,  -5.1031895 ,
       -25.543436  ,  -4.0270414 ,  30.284472  ,   1.9885311 ],      dtype=float32)

In [8]:
C_inv = jnp.einsum("ji,j,j,jk->ik", W, diag_pred, 1 / diag, W)

In [9]:
org = (jnp.diag(diag_pred) @ jnp.diag(1 / diag) @ W @ C_inv).sum(axis=1)
org

Array([ 186.09457 ,  212.62166 ,   79.07352 , -110.429886, -144.54663 ,
       -188.22388 ,   56.556385,  103.65888 ,  -49.551247,  172.92783 ],      dtype=float32)

In [10]:
ein = jnp.einsum("i,i,ij,jk -> i", diag_pred, 1 / diag, W, C_inv)
ein

Array([ 186.57985 ,  213.61336 ,   79.079506, -109.92603 , -144.95905 ,
       -187.01534 ,   56.58297 ,  103.55682 ,  -49.75179 ,  173.07147 ],      dtype=float32)

In [11]:
jnp.linalg.cholesky(jnp.array([[4]]))

Array([[2.]], dtype=float32)

## Moons' dataset

In [12]:
from sklearn.datasets import make_moons

In [13]:
n_samples = 500
n_test = 300
data = make_moons(n_samples=(n_samples + n_test), random_state=3141, noise=0.15)
X, y = jax.tree_map(jnp.array, data)
X_test, y_test = X[-n_test:], y[-n_test:]
X, y = X[:-n_test], y[:-n_test]

key = jax.random.PRNGKey(314)

In [187]:
%%time
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(50)(x)
        x = nn.relu(x)
        x = nn.Dense(50)(x)
        x = nn.relu(x)
        x = nn.Dense(50)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x # sklearn


model = MLP()
params = model.init(key, X)

CPU times: user 35.7 ms, sys: 1.15 ms, total: 36.8 ms
Wall time: 21.8 ms


In [188]:
agent = lofi.BernoulliFilter(
    model.apply,
    dynamics_covariance=1e-7
)

In [204]:
bel_init = agent.init_bel(params, cov=0.1)
bel_pred = agent._predict(bel_init)
bel_update = agent._update(bel_pred, X[0], y[0])

In [219]:
bel = agent.init_bel(params, cov=0.1)
bel, _ = agent.step(bel, (X[0], y[0]), callback_fn=callbacks.get_null)
bel

LoFiState(mean=Array([0.48382032, 0.        , 0.        , ..., 0.00447403, 4.289242  ,
       0.06426697], dtype=float32), diagonal=Array([0.1, 0.1, 0.1, ..., 0.1, 0.1, 0.1], dtype=float32), low_rank=Array([[0.06140734, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.5240219 , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]], dtype=float32))

In [222]:
bel, _ = agent.step(bel, (X[1], y[1]), callback_fn=callbacks.get_null)

In [223]:
bel

LoFiState(mean=Array([nan, nan, nan, ..., nan, nan, nan], dtype=float32), diagonal=Array([nan, nan, nan, ..., nan, nan, nan], dtype=float32), low_rank=Array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32))

In [191]:
from rebayes_mini import callbacks

In [192]:
bel, bel_hist = agent.scan(bel_init, y, X, callbacks.get_updated_mean)
bel = jax.block_until_ready(bel)

In [193]:
bel

LoFiState(mean=Array([nan, nan, nan, ..., nan, nan, nan], dtype=float32), diagonal=Array([nan, nan, nan, ..., nan, nan, nan], dtype=float32), low_rank=Array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32))