# Parallel Kalman Filter
## With JAX

In [1]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import random

In [2]:
%config InlineBackend.figure_format = "retina"

In [3]:
μ0 = jnp.array([0, 0])
Σ0 = jnp.array([[1, 0], [0, 1]]) / 1000

A = jnp.array([
    [0, -1],
    [1, 0]
]) / 3

C = jnp.array([
    [1, -0.3],
    [0.5, 1]
]) 

Q = jnp.eye(2) * 2
R = jnp.eye(2) / 5

T = 4
timesteps = 20
n_samples = 5
key = random.PRNGKey(314)



In [4]:
key_z1, key_eps, key_delta = random.split(key, 3)
dt = T / timesteps
observation_size, state_size = C.shape

In [5]:
delta = random.multivariate_normal(key_delta, jnp.zeros(observation_size), Q, (n_samples, timesteps))

In [13]:
eps = random.multivariate_normal(key_eps, jnp.zeros(state_size),
                                 Q, (n_samples, timesteps))
eps.shape

(5, 20, 2)

In [14]:
eps[:, 200, :]

DeviceArray([[-2.8453205 , -1.0356324 ],
             [-1.462031  ,  0.11376988],
             [ 1.0670778 , -0.8557862 ],
             [ 0.8361672 ,  2.6661315 ],
             [-1.1375958 ,  3.5206814 ]], dtype=float32)

In [20]:
key = random.PRNGKey(314)
random.randint(key, (10, ), minval=1, maxval=10)[100000]

DeviceArray(3, dtype=int32)

## vmap test

In [21]:
def important_step(x):
    return C @ x

In [23]:
xsamp = random.randint(key, (10, 2), minval=1, maxval=10)

In [25]:
import jax

In [28]:
jax.vmap(important_step, 0)(xsamp)

DeviceArray([[ 2.7       ,  2.5       ],
             [ 0.79999995,  5.        ],
             [ 1.8       ,  5.5       ],
             [ 2.3       , 11.5       ],
             [ 0.89999986,  8.5       ],
             [ 4.8       ,  7.        ],
             [-0.20000005,  4.5       ],
             [-1.7       ,  9.5       ],
             [ 8.1       ,  7.5       ],
             [-0.4000001 ,  9.        ]], dtype=float32)