# Demo - KF with Unknown Parameters

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".home"])

# append to path
sys.path.append(str(root))

%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from typing import NamedTuple
from jax.random import multivariate_normal, split
from tqdm.notebook import tqdm, trange
from jax.random import multivariate_normal
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions


import matplotlib.pyplot as plt

## Simulating Data

### State Transition Dynamics

We assume that we can fully describe the state when we have the `(x,y)` coordinates of the position and the `(x,y)` velocity. So we can write this as:

$$
\mathbf{z}_t = 
\begin{bmatrix}
z_t^1 \\ z_t^2 \\ \dot{z}_t^1 \\ \dot{z}_t^2
\end{bmatrix}
$$


where $z_t^d$ is the coordinate of the position and $\dot{z}^d$ is the velocity. 

We can describe the dynamics of the system using the following system of equations:

$$
\begin{aligned}
z_t^1 &= z_{t-1}^1 + \Delta_t \dot{z}_t^1 + \epsilon_t^1 \\
z_t^2 &= z_{t-1}^2 + \Delta_t \dot{z}_t^2 + \epsilon_t^2 \\
\dot{z}_t^1 &= \dot{z}_{t-1}^1 + \epsilon_t^3 \\
\dot{z}_t^2 &= \dot{z}_{t-1}^2 + \epsilon_t^4 \\
\end{aligned}
$$

This is a very simple formulation which takes a first order approximation to the change in position based on speed and we also assume constant velocity. Note, we also include some noise because we assume that some of the dynamics are noisy, i.e. there are random acceleration and position changes in the model. 


We can also put this into matrix formulation like so:

$$
\mathbf{z}_t = \mathbf{A}_t \mathbf{z}_{t-1} + \boldsymbol{\epsilon}_t
$$

where:

$$
\mathbf{A}_t = 
\begin{bmatrix}
1 & 0 & \Delta_t & 0 \\
0 & 1 & 0 & \Delta_t \\
0 & 0 & 1 & 0 \\
0 & 0 & 0 & 1 \\
\end{bmatrix}, \;\; \mathbf{A}_t \in \mathbb{R}^{4\times 4}
$$


---
### Emissions Model

We can only fully observe the locations (not the velocities). So this will be a lower dimensional vector of size 2-D. The system of equations are as follows:

$$
\begin{aligned}
x_t^1 &= z_t^1 + \delta_t^1 \\
x_t^2 &= z_t^2 + \delta_t^2 \\
\end{aligned}
$$

This is a very simple model where we assume we can extract the direct positions (plus some noise) from the state.

We can write this in an abbreviated matrix formulation:

$$
\mathbf{x}_t = \mathbf{C}_t \mathbf{z}_t + \delta_t
$$

where:

$$
\mathbf{C}_t = 
\begin{bmatrix}
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
\end{bmatrix}, \;\; \mathbf{C}_t \in \mathbb{R}^{2 \times 4}
$$

## Model

* [x] Modeling Noises Only
* [x] Modeling States/Observations
* [ ] Using Conditioning Notation
* [ ] Using Plate Notation

In [None]:
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro import diagnostics, infer

```python
def gaussian_hmm(obs=None, time_steps: int=10):
    
    if obs is not None:
        time_steps = obs.shape[0]
        
    # transition model
    trans = numpyro.sample("trans", dist.Normal(0, 1))
    # trans = numpyro.param("trans", 0.1)
    
    # emission model
    emit = numpyro.sample("emi", dist.Normal(0, 1))
    # emit = numpyro.param("emit", 0.1 )
    
    def body(z_prev, x_prev):
        # transition distribution
        z = numpyro.sample("z", dist.Normal(trans * z_prev, 1))
        
        # emission distribution
        x = numpyro.sample("x", dist.Normal(emit * z, 1), obs=x_prev)
        
        return z, (z, x)
    
    # prior dist
    z0 = numpyro.sample("z0", dist.Normal(0, 1))
    
    # scan
    _, (z, x) = scan(body, z0, obs, length=time_steps)
    
    return (z, x)
```

## Parameters

In [None]:
# init prior dist
mu0 = jnp.array([8.0, 5.0, 1.0, 0.0])
Sigma0 = 1e-4 * jnp.eye(4)

prior_dist = dist.MultivariateNormal(mu0, Sigma0)

# =================
# transition model
# =================
state_dim = 4
dt = 0.1
step_std = 0.1

trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)
trans_noise_param = step_std**2
trans_noise_mat = trans_noise_param * jnp.eye(state_dim)
trans_noise = dist.MultivariateNormal(jnp.zeros(state_dim), trans_noise_mat)

# =================
# emission model
# =================
noise_std = 0.1
obs_dim = 2

emiss_mat = jnp.eye(N=2, M=4)
emiss_noise_param = noise_std**2
emiss_noise_mat = emiss_noise_param * jnp.eye(obs_dim)
emiss_noise = dist.MultivariateNormal(jnp.zeros(obs_dim), emiss_noise_mat)

## Simulations

In [None]:
def true_simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    prior_dist,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        time_steps, n_dims = x_obs.shape

    # ==================
    # sample from prior
    # ==================
    z0 = numpyro.deterministic(
        "z0", prior_dist.mean
    )  # numpyro.sample("z0", prior_dist)

    # Model
    def body(z_prev, x_prev):

        # transition
        z = numpyro.sample(
            "z",
            dist.MultivariateNormal(
                loc=jnp.dot(trans_mat, z_prev), covariance_matrix=trans_noise_cov
            ),
        )

        # sample noise
        x = numpyro.sample(
            "x",
            dist.MultivariateNormal(
                loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov
            ),
            obs=x_prev,
        )

        return z, (z, x)

    _, (z, x) = scan(f=body, init=(z0), xs=x_obs, length=time_steps)

    return z, x

In [None]:
time_steps = 80

with numpyro.handlers.seed(rng_seed=123):
    z_true, x_true = true_simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
    )

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 0], z_true[..., 1], color="black", label="True State")
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

## Model (Unknown Parameters)

In [None]:
JITTER = 1e-5


def kalman_filter(
    trans_mat: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    prior_dist,
    time_steps: int = 80,
    x_obs: jnp.ndarray = None,
):
    if x_obs is not None:
        time_steps, n_dims = x_obs.shape

    # noise parameters
    trans_noise = numpyro.sample("trans_noise", dist.HalfCauchy())

    emiss_noise = numpyro.sample("emiss_noise", dist.HalfCauchy())

    trans_noise_cov = JITTER + trans_noise**2 * jnp.eye(4)
    emiss_noise_cov = JITTER + emiss_noise**2 * jnp.eye(2)

    # ==================
    # sample from prior
    # ==================
    z0 = numpyro.sample("z0", dist.Normal(), sample_shape=(4,))

    # Model
    def body(z_prev, x_pred):

        # transition
        z = numpyro.sample(
            "z",
            dist.MultivariateNormal(
                loc=jnp.dot(trans_mat, z_prev), covariance_matrix=trans_noise_cov
            ),
        )

        # sample noise
        x = numpyro.sample(
            "x",
            dist.MultivariateNormal(
                loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov
            ),
        )

        return z, (z, x)

    with numpyro.handlers.condition(data={"x": x_obs}):
        _, (z, x) = scan(f=body, init=(z0), xs=None, length=time_steps)

    return z, x

In [None]:
time_steps = 80

with numpyro.handlers.seed(rng_seed=42):
    z_sim, x_sim = kalman_filter(
        trans_mat=trans_mat,
        emiss_mat=emiss_mat,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=x_true,
    )

print(z_sim.shape, x_sim.shape)

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 0], z_true[..., 1], color="black", label="True State")
ax.plot(
    z_sim[..., 0], z_sim[..., 1], color="blue", label="Predicted State", linestyle="--"
)
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

### Samples (Prior)

In [None]:
num_samples = 5
rng_key_prior = jax.random.PRNGKey(123)

# prior
predictive = infer.Predictive(kalman_filter, num_samples=num_samples)
prior_samples = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    emiss_mat=emiss_mat,
    prior_dist=prior_dist,
    time_steps=time_steps,
    x_obs=None,
)

In [None]:
z_sim = prior_samples["z"]
x_sim = prior_samples["x"]

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(z_sim[i, ..., 0], z_sim[i, ..., 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_sim[i, ..., 0], x_sim[i, ..., 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

### Samples (Posterior)

In [None]:
num_samples = 5
rng_key_prior = jax.random.PRNGKey(123)

# prior
predictive = infer.Predictive(
    kalman_filter,
    posterior_samples=prior_samples,
    num_samples=num_samples,
    return_sites=["z", "x"],
)
predictive_posterior = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    emiss_mat=emiss_mat,
    prior_dist=prior_dist,
    time_steps=time_steps,
    x_obs=x_true,
)

In [None]:
z_pred = predictive_posterior["z"]
x_pred = predictive_posterior["x"]

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    # ax.plot(z_true[i, ..., 0], z_true[i, ..., 1], color="black", label="True State")
    ax.plot(
        z_pred[i, ..., 0],
        z_pred[i, ..., 1],
        color="black",
        label="Predicted State",
        linestyle="--",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    # ax.scatter(x_pred[i, ..., 0], x_pred[i, ..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

## Inference

### MAP Estimation

In [None]:
def init_kf_model(
    trans_mat: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    prior_dist,
    time_steps: int = 80,
):
    def fn(x):
        return kalman_filter(
            trans_mat=trans_mat,
            emiss_mat=emiss_mat,
            prior_dist=prior_dist,
            time_steps=time_steps,
            x_obs=x,
        )

    return fn

In [None]:
%%time
from numpyro import diagnostics, infer, optim
from numpyro.infer.autoguide import AutoDelta

# optimizers
rng_key_infer = jax.random.PRNGKey(666)
lr = 1e-2
adam = optim.Adam(lr)

kf_model = init_kf_model(trans_mat, emiss_mat, prior_dist, time_steps)

guide = AutoDelta(kf_model)
# def guide(x, time_steps=30):
#     return None

n_epochs = 100

# Inference
svi = infer.SVI(kf_model, guide, adam, infer.Trace_ELBO())
# svi_result = svi.run(rng_key_infer, n_epochs, x)

# svi_result.params

In [None]:
svi_result = svi.run(rng_key_infer, n_epochs, z_true[..., :2])

In [None]:
fig, ax = plt.subplots()

ax.plot(svi_result.losses)

plt.show()

In [None]:
# svi_result.params["emiss_noise"], emiss_noise_param

In [None]:
# svi_result.params["trans_noise"], trans_noise_param

In [None]:
svi_result.params["z0_auto_loc"]

In [None]:
rng_key_posterior = jax.random.PRNGKey(777)

# Posterior prediction
predictive = infer.Predictive(kf_model, params=svi_result.params, num_samples=10)
posterior_predictive = predictive(rng_key_posterior, x_true)

In [None]:
z_pred = posterior_predictive["z"]
x_pred = posterior_predictive["x"]

In [None]:
x_pred.shape, z_pred.shape

In [None]:
z_pred = posterior_predictive["z"]
z_lb, z_mu, z_ub = jnp.quantile(x_pred, jnp.array([0.05, 0.5, 0.95]), axis=0)

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 0], z_true[..., 1], color="black", label="True State")
ax.plot(
    z_mu[..., 0], z_mu[..., 1], color="blue", label="Predicted State", linestyle="--"
)
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

In [None]:
z_true.shape, z_mu.shape

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 2], z_true[..., 3], color="black", label="True Velocity")
ax.plot(
    z_pred[..., 2],
    z_pred[..., 3],
    color="blue",
    label="Predicted Velocity",
    linestyle="--",
)
# # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
# ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-velocity", ylabel="y-velocity")
# plt.legend()
plt.show()

### Samples

In [None]:
rng_key_posterior = jax.random.PRNGKey(777)

# Posterior prediction
predictive = infer.Predictive(kf_model, params=svi_result.params, num_samples=1)
posterior_predictive = predictive(rng_key_posterior, None)

In [None]:
z_sim_post = posterior_predictive["z"]
x_sim_post = posterior_predictive["x"]

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    # ax.plot(z_true[i, ..., 0], z_true[i, ..., 1], color="black", label="True State")
    ax.plot(
        z_sim_post[i, ..., 0],
        z_sim_post[i, ..., 1],
        color="black",
        label="Predicted State",
        linestyle="--",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    # ax.scatter(x_sim_post[i, ..., 0], x_sim_post[i, ..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

### MCMC

In [None]:
kf_model = init_kf_model(trans_mat, emiss_mat, prior_dist, time_steps)

# Inference
kernel = infer.NUTS(kf_model)
num_warmup = 100
num_mcmc_samples = 200
mcmc = infer.MCMC(kernel, num_warmup=num_warmup, num_samples=num_mcmc_samples)

In [None]:
mcmc.run(rng_key_infer, z_true[:, :2])

In [None]:
# mcmc.print_summary()

In [None]:
posterior_samples = mcmc.get_samples()

In [None]:
plt.figure()

plt.hist(posterior_samples["emiss_noise"], bins=25)

plt.show()

In [None]:
plt.figure()

plt.hist(posterior_samples["trans_noise"], bins=25)
plt.show()

In [None]:
rng_key_posterior = jax.random.PRNGKey(777)

# Posterior prediction
predictive = infer.Predictive(
    kf_model, posterior_samples=posterior_samples, return_sites=["x", "z"]
)
posterior_predictive = predictive(rng_key_posterior, x_true)

In [None]:
z_pred = posterior_predictive["z"]
x_pred = posterior_predictive["x"]

In [None]:
z_pred = posterior_predictive["z"]
z_lb, z_mu, z_ub = jnp.quantile(x_pred, jnp.array([0.05, 0.5, 0.95]), axis=0)

In [None]:
fig, ax = plt.subplots()

ax.plot(z_true[..., 0], z_true[..., 1], color="black", label="True State")
ax.plot(
    z_mu[..., 0], z_mu[..., 1], color="blue", label="Predicted State", linestyle="--"
)

# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_true[..., 0], x_true[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

In [None]:
%%time
from numpyro import diagnostics, infer, optim
from numpyro.infer.autoguide import AutoDelta

# optimizers
rng_key_infer = jax.random.PRNGKey(666)
lr = 1e-2
adam = optim.Adam(lr)

kf_model = init_kf_model(trans_mat, emiss_mat, prior_dist, time_steps)

guide = AutoDelta(kf_model)
# def guide(x, time_steps=30):
#     return None

n_epochs = 100

# Inference
svi = infer.SVI(kf_model, guide, adam, infer.Trace_ELBO())
# svi_result = svi.run(rng_key_infer, n_epochs, x)

# svi_result.params

### Numpyro - Distributions (Deterministic) Everywhere

In [None]:
def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    mu0: jnp.ndarray,
    Sigma0: jnp.ndarray,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        time_steps, n_dims = x_obs.shape

    # Transition Functions
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)
    trans_noise_cov = numpyro.deterministic("trans_noise_cov", trans_mat)

    # Emission Functions
    emiss_mat = numpyro.deterministic("emiss_mat", trans_mat)
    emiss_noise_cov = numpyro.deterministic("emiss_noise_cov", trans_mat)

    # ==================
    # sample from prior
    # ==================
    prior_dist = dist.MultivariateNormal(mu0, Sigma0)

    z0 = numpyro.sample("z0", prior_dist)

    # Model
    def body(z_prev, x_prev):
        x_prev, x_prev_mask = x_prev

        # transition
        z = numpyro.sample(
            "z",
            dist.MultivariateNormal(
                loc=jnp.dot(trans_mat, z_prev), covariance_matrix=trans_noise_cov
            ),
        )

        # sample noise
        if x_obs_mask is not None:
            print(x_prev.shape, x_prev_mask.shape)
            with numpyro.handlers.mask(mask=x_prev_mask):

                x = numpyro.sample(
                    "x",
                    dist.MultivariateNormal(
                        loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov
                    ),
                    obs=x_prev,
                    obs_mask=None,
                )

        else:
            x = numpyro.sample(
                "x",
                dist.MultivariateNormal(
                    loc=jnp.dot(emiss_mat, z), covariance_matrix=emiss_noise_cov
                ),
                obs=x_prev,
                obs_mask=None,
            )

        return z, (z, x)

    _, (z, x) = scan(f=body, init=(z0), xs=(x_obs, x_obs_mask), length=time_steps)

    return z, x

#### Samples (Propagated)

In [None]:
time_steps = 80

with numpyro.handlers.seed(rng_seed=42):
    z_s, x_s = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
    )

print(z_s.shape, x_s.shape)

In [None]:
fig, ax = plt.subplots()

ax.plot(z_s[..., 0], z_s[..., 1], color="black", label="True State")
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_s[..., 0], x_s[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

In [None]:
time_steps = 80
x_obs_mask = np.ones_like(x_s, dtype=bool)
x_obs_mask[::2] = False
x_obs_mask = jnp.asarray(x_obs_mask, dtype=bool)

with numpyro.handlers.seed(rng_seed=42):
    z_s_, x_s_ = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=x_s,
        x_obs_mask=x_obs_mask,
    )

print(z_s_.shape, x_s_.shape)

In [None]:
np.testing.assert_array_almost_equal(x_s, x_s_)

In [None]:
fig, ax = plt.subplots()

ax.plot(z_s_[..., 0], z_s_[..., 1], color="black", label="True State")
# ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
ax.scatter(x_s_[..., 0], x_s_[..., 1], label="Measurements", color="red", alpha=0.5)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

In [None]:
num_samples = 5
rng_key_prior = jax.random.PRNGKey(123)

# prior
predictive = infer.Predictive(simulated_kalman_filter, num_samples=num_samples)
prior_samples = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_mat,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_mat,
    mu0=mu0,
    Sigma0=Sigma0,
    time_steps=time_steps,
    x_obs=None,
)

z_s = prior_samples["z"]
x_s = prior_samples["x"]

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(z_s[i, ..., 0], z_s[i, ..., 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s[i, ..., 0], x_s[i, ..., 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

## Multivariate

In [None]:
fn_vdot = lambda mat, x: jax.vmap(jnp.dot, in_axes=(None, 0))(mat, x)


def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    mu0: jnp.ndarray,
    Sigma0: jnp.ndarray,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    n_batch: int = 1,
):
    if x_obs is not None:
        n_batch, time_steps, n_dims = x_obs.shape

    # Transition Functions
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)
    trans_noise_cov = numpyro.deterministic("trans_noise_cov", trans_noise_cov)

    # Emission Functions
    emiss_mat = numpyro.deterministic("emiss_mat", emiss_mat)
    emiss_noise_cov = numpyro.deterministic("emiss_noise_cov", emiss_noise_cov)

    # ==================
    # sample from prior
    # ==================
    prior_dist = dist.MultivariateNormal(mu0, Sigma0)

    with numpyro.plate("batch", n_batch):
        z0 = numpyro.sample("z0", prior_dist)
        print(z0.shape)

    # Model
    def body(z_prev, x_prev):
        # transition

        print("z Before:", z_prev.shape)
        with numpyro.plate("batch", n_batch):
            print("z Before (Batched):", z_prev.shape)
            # TRANSITION DYNAMICS

            # vectorize multiplication
            print(trans_mat.shape, z_prev.shape)
            z = fn_vdot(trans_mat, z_prev)
            print("z After Mul:", z.shape)
            # sample
            z = numpyro.sample(
                "z", dist.MultivariateNormal(loc=z, covariance_matrix=trans_noise_cov)
            )
            print("z After Sample:", z.shape)
            # EMISSION DYNAMICS

            # vectorize multiplication
            print(emiss_mat.shape, z.shape)
            x = fn_vdot(emiss_mat, z)
            print("x After Mul:", x.shape)
            # sample
            x = numpyro.sample(
                "x",
                dist.MultivariateNormal(loc=x, covariance_matrix=emiss_noise_cov),
                obs=x_prev,
            )
            print("x After Sample:", x.shape)

        return z, (z, x)

    if x_obs is not None:
        x_obs = jnp.swapaxes(x_obs, 0, 1)

    # if x_obs_mask is not None:
    #     x_obs_mask = jnp.swapaxes(x_obs_mask, 0, 1)

    _, (z, x) = scan(f=body, init=(z0), xs=(x_obs), length=time_steps)

    return z.squeeze(), x.squeeze()

In [None]:
time_steps = 80
n_batch = 5

with numpyro.handlers.seed(rng_seed=42):
    z_s, x_s = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
        n_batch=n_batch,
    )

print(z_s.shape, x_s.shape)

In [None]:
time_steps = 80
n_batch = 5
x_obs_init = jnp.swapaxes(x_s, 0, 1)
x_obs_mask = jnp.ones(x_obs_init.shape[1:]).astype(bool)
x_obs_mask = x_obs_mask.at[..., 1].set(False)
print(x_obs_init.shape, x_obs_mask.shape)
with numpyro.handlers.seed(rng_seed=42):
    z_s, x_s_ = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=x_obs_init,
        n_batch=n_batch,
    )

print(z_s.shape, x_s_.shape)

In [None]:
np.testing.assert_array_almost_equal(x_s, x_s_)

In [None]:
num_samples = 5
rng_key_prior = jax.random.PRNGKey(123)
n_batch = 1

# prior
predictive = infer.Predictive(simulated_kalman_filter, num_samples=num_samples)
prior_samples = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_mat,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_mat,
    mu0=mu0,
    Sigma0=Sigma0,
    time_steps=time_steps,
    x_obs=None,
    n_batch=n_batch,
)

z_s = prior_samples["z"]
x_s = prior_samples["x"]

In [None]:
z_s.shape

### Numpyro - Introducing Masks

In [None]:
fn_vdot = lambda mat, x: jax.vmap(jnp.dot, in_axes=(None, 0))(mat, x)


def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    mu0: jnp.ndarray,
    Sigma0: jnp.ndarray,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    x_obs_mask: jnp.ndarray = None,
    n_batch: int = 1,
):
    if x_obs is not None:
        n_batch, time_steps, n_dims = x_obs.shape

    # Transition Functions
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)
    trans_noise_cov = numpyro.deterministic("trans_noise_cov", trans_noise_cov)

    # Emission Functions
    emiss_mat = numpyro.deterministic("emiss_mat", emiss_mat)
    emiss_noise_cov = numpyro.deterministic("emiss_noise_cov", emiss_noise_cov)

    # ==================
    # sample from prior
    # ==================
    prior_dist = dist.MultivariateNormal(mu0, Sigma0)

    with numpyro.plate("batch", n_batch):
        z0 = numpyro.sample("z0", prior_dist)
        print(z0.shape)

    # Model
    def body(z_prev, x_prev):
        x_prev, x_prev_mask = x_prev
        # transition
        try:
            print("x Before:", x_prev.shape, x_prev_mask.shape)

        except AttributeError:
            pass
        print("z Before:", z_prev.shape)
        with numpyro.plate("batch", n_batch, subsample_size=None):
            print("z Before (Batched):", z_prev.shape)
            # TRANSITION DYNAMICS

            # vectorize multiplication
            print(trans_mat.shape, z_prev.shape)
            z = fn_vdot(trans_mat, z_prev)
            print("z After Mul (Batch x Dim):", z.shape)
            # sample
            z = numpyro.sample(
                "z", dist.MultivariateNormal(loc=z, covariance_matrix=trans_noise_cov)
            )
            print("z After Sample (Batch x Dim):", z.shape)
            # EMISSION DYNAMICS

            # vectorize multiplication
            print(emiss_mat.shape, z.shape)
            x = fn_vdot(emiss_mat, z)
            print("x After Mul:", x.shape)
            # print("masks:", x.shape, x_prev.shape, x_prev_mask.shape)
            if x_prev_mask is not None:
                print("MASKING!")
                x_prev_mask = x_prev_mask
                print(
                    "obs (Batch x Dim):",
                    x_prev.shape,
                    "| mask (Dim):",
                    x_prev_mask.shape,
                )
                x = numpyro.sample(
                    "x",
                    dist.MultivariateNormal(
                        loc=x, covariance_matrix=emiss_noise_cov
                    ).mask(x_prev_mask),
                    obs=x_prev,
                )
            #     with numpyro.handlers.mask(mask=x_prev_mask.T):
            # # sample
            #         x = numpyro.sample("x", dist.MultivariateNormal(loc=x, covariance_matrix=emiss_noise_cov), obs=x_prev)
            else:
                x = numpyro.sample(
                    "x",
                    dist.MultivariateNormal(loc=x, covariance_matrix=emiss_noise_cov),
                    obs=x_prev,
                    obs_mask=x_prev_mask,
                )
            print("x After Sample:", x.shape)

        return z, (z, x)

    if x_obs is not None:
        x_obs = jnp.swapaxes(x_obs, 0, 1)

    # if x_obs_mask is not None:
    #     x_obs_mask = jnp.swapaxes(x_obs_mask, 0, 1)

    _, (z, x) = scan(f=body, init=(z0), xs=(x_obs, x_obs_mask), length=time_steps)

    return z, x

In [None]:
from einops import repeat

In [None]:
time_steps = 80
n_batch = 5

with numpyro.handlers.seed(rng_seed=42):
    z_s, x_s = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
        n_batch=n_batch,
    )

print(z_s.shape, x_s.shape)

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(z_s[..., i, 0], z_s[..., i, 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s[..., i, 0], x_s[..., i, 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

In [None]:
time_steps = 80
n_batch = 5
x_obs_init = jnp.swapaxes(x_s, 0, 1)
x_obs_mask = jnp.ones(x_obs_init.shape, dtype=bool)
x_obs_mask = x_obs_mask.at[..., 1].set(False)
print("obs:", x_obs_init.shape, "| mask:", x_obs_mask.shape)
with numpyro.handlers.seed(rng_seed=42):
    z_s_, x_s_ = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=x_obs_init,
        x_obs_mask=None,
        n_batch=n_batch,
    )

print(z_s_.shape, x_s_.shape)

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(z_s_[..., i, 0], z_s_[..., i, 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s_[..., i, 0], x_s_[..., i, 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

In [None]:
time_steps = 80
n_batch = 5
x_obs_init = jnp.swapaxes(x_s, 0, 1)
x_obs_mask = jnp.ones(x_obs_init.shape[1:], dtype=bool)
x_obs_mask = x_obs_mask.at[..., 1].set(False)
print("obs:", x_obs_init.shape, "| mask:", x_obs_mask.shape)
with numpyro.handlers.seed(rng_seed=42):
    z_s_, x_s_ = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=x_obs_init,
        x_obs_mask=x_obs_mask,
        n_batch=n_batch,
    )

print(z_s_.shape, x_s_.shape)

In [None]:
fig, ax = plt.subplots()

for i in range(num_samples):
    ax.plot(z_s_[..., i, 0], z_s_[..., i, 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s_[..., i, 0], x_s_[..., i, 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
# plt.legend()
plt.show()

In [None]:
mask = np.array([False, True, True, False])
data = np.full((4, 3), 7.0)


def model():
    x = numpyro.sample("x", dist.MultivariateNormal(np.zeros(3), np.eye(3)))
    print(x.shape)
    with numpyro.plate("plate", len(data)):
        print(len(data), x.shape, data.shape, mask.shape)
        y = numpyro.sample(
            "y", dist.MultivariateNormal(x, np.eye(3)), obs=data, obs_mask=mask
        )
    return x, y

In [None]:
with numpyro.handlers.seed(rng_seed=42):
    x, y = model()
x.shape, y.shape

In [None]:
mask.shape, data.shape

In [None]:
mask_last = 1  # ", [1, 5, 10]

N = 10
mask = np.ones(N, dtype=bool)
mask[-mask_last] = 0


def model(data, mask):
    print(data.shape, mask.shape)
    with numpyro.plate("N", N):
        x = numpyro.sample("x", dist.Normal(0, 1))
        print(x.shape)
        with numpyro.handlers.mask(mask=mask):

            y = numpyro.sample("y", dist.Delta(x, log_density=1.0))
            print("mask:", mask.shape, x.shape, y.shape)
            with numpyro.handlers.scale(scale=2):
                print("scale:", x.shape, data.shape)
                obs = numpyro.sample("obs", dist.Normal(x, 1), obs=data)

In [None]:
data = jax.random.normal(jax.random.PRNGKey(0), (N,))
x = jax.random.normal(jax.random.PRNGKey(1), (N,))

data.shape, x.shape, mask.shape

In [None]:
log_joint = numpyro.infer.util.log_density(model, (data, mask), {}, {"x": x, "y": x})[0]
log_joint

In [None]:
time_steps = 80

with numpyro.handlers.seed(rng_seed=42):
    z_s, x_s = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_mat,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_mat,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
    )

print(z_s.shape, x_s.shape)

In [None]:
def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise: jnp.ndarray,
    prior_dist,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    batch: int = 1,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        batch, time_steps, n_dims = x_obs.shape

    # ==================
    # sample from prior
    # ==================
    z0 = numpyro.sample("z0", prior_dist, sample_shape=(batch,))

    # Model
    def body(z_prev, x_prev):

        # transition
        noise = numpyro.sample("z_noise", trans_noise)
        print(z_prev.shape, trans_mat.shape, noise.shape)
        z = jnp.dot(trans_mat, z_prev) + noise
        # z = numpyro.sample("z", dist.MultivariateNormal(mu, trans_noise))

        # emission
        noise = numpyro.sample("x_noise", emiss_noise)
        print(z.shape, emiss_mat.shape, noise.shape)
        x = jnp.dot(emiss_mat, z) + noise
        # x = numpyro.sample("x", dist.MultivariateNormal(mu, emiss_noise), obs=x_prev)

        return z, (z, x)

    # vmap version
    fn = lambda states, x_obs: scan(body, states, x_obs, length=time_steps)

    if x_obs is None:
        fn_scan = jax.vmap(fn, in_axes=(0, None))
    else:
        fn_scan = jax.vmap(fn, in_axes=(0, 0))

    # Loop through data
    _, (z, x) = fn_scan(z0, x_obs)

    return z, x

In [None]:
time_steps = 80
batch = 5

with numpyro.handlers.seed(rng_seed=123):
    z_s, x_s = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise=trans_noise,
        emiss_mat=emiss_mat,
        emiss_noise=emiss_noise,
        prior_dist=prior_dist,
        time_steps=time_steps,
        x_obs=None,
        batch=batch,
    )

print(z_s.shape, x_s.shape)

In [None]:
fig, ax = plt.subplots()

batch = 5
for i in range(batch):
    ax.plot(z_s[i, ..., 0], z_s[i, ..., 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s[i, ..., 0], x_s[i, ..., 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

### Model II - States/Obs

In [None]:
def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    mu0: jnp.ndarray,
    Sigma0: jnp.ndarray,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    batch: int = 1,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        batch, time_steps, n_dims = x_obs.shape

    # Transition Functions
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)
    trans_noise_cov = numpyro.deterministic("trans_noise_cov", trans_mat)

    # Emission Functions
    emiss_mat = numpyro.deterministic("emiss_mat", trans_mat)
    emiss_noise_cov = numpyro.deterministic("emiss_noise_cov", trans_mat)

    # ==================
    # sample from prior
    # ==================
    prior_dist = dist.MultivariateNormal(mu0, Sigma0)

    z0 = numpyro.sample("z0", prior_dist, sample_shape=(batch,))

    # ==================
    # Model
    # ==================
    def body(z_prev, x_prev):

        # transition
        z = jnp.dot(trans_mat, z_prev)
        z = numpyro.sample("z", dist.MultivariateNormal(z, trans_noise_cov))

        # emission
        x = jnp.dot(emiss_mat, z)
        x = numpyro.sample("x", dist.MultivariateNormal(x, emiss_noise_cov), obs=x_prev)

        return z, (z, x)

    # create function
    fn = lambda states, x_obs: scan(body, states, x_obs, length=time_steps)

    # vectorize function
    if x_obs is None:
        fn_scan = jax.vmap(fn, in_axes=(0, None))
    else:
        fn_scan = jax.vmap(fn, in_axes=(0, 0))

    # loop through data
    _, (z, x) = fn_scan(z0, x_obs)

    return z, x

In [None]:
# init prior dist
mu0 = jnp.array([8.0, 5.0, 1.0, 0.0])
Sigma0 = 1e-5 * jnp.eye(4)

prior_dist = dist.MultivariateNormal(mu0, Sigma0)

# =================
# transition model
# =================
state_dim = 4
dt = 0.1
step_std = 0.1

trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)
trans_noise_cov = step_std**2 * jnp.eye(state_dim)
trans_noise = dist.MultivariateNormal(jnp.zeros(state_dim), trans_noise_cov)

# =================
# emission model
# =================
noise_std = 0.02
obs_dim = 2

emiss_mat = jnp.eye(N=2, M=4)
emiss_noise_cov = noise_std**2 * jnp.eye(obs_dim)
emiss_noise = dist.MultivariateNormal(jnp.zeros(obs_dim), emiss_noise_cov)

#### Sampling

In [None]:
time_steps = 80
batch = 2

with numpyro.handlers.seed(rng_seed=123):
    z_s, x_s = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_cov,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_cov,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
        batch=batch,
    )

print(z_s.shape, x_s.shape)

#### Viz

In [None]:
fig, ax = plt.subplots()

for i in range(batch):
    ax.plot(z_s[..., i, 0], z_s[..., i, 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s[..., i, 0], x_s[..., i, 1], label="Measurements", color="red", alpha=0.5
    )

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

#### Conditioned

In [None]:
time_steps = 80
batch = 5

with numpyro.handlers.seed(rng_seed=123):
    z_samples_cond, x_samples_cond = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_cov,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_cov,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=x_s,
        batch=batch,
    )
np.testing.assert_array_almost_equal(x_s, jnp.swapaxes(x_samples_cond, 0, 1))

### Distribution

In [None]:
def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    mu0: jnp.ndarray,
    Sigma0: jnp.ndarray,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    batch: int = 1,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        batch, time_steps, n_dims = x_obs.shape
        x_obs = jnp.swapaxes(x_obs, 0, 1)

    # Transition Functions
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)
    trans_noise_cov = numpyro.deterministic("trans_noise_cov", trans_mat)

    # Emission Functions
    emiss_mat = numpyro.deterministic("emiss_mat", trans_mat)
    emiss_noise_cov = numpyro.deterministic("emiss_noise_cov", trans_mat)

    # ==================
    # sample from prior
    # ==================
    prior_dist = dist.MultivariateNormal(mu0, Sigma0)

    z0 = numpyro.sample("z0", prior_dist, sample_shape=(batch,))

    # print("Prior:", z0.shape)

    fn_vec_dot = jax.vmap(jnp.dot, in_axes=(None, 0))

    # ==================
    # Model
    # ==================
    def body(z_prev, x_prev):

        # transition
        with numpyro.plate("batches", batch, dim=-1):
            # print("Trans Mult:", z_prev.shape, trans_mat.shape)
            z = fn_vec_dot(trans_mat, z_prev)
            # print("Z Before:", z.shape)
            z = numpyro.sample("z", dist.MultivariateNormal(z, trans_noise_cov))
            # print("Z After:", z.shape)
            # emission
            # print("Emiss Mult:", z.shape, emiss_mat.shape)
            x = fn_vec_dot(emiss_mat, z)
            # if x_prev is not None:
            #     print("X Before:", x.shape, x_prev.shape)
            # else:
            #     print("X Before:", x.shape)
            x = numpyro.sample(
                "x", dist.MultivariateNormal(x, emiss_noise_cov), obs=x_prev
            )
            # print("X After:", x.shape)
        return z, (z, x)

    # create function
    fn = lambda states, x_obs: scan(body, states, x_obs, length=time_steps)

    # loop through data
    _, (z, x) = scan(body, z0, x_obs, length=time_steps)

    return z, x

In [None]:
# # Inference
# kernel = infer.NUTS(simulated_kalman_filter)
# mcmc = infer.MCMC(kernel, num_warmup=100, num_samples=100)
# mcmc.run(rng_key_infer,
#     trans_mat=trans_mat, trans_noise_cov=trans_noise_cov,
#     emiss_mat=emiss_mat, emiss_noise_cov=emiss_noise_cov,
#     mu0=mu0, Sigma0=Sigma0,
#     time_steps=time_steps,
#     x_obs=x_s,
#     batch=0)
# posterior_samples = mcmc.get_samples()

#### Prior Samples (Propagate)

In [None]:
time_steps = 80
batch = 10

with numpyro.handlers.seed(rng_seed=123):
    z_samples_prior, x_samples_prior = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_cov,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_cov,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
        batch=batch,
    )
# np.testing.assert_array_almost_equal(x_s, x_s_)
# np.testing.assert_array_almost_equal(z_s, z_s_)
z_samples_prior.shape, x_samples_prior.shape

In [None]:
fig, ax = plt.subplots()

n_time_steps, n_samples, n_dims = x_samples_prior.shape

for i_sample in range(n_samples):
    ax.plot(
        z_samples_prior[..., i_sample, 0],
        z_samples_prior[..., i_sample, 1],
        color="black",
        label="True State",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_samples_prior[..., i_sample, 0],
        x_samples_prior[..., i_sample, 1],
        label="Measurements",
        color="red",
        alpha=0.5,
    )
    break
ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

#### Prior Samples (Predictive)

In [None]:
time_steps = 80
batch = 1

# prior
predictive = infer.Predictive(simulated_kalman_filter, num_samples=num_samples)
prior_samples = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_cov,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_cov,
    mu0=mu0,
    Sigma0=Sigma0,
    time_steps=time_steps,
    x_obs=None,
    batch=batch,
)

In [None]:
prior_samples["x"].shape

In [None]:
fig, ax = plt.subplots()


z_s_samples = prior_samples["z"].squeeze()
x_s_samples = prior_samples["x"].squeeze()

n_time_steps, n_samples, n_dims = z_s_samples.shape

for i_sample in range(num_samples):
    ax.plot(
        z_s_samples[i_sample, ..., 0],
        z_s_samples[i_sample, ..., 1],
        color="black",
        label="True State",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s_samples[i_sample, ..., 0],
        x_s_samples[i_sample, ..., 1],
        label="Measurements",
        color="red",
        alpha=0.5,
    )
    break
ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

### MCMC

In [None]:
from einops import rearrange

x_obs = rearrange(x_s, "T B D -> B T D")
x_obs.shape

In [None]:
# Inference
kernel = infer.NUTS(simulated_kalman_filter)
mcmc = infer.MCMC(kernel, num_warmup=100, num_samples=100)
mcmc.run(
    rng_key_infer,
    x_obs=x_obs,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_cov,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_cov,
    mu0=mu0,
    Sigma0=Sigma0,
)
posterior_samples = mcmc.get_samples()

In [None]:
[posterior_samples.keys()], posterior_samples["z"].shape

In [None]:
time_steps = 80
num_samples = 5

# prior
predictive = infer.Predictive(
    simulated_kalman_filter,
    posterior_samples=posterior_samples,
    return_sites=["x", "z", "z0"],
)
posterior_predictive = predictive(
    rng_key_prior,
    trans_mat=trans_mat,
    trans_noise_cov=trans_noise_cov,
    emiss_mat=emiss_mat,
    emiss_noise_cov=emiss_noise_cov,
    mu0=mu0,
    Sigma0=Sigma0,
    x_obs=x_obs,
)

In [None]:
[posterior_predictive.keys()]

In [None]:
fig, ax = plt.subplots()


z_s_samples = posterior_predictive["z"]
x_s_samples = posterior_predictive["x"]

n_time_steps, n_samples, _, n_dims = z_s_samples.shape

i_sample = 1
for i in range(num_samples):
    ax.plot(
        z_s_samples[i_sample, ..., 0, 0],
        z_s_samples[i_sample, ..., 0, 1],
        color="black",
        label="True State",
    )
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s_samples[i_sample, ..., 0, 0],
        x_s_samples[i_sample, ..., 0, 1],
        label="Measurements",
        color="red",
        alpha=0.5,
    )

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

In [None]:
fig, ax = plt.subplots()


i_sample = 5
for i in range(num_samples):
    ax.plot(z_s[i_sample, ..., 0], z_s[i, ..., 1], color="black", label="True State")
    # ax.plot(x_s[..., 0], x_s[..., 1], color="tab:red", linestyle="--", label="Noisy Latent")
    ax.scatter(
        x_s[i_sample, ..., 0],
        x_s[i, ..., 1],
        label="Measurements",
        color="red",
        alpha=0.5,
    )

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

## Observed Samples

In [None]:
def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    mu0: jnp.ndarray,
    Sigma0: jnp.ndarray,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    batch: int = 1,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        batch, time_steps, n_dims = x_obs.shape
        x_obs = jnp.swapaxes(x_obs, 0, 1)

    # Transition Functions
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)
    trans_noise_cov = numpyro.deterministic("trans_noise_cov", trans_mat)

    # Emission Functions
    emiss_mat = numpyro.deterministic("emiss_mat", trans_mat)
    emiss_noise_cov = numpyro.deterministic("emiss_noise_cov", trans_mat)

    # ==================
    # sample from prior
    # ==================
    prior_dist = dist.MultivariateNormal(mu0, Sigma0)

    z0 = numpyro.sample("z0", prior_dist)

    # ==================
    # Model
    # ==================
    def body(z_prev, x_prev):

        with numpyro.plate("batches", batch, dim=-1):
            # transition
            z = jnp.dot(trans_mat, z_prev)
            z = numpyro.sample("z", dist.MultivariateNormal(z, trans_noise_cov))

            # emission
            x = jnp.dot(emiss_mat, z)
            x = numpyro.sample(
                "x", dist.MultivariateNormal(x, emiss_noise_cov), obs=x_prev
            )

        return z, (z, x)

    # create function
    fn = lambda states, x_obs: scan(body, states, x_obs, length=time_steps)

    # loop through data
    _, (z, x) = scan(body, z0, x_obs, length=time_steps)

    return z, x

In [None]:
time_steps = 80
batch = None

with numpyro.handlers.seed(rng_seed=123):
    z_s_, x_s_ = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_cov,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_cov,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
        batch=1,
    )
# np.testing.assert_array_almost_equal(x_s, x_s_)
# np.testing.assert_array_almost_equal(z_s, z_s_)
z_s_.shape, x_s_.shape

In [None]:
# # init prior dist
# mu0 = jnp.array([0., 0., 1., -1.])
# Sigma0 = jnp.eye(4)

# # =================
# # transition model
# # =================
# dt = 0.01
# trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)

# a = jnp.array([[dt**3/3, dt**2/2], [dt**2/2, dt]])
# b = jnp.eye(2)
# trans_noise = jnp.kron(a,b)

# # =================
# # emission model
# # =================
# r = 0.5

# emiss_mat = jnp.eye(N=2, M=4)
# emiss_noise = r**2 * jnp.eye(2)

## Plate Notation

In [None]:
def simulated_kalman_filter(
    trans_mat: jnp.ndarray,
    trans_noise_cov: jnp.ndarray,
    emiss_mat: jnp.ndarray,
    emiss_noise_cov: jnp.ndarray,
    mu0: jnp.ndarray,
    Sigma0: jnp.ndarray,
    time_steps: int = 100,
    x_obs: jnp.ndarray = None,
    batch: int = 1,
    x_obs_mask: jnp.ndarray = None,
):
    if x_obs is not None:
        batch, time_steps, n_dims = x_obs.shape

    # Transition Functions
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)
    trans_noise_cov = numpyro.deterministic("trans_noise_cov", trans_mat)

    # Emission Functions
    emiss_mat = numpyro.deterministic("emiss_mat", trans_mat)
    emiss_noise_cov = numpyro.deterministic("emiss_noise_cov", trans_mat)

    # ==================
    # sample from prior
    # ==================
    prior_dist = dist.MultivariateNormal(mu0, Sigma0)

    z0 = numpyro.sample("z0", prior_dist)
    print("Prior:", z0.shape)

    # ==================
    # Model
    # ==================
    def body(carry, x_prev):
        z_prev, t = carry
        # transition
        print("Before shape:", z_prev.shape)
        with numpyro.plate("batch", batch, dim=-2):
            print("After Shape:", z_prev.shape)
            print("Trans:", trans_mat.shape, z_prev.shape)
            z = jnp.dot(trans_mat, z_prev)
            print("Trans Dist:", z.shape, trans_noise_cov.shape)
            z = numpyro.sample("z", dist.MultivariateNormal(z, trans_noise_cov))

            # emission
            print("Obs Trans:", emiss_mat.shape, z.shape)
            x = jnp.dot(emiss_mat, z)
            print("Obs Dist:", z.shape, x.shape, x_prev.shape)
            x = numpyro.sample(
                "x", dist.MultivariateNormal(x, emiss_noise_cov), obs=x_prev
            )

        return (z, t + 1), (z, x)

    # create function
    fn = lambda states, x_obs: scan(body, states, x_obs, length=time_steps)

    # loop through data
    _, (z, x) = scan(body, (z0, 0), jnp.swapaxes(x_obs, 0, 1), length=time_steps)

    return z, x

In [None]:
time_steps = 80
batch = 5

with numpyro.handlers.seed(rng_seed=123):
    z_s_, x_s_ = simulated_kalman_filter(
        trans_mat=trans_mat,
        trans_noise_cov=trans_noise_cov,
        emiss_mat=emiss_mat,
        emiss_noise_cov=emiss_noise_cov,
        mu0=mu0,
        Sigma0=Sigma0,
        time_steps=time_steps,
        x_obs=None,
        batch=batch,
    )
# np.testing.assert_array_almost_equal(x_s, x_s_)
# np.testing.assert_array_almost_equal(z_s, z_s_)

In [None]:
def simulate_data(num_samples: int = 1, time_steps: int = 10, dt=0.1, r=0.5, seed=123):

    key = jax.random.PRNGKey(seed)

    # init prior dist
    mu0 = jnp.array([0.0, 0.0, 1.0, -1.0])
    Sigma0 = jnp.eye(4)
    prior_dist = dist.MultivariateNormal(loc=mu0, covariance_matrix=Sigma0)

    # =================
    # transition model
    # =================
    trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)
    a = jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]])
    b = jnp.eye(2)
    transition_noise = jnp.kron(a, b)

    # =================
    # emission model
    # =================
    obs_mat = jnp.eye(2, M=4)

    obs_noise = r**2 * jnp.eye(2)

    # ==================
    # sample from prior
    # ==================
    key, key_prior = jax.random.split(key, 2)

    z_prev = prior_dist.sample(key=key_prior, sample_shape=(num_samples,))

    assert z_prev.shape == (num_samples, 4)

    # results
    states = {"state_true": [], "state_noise": [], "meas_noise": []}

    # loop through time steps
    for i_z_init in tqdm(z_prev):
        z_prev = i_z_init
        state_true, state_noise, meas_noise = [], [], []
        for t in trange(time_steps):

            key, key_trans, key_obs = jax.random.split(key, 3)

            # true state z = F z
            z_true = trans_mat @ z_prev

            assert z_true.shape == (4,)
            state_true.append(z_true[None, :])

            # simulate transition model
            # z = F z + eps
            temp_dist = dist.MultivariateNormal(covariance_matrix=transition_noise)
            noise = temp_dist.sample(key=key_trans, sample_shape=(1,))

            z_prev = z_true.squeeze() + noise.squeeze()
            assert z_prev.shape == (4,)
            state_noise.append(z_prev)

            # simulate transition model
            # x = H z + eps
            temp_dist = dist.MultivariateNormal(
                loc=jnp.zeros(2), covariance_matrix=obs_noise
            )
            noise = temp_dist.sample(key=key_trans)

            x_prev = obs_mat @ z_prev.squeeze() + noise.squeeze()
            assert x_prev.shape == (2,)
            meas_noise.append(x_prev)

        states["state_true"].append(jnp.vstack(state_true))
        states["state_noise"].append(jnp.vstack(state_noise))
        states["meas_noise"].append(jnp.vstack(meas_noise))

        break

    states["state_true"] = jnp.vstack(states["state_true"])
    states["state_noise"] = jnp.vstack(states["state_noise"])
    states["meas_noise"] = jnp.vstack(states["meas_noise"])

    return states

In [None]:
states = simulate_data(10, time_steps=50)

In [None]:
states["state_true"].shape, states["state_noise"].shape

In [None]:
fig, ax = plt.subplots()

ax.plot(
    states["state_true"][..., 0],
    states["state_true"][..., 1],
    color="black",
    label="True State",
)
ax.plot(
    states["state_noise"][..., 0],
    states["state_noise"][..., 1],
    color="tab:red",
    linestyle="--",
    label="Noisy Latent",
)
ax.scatter(
    states["meas_noise"][..., 0],
    states["meas_noise"][..., 1],
    label="Measurements",
    color="red",
    alpha=0.5,
)

ax.set(xlabel="x-position", ylabel="y-position")
plt.legend()
plt.show()

## Numpyro Model

In [None]:
dt = 0.1
r = 0.5


def gaussian_hmm(obs=None, time_steps: int = 15):

    # extract shapes from observations
    if obs is not None:
        time_steps, obs_dim = obs.shape

    # =================
    # transition model
    # =================
    trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)

    a = jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]])
    b = jnp.eye(2)
    transition_noise = jnp.kron(a, b)

    # =================
    # emission model
    # =================
    obs_mat = jnp.eye(2, M=4)
    emission_mat = numpyro.deterministic("obs_mat", obs_mat)

    emission_noise = r**2 * jnp.eye(2)

    def body(z_prev, x_prev):

        # transition distribution
        z = trans_mat @ z_prev
        noise_z = numpyro.sample(
            "trans_noise", dist.MultivariateNormal(covariance_matrix=transition_noise)
        )

        z += noise_z

        z = numpyro.deterministic("z", z)

        # emission distribution
        x = emission_mat @ z

        noise_x = numpyro.sample(
            "emiss_noise",
            dist.MultivariateNormal(covariance_matrix=emission_noise),
            obs=x_prev,
        )

        x += noise_x

        x = numpyro.deterministic("x", x)

        return z, (z, x)

    # prior dist
    mu0 = jnp.array([0.0, 0.0, 1.0, -1.0])
    Sigma0 = jnp.eye(4)
    z0 = numpyro.sample(
        "z0", dist.MultivariateNormal(loc=mu0, covariance_matrix=Sigma0)
    )

    # scan
    # with numpyro.handlers.condition(data={"x": x}):
    _, (z, x) = scan(body, z0, obs, length=time_steps)

    return (z, x)

### Observations

In [None]:
# obs = states["meas_noise"]
# T = obs.shape[0]
# D_obs = obs.shape[1]

In [None]:
# obs.shape

### Generative

In [None]:
with numpyro.handlers.seed(rng_seed=314):
    x, y = gaussian_hmm(None, time_steps=50)

In [None]:
t_axes = jnp.arange(x.shape[0])

fig, ax = plt.subplots()

ax.plot(x[..., 0], x[..., 1], label="Gen. State", color="black", linestyle="--")
ax.scatter(y[..., 0], y[..., 1], label="Gen. Observation", color="tab:red", alpha=0.5)

ax.set(
    xlabel="x-Position",
    ylabel="y-position",
)

plt.legend()
plt.show()

### Prior

In [None]:
%%time


# Prior prediction
predictive = infer.Predictive(gaussian_hmm, num_samples=100)

# without x
prior_predictive = predictive(rng_key_prior, time_steps=50)

# extract samples
x_samples = prior_predictive["x"]
z_samples = prior_predictive["z"]

In [None]:
fig, ax = plt.subplots()

ax.plot(
    z_samples[:10, ..., 0].T,
    z_samples[:10, ..., 1].T,
    label="Gen. State",
    color="black",
    linestyle="-",
)
ax.set(title="Generated States")
plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(
    x_samples[:10, ..., 0].T,
    x_samples[:10, ..., 1].T,
    label="Gen. State",
    color="black",
    linestyle="-",
)
ax.set(title="Generated Obs")
plt.show()

In [None]:
n_plots = 3
rand_int = 1

fig, axes = plt.subplots(nrows=3, figsize=(7, 10))

for i, iax in enumerate(axes):

    i += rand_int

    iax.plot(
        z_samples[i, ..., 0],
        z_samples[i, ..., 1],
        label="Gen. State",
        color="black",
        linestyle="-",
    )
    iax.scatter(
        x_samples[i, ..., 0],
        x_samples[i, ..., 1],
        label="Gen. Obs",
        color="red",
        linestyle="--",
    )

iax.set(
    xlabel="x-Position",
    ylabel="y-position",
)

plt.legend()
plt.show()

## Unknown Model

In [None]:
def gaussian_hmm_unknown(obs=None, time_steps: int = 15):

    # extract shapes from observations
    if obs is not None:
        time_steps, *_ = obs.shape

    # =================
    # transition model
    # =================
    trans_mat = jnp.eye(4) + dt * jnp.eye(4, k=2)
    trans_mat = numpyro.deterministic("trans_mat", trans_mat)

    transition_noise = jnp.ones(4)

    # =================
    # emission model
    # =================
    obs_mat = jnp.eye(2, M=4)
    emission_mat = numpyro.deterministic("obs_mat", obs_mat)

    emission_noise = r**2 * jnp.ones(2)

    def body(z_prev, x_prev):

        # transition distribution
        z = trans_mat @ z_prev
        noise_z = numpyro.sample("trans_noise", dist.Normal(scale=transition_noise))
        z += noise_z

        z = numpyro.deterministic("z", z)

        # emission distribution
        x = emission_mat @ z_prev

        noise_x = numpyro.sample("emiss_noise", dist.Normal(scale=emission_noise))

        x += noise_x

        x = numpyro.deterministic("x", x)

        return z, (z, x)

    # prior dist
    mu0 = jnp.array([0.0, 0.0, 0.0, 0.0])
    Sigma0 = jnp.ones(4)
    z0 = numpyro.sample("z0", dist.Normal(loc=mu0, scale=Sigma0))

    # scan
    # with numpyro.handlers.condition(data={"x": x}):
    _, (z, x) = scan(body, z0, obs, length=time_steps)

    return (z, x)

### Prior

In [None]:
with numpyro.handlers.seed(rng_seed=314):
    x, y = gaussian_hmm_unknown(x_samples, time_steps=50)

In [None]:
x.shape,

In [None]:
# Prior prediction
predictive = infer.Predictive(
    gaussian_hmm_unknown, num_samples=100, return_sites=["x", "z"]
)

# without x
prior_predictive = predictive(rng_key_prior, time_steps=50)

# extract samples
x_samples_prior = prior_predictive["x"]
z_samples_prior = prior_predictive["z"]

In [None]:
fig, ax = plt.subplots()

ax.plot(
    z_samples_prior[:10, ..., 0].T,
    z_samples_prior[:10, ..., 1].T,
    label="Gen. State",
    color="black",
    linestyle="-",
)
ax.set(title="Generated States")
plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(
    x_samples_prior[:10, ..., 0].T,
    x_samples_prior[:10, ..., 1].T,
    label="Gen. State",
    color="black",
    linestyle="-",
)
ax.set(title="Generated Observations")
plt.show()

In [None]:
n_plots = 3
rand_int = 1

fig, axes = plt.subplots(nrows=3, figsize=(7, 10))

for i, iax in enumerate(axes):

    i += rand_int

    iax.plot(
        z_samples_prior[i, ..., 0],
        z_samples_prior[i, ..., 1],
        label="Gen. State",
        color="black",
        linestyle="-",
    )
    iax.scatter(
        x_samples_prior[i, ..., 0],
        x_samples_prior[i, ..., 1],
        label="Gen. Obs",
        color="red",
        linestyle="--",
    )

iax.set(
    xlabel="x-Position",
    ylabel="y-position",
)

plt.legend()
plt.show()

## Training

In [None]:
from numpyro import diagnostics, infer, optim

In [None]:
# Inference
kernel = infer.NUTS(gaussian_hmm_unknown)
mcmc = infer.MCMC(kernel, num_warmup=200, num_samples=100)
mcmc.run(rng_key_infer, x_samples_proi)
posterior_samples = mcmc.get_samples()

### Posterior

In [None]:
%%time
# Posterior prediction
predictive = infer.Predictive(
    gaussian_hmm_unknown,
    posterior_samples=posterior_samples,
    return_sites=["x", "z"],
    num_samples=100,
)
posterior_predictive = predictive(rng_key_posterior, time_steps=50)

In [None]:
# extract samples
x_samples_learned = posterior_predictive["x"]
z_samples_learned = posterior_predictive["z"]

In [None]:
fig, ax = plt.subplots()

ax.plot(
    z_samples_learned[:10, ..., 2].T,
    z_samples_learned[:10, ..., 3].T,
    label="Gen. State",
    color="black",
    linestyle="-",
)
ax.set(title="Generated States")
plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(
    x_samples_learned[:10, ..., 0].T,
    x_samples_learned[:10, ..., 1].T,
    label="Gen. State",
    color="black",
    linestyle="-",
)
ax.set(title="Generated Observations")
plt.show()

In [None]:
n_plots = 3
rand_int = 1

fig, axes = plt.subplots(nrows=3, figsize=(7, 10))

for i, iax in enumerate(axes):

    i += rand_int

    iax.plot(
        z_samples_learned[i, ..., 0],
        z_samples_learned[i, ..., 1],
        label="Gen. State",
        color="black",
        linestyle="-",
    )
    iax.scatter(
        x_samples_learned[i, ..., 0],
        x_samples_learned[i, ..., 1],
        label="Gen. Obs",
        color="red",
        linestyle="--",
    )

iax.set(
    xlabel="x-Position",
    ylabel="y-position",
)

plt.legend()
plt.show()

In [None]:
x_lb, x_mu, x_ub = jnp.quantile(x_pred, jnp.array([0.05, 0.5, 0.95]), axis=0)

In [None]:
fig, ax = plt.subplots()

ax.plot(t_axes, x_lb, label="State")
ax.plot(t_axes, x_mu, label="State")
ax.plot(t_axes, x_ub, label="State")
ax.scatter(t_axes, obs, label="Observations", color="Red")

ax.set(
    xlabel="Time",
    ylabel="Signal",
)

plt.legend()

In [None]:
%%time
# Posterior prediction
predictive = infer.Predictive(gaussian_hmm, posterior_samples=posterior_samples)
posterior_predictive = predictive(rng_key_posterior, time_steps=T)

In [None]:
posterior_predictive

In [None]:
with numpyro.handlers.seed(rng_seed=rng_key_prior):
    x, temp = gaussian_hmm(x_pred[0])

In [None]:
fig, ax = plt.subplots()

ax.plot(t_axes, temp, label="State")
ax.scatter(t_axes, obs, label="Observations", color="Red")

ax.set(
    xlabel="Time",
    ylabel="Signal",
)

plt.legend()

## Training - SVI

In [None]:
from typing import Optional
from numpyro.distributions import constraints


def guide(x: Optional[jnp.ndarray] = None, time_steps: int = 30) -> None:

    if x is not None:
        time_steps = x.shape[0]

    phi = numpyro.param("phi", jnp.ones(1))
    sigma = numpyro.param("sigma", 0.05 * jnp.ones(1), constraint=constraints.positive)
    numpyro.sample("z", dist.Normal(x * phi, sigma))

In [None]:
from numpyro.infer.autoguide import AutoNormal, AutoDelta

In [None]:
%%time
# optimizers
lr = 1e-3
adam = optim.Adam(lr)


guide = AutoDelta(gaussian_hmm)
# def guide(x, time_steps=30):
#     return None

n_epochs = 100

# Inference
svi = infer.SVI(gaussian_hmm, guide, adam, infer.Trace_ELBO())
svi_result = svi.run(rng_key_infer, n_epochs, x)

In [None]:
svi_result.params

In [None]:
fig, ax = plt.subplots()

ax.plot(svi_result.losses)

plt.show()

In [None]:
# Posterior prediction
predictive = infer.Predictive(gaussian_hmm, params=svi_result.params, num_samples=10)
posterior_predictive = predictive(rng_key_posterior, time_steps=T)

#### Results - X

In [None]:
x_pred = posterior_predictive["x"]
x_lb, x_mu, x_ub = jnp.quantile(x_pred, jnp.array([0.05, 0.5, 0.95]), axis=0)

In [None]:
fig, ax = plt.subplots()

ax.plot(t_axes, x_lb, label="State")
ax.plot(t_axes, x_mu, label="State")
ax.plot(t_axes, x_ub, label="State")
ax.scatter(t_axes, obs, label="Observations", color="Red")

ax.set(
    xlabel="Time",
    ylabel="Signal",
)

plt.legend()

In [None]:
%%time
# optimizers
lr = 1e-2
adam = optim.Adam(lr)

# def guide(x=None, seq_len: int=0, batch:int=0, x_dim: int=1, future_steps=0, z_dim: int=2, ):
#     return None


n_epochs = 50_000

# Inference
svi = infer.SVI(kf_model, guide, adam, infer.Trace_ELBO())
svi_result = svi.run(rng_key_infer, n_epochs, x)

In [None]:
plt.plot(svi_result.losses)

In [None]:
svi_result.params

In [None]:
# Posterior prediction
predictive = infer.Predictive(kf_model, params=svi_result.params, num_samples=20)
posterior_predictive = predictive(rng_key_posterior, None, *x.shape, future_steps=10)

In [None]:
x_pred.shape, x.shape

In [None]:
x_pred = posterior_predictive["x"]

lb, pred, ub = jnp.percentile(x_pred, jnp.array([0.1, 0.5, 0.95]), axis=0)

In [None]:
fig, ax = plt.subplots()

ax.plot(x[:, 0, :], label="samples")
ax.plot(pred[:, 0, :], label="Preds")

plt.legend()
plt.show()

In [None]:
x_pred.shape

In [None]:
d = jnp.ones(2)
d = jnp.diag(d)
d

In [None]:
# initialize Kalman Filter
state_dim = 2
observation_dim = 1

# init transition model
transition_matrix = jnp.array([[1.0, 1.0], [0.0, 1.0]])  # state transition matrix
transition_noise = 1e-4 * jnp.eye((state_dim))  # state uncertainty

# check sizes
assert transition_matrix.shape == (state_dim, state_dim)
assert transition_noise.shape == (state_dim, state_dim)

# init emission model
observation_matrix = jnp.array([[1.0, 0.0]])  # emission matrix
observation_noise = 50.0 * jnp.eye((1))  # emission uncertainty

assert observation_matrix.shape == (observation_dim, state_dim)
assert observation_noise.shape == (observation_dim, observation_dim)


# Prior parameter distribution
mu0 = jnp.array([2.0, 0.0]).astype(float)
Sigma0 = jnp.eye(state_dim) * 1.0

assert mu0.shape == (state_dim,)
assert Sigma0.shape == (state_dim, state_dim)

##

In [None]:
# initia
transition_noise_dist = lgssm.MultivariateNormal(jnp.zeros(state_dim), transition_noise)

observation_noise_dist = lgssm.MultivariateNormal(
    jnp.zeros(observation_dim), observation_noise
)

initial_state_prior_dist = lgssm.MultivariateNormal(mu0, Sigma0)

kf_model = lgssm.LinearGaussianStateSpaceModel(
    transition_matrix,
    transition_noise_dist,
    observation_matrix,
    observation_noise_dist,
    initial_state_prior_dist,
)

In [None]:
%%time

log_probs, mus, sigmas, mus_cond, sigmas_cond = kf_model.forward_filter(obs_samples[0])


mus.shape, sigmas.shape

In [None]:
fig, ax = plt.subplots()


ax.scatter(ts[0], obs_samples[0], color="red", label="Observations")
ax.plot(ts[0], mus[:, 0], label="State (Filtered)", linestyle="--")

ax.set(xlabel="Time", ylabel="Signal")

plt.legend()
plt.show()

In [None]:
fig, ax = plt.subplots()


ax.scatter(ts[0], obs_samples[0], color="red", label="Observations")
ax.plot(ts[0], mus_cond[:, 0], label="State (Filtered)", linestyle="--")

ax.set(xlabel="Time", ylabel="Signal")

plt.legend()
plt.show()

In [None]:
state.shape, obs.shape

In [None]:
# initialize states
all_states, all_obs = [], []

state = state_init

for i_step in tqdm(sample_keys):

    # kalman step
    state, (state, obs) = kf_model.sample_step(state, i_step)

    # append
    all_states.append(state)
    all_obs.append(obs)


all_states = jnp.vstack(all_states)
all_obs = jnp.vstack(all_obs)

In [None]:
fig, ax = plt.subplots()

ax.plot(time_steps, all_states[:, 0], label="True State", color="green")
ax.scatter(time_steps, all_obs, label="Observations", color="red", alpha=0.4)

plt.legend()
plt.show()

In [None]:
num_time_steps = 100

all_states, all_obs = kf_model.sample(
    seed=123, sample_shape=10, num_timesteps=num_time_steps
)
states.shape, all_obs.shape

In [None]:
fig, ax = plt.subplots()

ax.plot(time_steps, all_states[3, :, 0], label="True State")
ax.scatter(time_steps, all_obs[3, :, 0], label="Observations")

plt.legend()
plt.show()

In [None]:
mu0.shape

In [None]:
kf_model.sample_step(mu0, key)

---

## Filtering

1. Do Forward Filter for a batch of inputs
2. Init prior mean, cov
3. Loop Through Kalman Step (mu0, cov0)

In [None]:
from jaxkf._src.functional.ops import kalman_step

In [None]:
num_time_steps = 15

states_preds, states_corrs = [], []

for i_t_step in trange(num_time_steps):
    
    # kalman step
    state_pred, state_corrected = kalman_step(
    
    pass

### Kalman Filter Step

### Predict Step

## Smoothing

### Posterior Marginals (Alternative)

## Log Probability

## Model

##### Initialize Parameters

In [None]:
kf_params = KFParams(F=F, R=R, H=H, Q=Q)