# Kalman Filter (Numpyro) - Object Tracking

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".here"])

# append to path
sys.path.append(str(root))

%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from typing import NamedTuple
from jax.random import multivariate_normal, split
from tqdm.notebook import tqdm, trange
from jax.random import multivariate_normal
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
import jaxkf._src.lgssm as lgssm


import matplotlib.pyplot as plt

## Simulating Data

### State Transition Dynamics

We assume that we can fully describe the state when we have the `(x,y)` coordinates of the position and the `(x,y)` velocity. So we can write this as:

$$
\mathbf{z}_t = 
\begin{bmatrix}
z_t^1 \\ z_t^2 \\ \dot{z}_t^1 \\ \dot{z}_t^2
\end{bmatrix}
$$


where $z_t^d$ is the coordinate of the position and $\dot{z}^d$ is the velocity. 

We can describe the dynamics of the system using the following system of equations:

$$
\begin{aligned}
z_t^1 &= z_{t-1}^1 + \Delta_t \dot{z}_t^1 + \epsilon_t^1 \\
z_t^2 &= z_{t-1}^2 + \Delta_t \dot{z}_t^2 + \epsilon_t^2 \\
\dot{z}_t^1 &= \dot{z}_{t-1}^1 + \epsilon_t^3 \\
\dot{z}_t^2 &= \dot{z}_{t-1}^2 + \epsilon_t^4 \\
\end{aligned}
$$

This is a very simple formulation which takes a first order approximation to the change in position based on speed and we also assume constant velocity. Note, we also include some noise because we assume that some of the dynamics are noisy, i.e. there are random acceleration and position changes in the model. 


We can also put this into matrix formulation like so:

$$
\mathbf{z}_t = \mathbf{A}_t \mathbf{z}_{t-1} + \boldsymbol{\epsilon}_t
$$

where:

$$
\mathbf{A}_t = 
\begin{bmatrix}
1 & 0 & \Delta_t & 0 \\
0 & 1 & 0 & \Delta_t \\
0 & 0 & 1 & 0 \\
0 & 0 & 0 & 1 \\
\end{bmatrix}, \;\; \mathbf{A}_t \in \mathbb{R}^{4\times 4}
$$


---
### Emissions Model

We can only fully observe the locations (not the velocities). So this will be a lower dimensional vector of size 2-D. The system of equations are as follows:

$$
\begin{aligned}
x_t^1 &= z_t^1 + \delta_t^1 \\
x_t^2 &= z_t^2 + \delta_t^2 \\
\end{aligned}
$$

This is a very simple model where we assume we can extract the direct positions (plus some noise) from the state.

We can write this in an abbreviated matrix formulation:

$$
\mathbf{x}_t = \mathbf{C}_t \mathbf{z}_t + \delta_t
$$

where:

$$
\mathbf{C}_t = 
\begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
\end{bmatrix}, \;\; \mathbf{C}_t \in \mathbb{R}^{2 \times 4}
$$

## Model

* [x] Modeling Noises Only
* [x] Modeling States/Observations
* [ ] Using Conditioning Notation
* [x] Using Plate Notation

In [None]:
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro import diagnostics, infer

In [None]:
def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    mu0: jnp.ndarray,
    Sigma0: jnp.ndarray,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    batch: int = 1,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        batch, time_steps, n_dims = x_obs.shape
        x_obs = jnp.swapaxes(x_obs, 0, 1)

    # Transition Functions
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)
    trans_noise_cov = numpyro.deterministic("trans_noise_cov", trans_mat)

    # Emission Functions
    emiss_mat = numpyro.deterministic("emiss_mat", trans_mat)
    emiss_noise_cov = numpyro.deterministic("emiss_noise_cov", trans_mat)

    # ==================
    # sample from prior
    # ==================
    prior_dist = dist.MultivariateNormal(mu0, Sigma0)

    z0 = numpyro.sample("z0", prior_dist, sample_shape=(batch,))

    # print("Prior:", z0.shape)

    fn_vec_dot = jax.vmap(jnp.dot, in_axes=(None, 0))

    # ==================
    # Model
    # ==================
    def body(z_prev, x_prev):

        # transition
        with numpyro.plate("batches", batch, dim=-1):
            # print("Trans Mult:", z_prev.shape, trans_mat.shape)
            z = fn_vec_dot(trans_mat, z_prev)
            # print("Z Before:", z.shape)
            z = numpyro.sample("z", dist.MultivariateNormal(z, trans_noise_cov))
            # print("Z After:", z.shape)
            # emission
            # print("Emiss Mult:", z.shape, emiss_mat.shape)
            x = fn_vec_dot(emiss_mat, z)
            # if x_prev is not None:
            #     print("X Before:", x.shape, x_prev.shape)
            # else:
            #     print("X Before:", x.shape)
            x = numpyro.sample(
                "x", dist.MultivariateNormal(x, emiss_noise_cov), obs=x_prev
            )
            # print("X After:", x.shape)
        return z, (z, x)

    # create function
    fn = lambda states, x_obs: scan(body, states, x_obs, length=time_steps)

    # loop through data
    _, (z, x) = scan(body, z0, x_obs, length=time_steps)

    return z, x

In [None]:
# init prior dist
mu0 = jnp.array([8.0, 5.0, 1.0, 0.0])
Sigma0 = 1e-5 * jnp.eye(4)

prior_dist = dist.MultivariateNormal(mu0, Sigma0)

# =================
# transition model
# =================
state_dim = 4
dt = 0.1
step_std = 0.1

trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)
trans_noise_cov = step_std**2 * jnp.eye(state_dim)
trans_noise = dist.MultivariateNormal(jnp.zeros(state_dim), trans_noise_cov)

# =================
# emission model
# =================
noise_std = 0.02
obs_dim = 2

emiss_mat = jnp.eye(N=2, M=4)
emiss_noise_cov = noise_std**2 * jnp.eye(obs_dim)
emiss_noise = dist.MultivariateNormal(jnp.zeros(obs_dim), emiss_noise_cov)

In [None]:
# # Inference
# kernel = infer.NUTS(simulated_kalman_filter)
# mcmc = infer.MCMC(kernel, num_warmup=100, num_samples=100)
# mcmc.run(rng_key_infer,
#     trans_mat=trans_mat, trans_noise_cov=trans_noise_cov,
#     emiss_mat=emiss_mat, emiss_noise_cov=emiss_noise_cov,
#     mu0=mu0, Sigma0=Sigma0,
#     time_steps=time_steps,
#     x_obs=x_s,
#     batch=0)
# posterior_samples = mcmc.get_samples()

#### Prior Samples (Propagate)

In [None]:
time_steps = 80
batch = 5

with numpyro.handlers.seed(rng_seed=123):
    z_samples_prior, x_samples_prior = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_cov,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_cov,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
        batch=batch,
    )
# np.testing.assert_array_almost_equal(x_s, x_s_)
# np.testing.assert_array_almost_equal(z_s, z_s_)
z_samples_prior.shape, x_samples_prior.shape

In [None]:
fig, ax = plt.subplots()

n_time_steps, n_samples, n_dims = x_samples_prior.shape

for i_sample in range(n_samples):
    ax.plot(
        z_samples_prior[..., i_sample, 0],
        z_samples_prior[..., i_sample, 1],
        color="black",
        label="True State",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_samples_prior[..., i_sample, 0],
        x_samples_prior[..., i_sample, 1],
        label="Measurements",
        color="red",
        alpha=0.5,
    )
    break
ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

#### Prior Samples (Predictive)

In [None]:
time_steps = 80
batch = 1
num_samples = 5
rng_key_prior = jax.random.PRNGKey(42)
# prior
predictive = infer.Predictive(simulated_kalman_filter, num_samples=num_samples)
prior_samples = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_cov,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_cov,
    mu0=mu0,
    Sigma0=Sigma0,
    time_steps=time_steps,
    x_obs=None,
    batch=batch,
)

In [None]:
[posterior_samples.keys()]

In [None]:
fn = lambda: simulated_kalman_filter(
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_cov,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_cov,
    mu0=mu0,
    Sigma0=Sigma0,
    x_obs=x_obs,
)

ll = numpyro.infer.util.log_likelihood(fn, posterior_samples)

In [None]:
[ll.keys()]

In [None]:
ll["x"].shape

In [None]:
fig, ax = plt.subplots()


z_s_samples = prior_samples["z"].squeeze()
x_s_samples = prior_samples["x"].squeeze()

n_time_steps, n_samples, n_dims = z_s_samples.shape

for i_sample in range(num_samples):
    ax.plot(
        z_s_samples[i_sample, ..., 0],
        z_s_samples[i_sample, ..., 1],
        color="black",
        label="True State",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s_samples[i_sample, ..., 0],
        x_s_samples[i_sample, ..., 1],
        label="Measurements",
        color="red",
        alpha=0.5,
    )
    break
ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

## Inference

### MCMC

In [None]:
from einops import rearrange

x_obs = rearrange(x_samples_prior, "T B D -> B T D")
x_obs.shape

In [None]:
rng_key_infer = jax.random.PRNGKey(123)

# Inference
kernel = infer.NUTS(simulated_kalman_filter)
mcmc = infer.MCMC(kernel, num_warmup=100, num_samples=100)
mcmc.run(
    rng_key_infer,
    x_obs=x_obs,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_cov,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_cov,
    mu0=mu0,
    Sigma0=Sigma0,
)
posterior_samples = mcmc.get_samples()

In [None]:
[posterior_samples.keys()], posterior_samples["z"].shape

In [None]:
time_steps = 80
num_samples = 5

# prior
predictive = infer.Predictive(
    simulated_kalman_filter,
    posterior_samples=posterior_samples,
    return_sites=["x", "z", "z0"],
)
posterior_predictive = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_cov,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_cov,
    mu0=mu0,
    Sigma0=Sigma0,
    # x_obs=x_obs
)

In [None]:
[posterior_predictive.keys()]

In [None]:
fig, ax = plt.subplots()


z_s_samples = posterior_predictive["z"]
x_s_samples = posterior_predictive["x"]

n_time_steps, n_samples, _, n_dims = z_s_samples.shape

i_sample = 1
for i_sample in range(num_samples):
    # ax.plot(z_s_samples[i_sample, ..., 0, 0], z_s_samples[i_sample, ..., 0, 1], color="black", label="True State")
    ax.plot(
        z_s_samples[i_sample, ..., 0, 0],
        z_s_samples[i_sample, ..., 0, 1],
        color="black",
        label="True State",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s_samples[i_sample, ..., 0, 0],
        x_s_samples[i_sample, ..., 0, 1],
        label="Measurements",
        color="red",
        alpha=0.5,
    )

    break
ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

### SVI