# Demo - KF with Unknown Parameters

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".home"])

# append to path
sys.path.append(str(root))

%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from typing import NamedTuple
from jax.random import multivariate_normal, split
from tqdm.notebook import tqdm, trange
from jax.random import multivariate_normal
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions


import matplotlib.pyplot as plt

## Simulating Data

### State Transition Dynamics

We assume that we can fully describe the state when we have the `(x,y)` coordinates of the position and the `(x,y)` velocity. So we can write this as:

$$
\mathbf{z}_t = 
\begin{bmatrix}
z_t^1 \\ z_t^2 \\ \dot{z}_t^1 \\ \dot{z}_t^2
\end{bmatrix}
$$


where $z_t^d$ is the coordinate of the position and $\dot{z}^d$ is the velocity. 

We can describe the dynamics of the system using the following system of equations:

$$
\begin{aligned}
z_t^1 &= z_{t-1}^1 + \Delta_t \dot{z}_t^1 + \epsilon_t^1 \\
z_t^2 &= z_{t-1}^2 + \Delta_t \dot{z}_t^2 + \epsilon_t^2 \\
\dot{z}_t^1 &= \dot{z}_{t-1}^1 + \epsilon_t^3 \\
\dot{z}_t^2 &= \dot{z}_{t-1}^2 + \epsilon_t^4 \\
\end{aligned}
$$

This is a very simple formulation which takes a first order approximation to the change in position based on speed and we also assume constant velocity. Note, we also include some noise because we assume that some of the dynamics are noisy, i.e. there are random acceleration and position changes in the model. 


We can also put this into matrix formulation like so:

$$
\mathbf{z}_t = \mathbf{A}_t \mathbf{z}_{t-1} + \boldsymbol{\epsilon}_t
$$

where:

$$
\mathbf{A}_t = 
\begin{bmatrix}
1 & 0 & \Delta_t & 0 \\
0 & 1 & 0 & \Delta_t \\
0 & 0 & 1 & 0 \\
0 & 0 & 0 & 1 \\
\end{bmatrix}, \;\; \mathbf{A}_t \in \mathbb{R}^{4\times 4}
$$


---
### Emissions Model

We can only fully observe the locations (not the velocities). So this will be a lower dimensional vector of size 2-D. The system of equations are as follows:

$$
\begin{aligned}
x_t^1 &= z_t^1 + \delta_t^1 \\
x_t^2 &= z_t^2 + \delta_t^2 \\
\end{aligned}
$$

This is a very simple model where we assume we can extract the direct positions (plus some noise) from the state.

We can write this in an abbreviated matrix formulation:

$$
\mathbf{x}_t = \mathbf{C}_t \mathbf{z}_t + \delta_t
$$

where:

$$
\mathbf{C}_t = 
\begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
\end{bmatrix}, \;\; \mathbf{C}_t \in \mathbb{R}^{2 \times 4}
$$

## Model

* [x] Modeling Noises Only
* [x] Modeling States/Observations
* [ ] Using Conditioning Notation
* [ ] Using Plate Notation

In [None]:
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro import diagnostics, infer

```python
def gaussian_hmm(obs=None, time_steps: int=10):
    
    if obs is not None:
        time_steps = obs.shape[0]
        
    # transition model
    trans = numpyro.sample("trans", dist.Normal(0, 1))
    # trans = numpyro.param("trans", 0.1)
    
    # emission model
    emit = numpyro.sample("emi", dist.Normal(0, 1))
    # emit = numpyro.param("emit", 0.1 )
    
    def body(z_prev, x_prev):
        # transition distribution
        z = numpyro.sample("z", dist.Normal(trans * z_prev, 1))
        
        # emission distribution
        x = numpyro.sample("x", dist.Normal(emit * z, 1), obs=x_prev)
        
        return z, (z, x)
    
    # prior dist
    z0 = numpyro.sample("z0", dist.Normal(0, 1))
    
    # scan
    _, (z, x) = scan(body, z0, obs, length=time_steps)
    
    return (z, x)
```

## Parameters

In [None]:
# init prior dist
mu0 = jnp.array([8.0, 5.0, 1.0, 0.0])
Sigma0 = 1e-4 * jnp.eye(4)

prior_dist = dist.MultivariateNormal(mu0, Sigma0)

# =================
# transition model
# =================
state_dim = 4
dt = 0.1
step_std = 0.1

trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)
trans_noise_param = step_std**2
trans_noise_mat = trans_noise_param * jnp.eye(state_dim)
trans_noise = dist.MultivariateNormal(jnp.zeros(state_dim), trans_noise_mat)

# =================
# emission model
# =================
noise_std = 0.1
obs_dim = 2

emiss_mat = jnp.eye(N=2, M=4)
emiss_noise_param = noise_std**2
emiss_noise_mat = emiss_noise_param * jnp.eye(obs_dim)
emiss_noise = dist.MultivariateNormal(jnp.zeros(obs_dim), emiss_noise_mat)

In [None]:
prior_dist.mean

## Simulations

In [None]:
def true_simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    prior_dist,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        time_steps, n_dims = x_obs.shape

    # ==================
    # sample from prior
    # ==================
    z0 = numpyro.deterministic(
        "z0", prior_dist.mean
    )  # numpyro.sample("z0", prior_dist)

    # Model
    def body(z_prev, x_prev):

        # transition
        z = numpyro.sample(
            "z",
            dist.MultivariateNormal(
                loc=jnp.dot(trans_mat, z_prev), covariance_matrix=trans_noise_cov
            ),
        )

        # sample noise
        x = numpyro.sample(
            "x",
            dist.MultivariateNormal(
                loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov
            ),
            obs=x_prev,
        )

        return z, (z, x)

    _, (z, x) = scan(f=body, init=(z0), xs=x_obs, length=time_steps)

    return z, x

#### Sampling (Unconditional)

In [None]:
time_steps = 80

with numpyro.handlers.seed(rng_seed=123):
    z_true, x_true = true_simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
    )

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 0], z_true[..., 1], color="black", label="True State")
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

## Model (Unknown Parameters)

In [None]:
# JITTER = 1e-5

# def kalman_filter(
#     trans_mat: jnp.ndarray,
#     emiss_mat: jnp.ndarray,
#     prior_dist,
#     time_steps: int=100,
#     x_obs: jnp.ndarray=None,
# ):
#     if x_obs is not None:
#         time_steps, n_dims = x_obs.shape

#     # noise parameters
#     trans_noise = numpyro.param("trans_noise", init_value=0.1 * jnp.ones(4), constraint=dist.constraints.positive)

#     emiss_noise = numpyro.param("emiss_noise", init_value=0.1 * jnp.ones(2), constraint=dist.constraints.positive)

#     trans_noise_cov = jnp.diag(JITTER + trans_noise)
#     emiss_noise_cov = jnp.diag(JITTER + emiss_noise)


#     # ==================
#     # sample from prior
#     # ==================
#     z0 = numpyro.deterministic("z0", prior_dist.mean)#numpyro.sample("z0", prior_dist)

#     # Model
#     def body(z_prev, x_prev):

#         # transition
#         z = numpyro.sample("z", dist.MultivariateNormal(loc=jnp.dot(trans_mat, z_prev), covariance_matrix=trans_noise_cov))

#         # sample noise
#         x = numpyro.sample("x", dist.MultivariateNormal(loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov), obs=x_prev)

#         return z, (z, x)

#     _, (z, x) = scan(f=body, init=(z0), xs=x_obs, length=time_steps)


#     return z, x

In [None]:
JITTER = 1e-5


def kalman_filter(
    trans_mat: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    prior_dist,
    time_steps: int = 80,
    x_obs: jnp.ndarray = None,
):
    if x_obs is not None:
        time_steps, n_dims = x_obs.shape

    # noise parameters
    trans_noise = numpyro.param(
        "trans_noise", init_value=1.0, constraint=dist.constraints.positive
    )

    emiss_noise = numpyro.param(
        "emiss_noise", init_value=1.0, constraint=dist.constraints.positive
    )

    trans_noise_cov = JITTER + trans_noise**2 * jnp.eye(4)
    emiss_noise_cov = JITTER + emiss_noise**2 * jnp.eye(2)

    # ==================
    # sample from prior
    # ==================
    z0 = numpyro.deterministic(
        "z0", prior_dist.mean
    )  # numpyro.sample("z0", prior_dist)

    # Model
    def body(z_prev, x_pred):

        # transition
        z = numpyro.sample(
            "z",
            dist.MultivariateNormal(
                loc=jnp.dot(trans_mat, z_prev), covariance_matrix=trans_noise_cov
            ),
        )

        # sample noise
        x = numpyro.sample(
            "x",
            dist.MultivariateNormal(
                loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov
            ),
        )

        return z, (z, x)

    with numpyro.handlers.condition(data={"x": x_obs}):
        _, (z, x) = scan(f=body, init=(z0), xs=None, length=time_steps)

    return z, x

In [None]:
time_steps = 80

with numpyro.handlers.seed(rng_seed=42):
    z_sim, x_sim = kalman_filter(
        trans_mat=trans_mat,
        emiss_mat=emiss_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=x_true,
    )

print(z_sim.shape, x_sim.shape)

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 0], z_true[..., 1], color="black", label="True State")
ax.plot(
    z_sim[..., 0], z_sim[..., 1], color="blue", label="Predicted State", linestyle="--"
)
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

### Samples (Prior)

In [None]:
num_samples = 5
rng_key_prior = jax.random.PRNGKey(123)

# prior
predictive = infer.Predictive(kalman_filter, num_samples=num_samples)
prior_samples = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    emiss_mat=emiss_mat,
    prior_dist=prior_dist,
    time_steps=time_steps,
    x_obs=None,
)

In [None]:
z_sim = prior_samples["z"]
x_sim = prior_samples["x"]

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(z_sim[i, ..., 0], z_sim[i, ..., 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_sim[i, ..., 0], x_sim[i, ..., 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

### Samples (Posterior)

In [None]:
num_samples = 5
rng_key_prior = jax.random.PRNGKey(123)

# prior
predictive = infer.Predictive(
    kalman_filter,
    posterior_samples=prior_samples,
    num_samples=num_samples,
    return_sites=["z", "x"],
)
predictive_posterior = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    emiss_mat=emiss_mat,
    prior_dist=prior_dist,
    time_steps=time_steps,
    x_obs=x_true,
)

In [None]:
z_pred = predictive_posterior["z"]
x_pred = predictive_posterior["x"]

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    # ax.plot(z_true[i, ..., 0], z_true[i, ..., 1], color="black", label="True State")
    ax.plot(
        z_pred[i, ..., 0],
        z_pred[i, ..., 1],
        color="black",
        label="Predicted State",
        linestyle="--",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    # ax.scatter(x_pred[i, ..., 0], x_pred[i, ..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

## Inference

In [None]:
def init_kf_model(
    trans_mat: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    prior_dist,
    time_steps: int = 80,
):
    def fn(x):
        return kalman_filter(
            trans_mat=trans_mat,
            emiss_mat=emiss_mat,
            prior_dist=prior_dist,
            time_steps=time_steps,
            x_obs=x,
        )

    return fn

In [None]:
%%time
from numpyro import diagnostics, infer, optim
from numpyro.infer.autoguide import AutoDelta

# optimizers
rng_key_infer = jax.random.PRNGKey(666)
lr = 1e-2
adam = optim.Adam(lr)

kf_model = init_kf_model(trans_mat, emiss_mat, prior_dist, time_steps)

guide = AutoDelta(kf_model)
# def guide(x, time_steps=30):
#     return None

n_epochs = 100

# Inference
svi = infer.SVI(kf_model, guide, adam, infer.Trace_ELBO())
# svi_result = svi.run(rng_key_infer, n_epochs, x)

# svi_result.params

In [None]:
svi_result = svi.run(rng_key_infer, n_epochs, z_true[..., :2])

In [None]:
fig, ax = plt.subplots()

ax.plot(svi_result.losses)

plt.show()

In [None]:
svi_result.params["emiss_noise"], emiss_noise_param

In [None]:
svi_result.params["trans_noise"], trans_noise_param

In [None]:
rng_key_posterior = jax.random.PRNGKey(777)

# Posterior prediction
predictive = infer.Predictive(kf_model, params=svi_result.params, num_samples=10)
posterior_predictive = predictive(rng_key_posterior, x_true)

In [None]:
z_pred = posterior_predictive["z"]
x_pred = posterior_predictive["x"]

In [None]:
x_pred.shape, z_pred.shape

In [None]:
z_pred = posterior_predictive["z"]
z_lb, z_mu, z_ub = jnp.quantile(x_pred, jnp.array([0.05, 0.5, 0.95]), axis=0)

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 0], z_true[..., 1], color="black", label="True State")
ax.plot(
    z_mu[..., 0], z_mu[..., 1], color="blue", label="Predicted State", linestyle="--"
)
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

In [None]:
z_true.shape, z_mu.shape

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 2], z_true[..., 3], color="black", label="True Velocity")
ax.plot(
    z_pred[..., 2],
    z_pred[..., 3],
    color="blue",
    label="Predicted Velocity",
    linestyle="--",
)
# # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
# ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-velocity", ylabel="y-velocity")
# plt.legend()
plt.show()

### Samples

In [None]:
rng_key_posterior = jax.random.PRNGKey(777)

# Posterior prediction
predictive = infer.Predictive(kf_model, params=svi_result.params, num_samples=5)
posterior_predictive = predictive(rng_key_posterior, None)

In [None]:
z_sim_post = posterior_predictive["z"]
x_sim_post = posterior_predictive["x"]

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    # ax.plot(z_true[i, ..., 0], z_true[i, ..., 1], color="black", label="True State")
    ax.plot(
        z_sim_post[i, ..., 0],
        z_sim_post[i, ..., 1],
        color="black",
        label="Predicted State",
        linestyle="--",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    # ax.scatter(x_sim_post[i, ..., 0], x_sim_post[i, ..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()