# Kalman Filter (Numpyro) - Object Tracking

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".home"])

# append to path
sys.path.append(str(root))

%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from typing import NamedTuple
from jax.random import multivariate_normal, split
from tqdm.notebook import tqdm, trange
from jax.random import multivariate_normal
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

import matplotlib.pyplot as plt

## Simulating Data

### State Transition Dynamics

We assume that we can fully describe the state when we have the `(x,y)` coordinates of the position and the `(x,y)` velocity. So we can write this as:

$$
\mathbf{z}_t = 
\begin{bmatrix}
z_t^1 \\ z_t^2 \\ \dot{z}_t^1 \\ \dot{z}_t^2
\end{bmatrix}
$$


where $z_t^d$ is the coordinate of the position and $\dot{z}^d$ is the velocity. 

We can describe the dynamics of the system using the following system of equations:

$$
\begin{aligned}
z_t^1 &= z_{t-1}^1 + \Delta_t \dot{z}_t^1 + \epsilon_t^1 \\
z_t^2 &= z_{t-1}^2 + \Delta_t \dot{z}_t^2 + \epsilon_t^2 \\
\dot{z}_t^1 &= \dot{z}_{t-1}^1 + \epsilon_t^3 \\
\dot{z}_t^2 &= \dot{z}_{t-1}^2 + \epsilon_t^4 \\
\end{aligned}
$$

This is a very simple formulation which takes a first order approximation to the change in position based on speed and we also assume constant velocity. Note, we also include some noise because we assume that some of the dynamics are noisy, i.e. there are random acceleration and position changes in the model. 


We can also put this into matrix formulation like so:

$$
\mathbf{z}_t = \mathbf{A}_t \mathbf{z}_{t-1} + \boldsymbol{\epsilon}_t
$$

where:

$$
\mathbf{A}_t = 
\begin{bmatrix}
1 & 0 & \Delta_t & 0 \\
0 & 1 & 0 & \Delta_t \\
0 & 0 & 1 & 0 \\
0 & 0 & 0 & 1 \\
\end{bmatrix}, \;\; \mathbf{A}_t \in \mathbb{R}^{4\times 4}
$$


---
### Emissions Model

We can only fully observe the locations (not the velocities). So this will be a lower dimensional vector of size 2-D. The system of equations are as follows:

$$
\begin{aligned}
x_t^1 &= z_t^1 + \delta_t^1 \\
x_t^2 &= z_t^2 + \delta_t^2 \\
\end{aligned}
$$

This is a very simple model where we assume we can extract the direct positions (plus some noise) from the state.

We can write this in an abbreviated matrix formulation:

$$
\mathbf{x}_t = \mathbf{C}_t \mathbf{z}_t + \delta_t
$$

where:

$$
\mathbf{C}_t = 
\begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
\end{bmatrix}, \;\; \mathbf{C}_t \in \mathbb{R}^{2 \times 4}
$$

## Model

* [x] Modeling Noises Only
* [x] Modeling States/Observations
* [ ] Using Conditioning Notation
* [ ] Using Plate Notation

In [None]:
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro import diagnostics, infer

```python
def gaussian_hmm(obs=None, time_steps: int=10):
    
    if obs is not None:
        time_steps = obs.shape[0]
        
    # transition model
    trans = numpyro.sample("trans", dist.Normal(0, 1))
    # trans = numpyro.param("trans", 0.1)
    
    # emission model
    emit = numpyro.sample("emi", dist.Normal(0, 1))
    # emit = numpyro.param("emit", 0.1 )
    
    def body(z_prev, x_prev):
        # transition distribution
        z = numpyro.sample("z", dist.Normal(trans * z_prev, 1))
        
        # emission distribution
        x = numpyro.sample("x", dist.Normal(emit * z, 1), obs=x_prev)
        
        return z, (z, x)
    
    # prior dist
    z0 = numpyro.sample("z0", dist.Normal(0, 1))
    
    # scan
    _, (z, x) = scan(body, z0, obs, length=time_steps)
    
    return (z, x)
```

### Model I - Noises Only

In [None]:
# init prior dist
mu0 = jnp.array([8.0, 5.0, 1.0, 0.0])
Sigma0 = 1e-4 * jnp.eye(4)

prior_dist = dist.MultivariateNormal(mu0, Sigma0)

# =================
# transition model
# =================
state_dim = 4
dt = 0.1
step_std = 0.1

trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)
trans_noise_mat = step_std**2 * jnp.eye(state_dim)
trans_noise = dist.MultivariateNormal(jnp.zeros(state_dim), trans_noise_mat)

# =================
# emission model
# =================
noise_std = 0.1
obs_dim = 2

emiss_mat = jnp.eye(N=2, M=4)
emiss_noise_mat = noise_std**2 * jnp.eye(obs_dim)
emiss_noise = dist.MultivariateNormal(jnp.zeros(obs_dim), emiss_noise_mat)

### Numpyro - Simplest

In [None]:
def simulated_kalman_filter_simple(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    prior_dist,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        time_steps, n_dims = x_obs.shape

    # ==================
    # sample from prior
    # ==================
    z = numpyro.sample("z0", prior_dist)

    xs, zs = [], []

    # Model
    for t in range(time_steps):

        # transition
        z = numpyro.sample(
            f"z_t{t}",
            dist.MultivariateNormal(
                loc=jnp.dot(trans_mat, z), covariance_matrix=trans_noise_cov
            ),
        )

        # sample noise
        x = numpyro.sample(
            f"x_t{t}",
            dist.MultivariateNormal(
                loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov
            ),
        )

        xs.append(x)
        zs.append(z)

    return jnp.stack(zs, axis=0), jnp.stack(xs, axis=0)

#### Sampling (Unconditional)

In [None]:
time_steps = 80

with numpyro.handlers.seed(rng_seed=123):
    z_s, x_s = simulated_kalman_filter_simple(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
    )
    params = numpyro.handlers.trace(simulated_kalman_filter_simple).get_trace(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
    )
    # log_joint = numpyro.infer.util.log_density(model, (data, mask), {}, {"x": x, "y": x})[0]
print(z_s.shape, x_s.shape)

In [None]:
[params.keys()]

In [None]:
fig, ax = plt.subplots()

ax.plot(z_s[..., 0], z_s[..., 1], color="black", label="True State")
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_s[..., 0], x_s[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

### Numpyro - Scan Function

In [None]:
def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    prior_dist,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        time_steps, n_dims = x_obs.shape

    # ==================
    # sample from prior
    # ==================
    z0 = numpyro.sample("z0", prior_dist)

    # Model
    def body(z_prev, x_prev):

        # transition
        z = numpyro.sample(
            "z",
            dist.MultivariateNormal(
                loc=jnp.dot(trans_mat, z_prev), covariance_matrix=trans_noise_cov
            ),
        )

        # sample noise
        x = numpyro.sample(
            "x",
            dist.MultivariateNormal(
                loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov
            ),
            obs=x_prev,
        )

        return z, (z, x)

    _, (z, x) = scan(f=body, init=(z0), xs=x_obs, length=time_steps)

    return z, x

#### Faster

**Especially** for longer time series.

In [None]:
%%timeit
time_steps = 1_000

with numpyro.handlers.seed(rng_seed=123):
    z_s, x_s = simulated_kalman_filter_simple(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
    )

    # log_joint = numpyro.infer.util.log_density(model, (data, mask), {}, {"x": x, "y": x})[0]

In [None]:
%%timeit
time_steps = 1_000

with numpyro.handlers.seed(rng_seed=123):
    z_s_, x_s_ = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
    )

In [None]:
fig, ax = plt.subplots()

ax.plot(z_s[..., 0], z_s[..., 1], color="black", label="True State")
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_s[..., 0], x_s[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

#### Clean Parameters

In [None]:
time_steps = 50

with numpyro.handlers.seed(rng_seed=123):
    params = numpyro.handlers.trace(simulated_kalman_filter_simple).get_trace(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
    )
    # log_joint = numpyro.infer.util.log_density(model, (data, mask), {}, {"x": x, "y": x})[0]
[params.keys()]

In [None]:
time_steps = 1_000

with numpyro.handlers.seed(rng_seed=123):
    params = numpyro.handlers.trace(simulated_kalman_filter).get_trace(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
    )
    # log_joint = numpyro.infer.util.log_density(model, (data, mask), {}, {"x": x, "y": x})[0]
[params.keys()]

In [None]:
params["z"]["value"].shape

## Sampling

### Unconditional

In [None]:
num_samples = 5
rng_key_prior = jax.random.PRNGKey(123)

# prior
predictive = infer.Predictive(simulated_kalman_filter, num_samples=num_samples)
prior_samples = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_mat,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_mat,
    prior_dist=prior_dist,
    time_steps=time_steps,
    x_obs=None,
)

In [None]:
z_s = prior_samples["z"]
x_s = prior_samples["x"]

print(z_s.shape, x_s.shape)

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(z_s[i, ..., 0], z_s[i, ..., 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s[i, ..., 0], x_s[i, ..., 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

#### Conditional

In [None]:
true_state = prior_samples["z"][0]
true_obs = prior_samples["x"][0]

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(true_state[..., 0], true_state[..., 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        true_obs[..., 0], true_obs[..., 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

In [None]:
# create posterior samples
posterior_samples = {
    "z": true_state[None, ...],
    "x": true_obs[None, ...],
}

In [None]:
num_samples = 100

# prior
predictive = infer.Predictive(
    simulated_kalman_filter,
    posterior_samples=prior_samples,
    num_samples=num_samples,
    return_sites=["z", "x"],
)
posterior_predictive = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_mat,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_mat,
    prior_dist=prior_dist,
    time_steps=time_steps,
)

In [None]:
posterior_predictive["x"].shape, posterior_samples["x"].shape

In [None]:
z_s_ = posterior_predictive["z"]
x_s_ = posterior_predictive["x"]

print(z_s_.shape, x_s_.shape)

In [None]:
z_s_.shape

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(z_s_[i, ..., 0], z_s_[i, ..., 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s_[i, ..., 0], x_s_[i, ..., 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()