In [None]:
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import numpyro

numpyro.set_platform("cpu")  # "cpu"

## Discrete Hidden Markov Model

### Log Prob

In [None]:
def logmatmulexp(x, y):
    x_shift = x.max(-1, keepdims=True)
    y_shift = y.max(-2, keepdims=True)
    return (
        jnp.log(jnp.exp(x - x_shift) @ jnp.exp(y - y_shift)) + x_shift + y_shift,
        None,
    )

### Transition Functions

#### Sequential

In [None]:
@jax.jit
def sequential(x_init, xs):
    o, _ = jax.lax.scan(logmatmulexp, xs[0], xs[1:])
    o = logmatmulexp(jnp.expand_dims(x_init, -2), o)[0]
    return logsumexp(o.squeeze(-2), -1)

#### Forward

In [None]:
@jax.jit
def forward(x_init, xs):
    o, _ = jax.lax.scan(logmatmulexp, jnp.expand_dims(x_init, -2), xs)
    return logsumexp(o.squeeze(-2), -1)

#### Parallel

In [None]:
@jax.jit
def parallel(x_init, xs):
    batch_shape = xs.shape[:-3]
    state_dim = xs.shape[-1]
    while xs.shape[-3] > 1:
        time = xs.shape[-3]
        even_time = time // 2 * 2
        even_part = xs[..., :even_time, :, :]
        a_b = even_part.reshape(batch_shape + (even_time // 2, 2, state_dim, state_dim))
        a, b = a_b[..., 0, :, :], a_b[..., 1, :, :]
        contracted = logmatmulexp(a, b)[0]
        if time > even_time:
            contracted = jnp.concatenate((contracted, xs[..., -1:, :, :]), axis=-3)
        xs = contracted
    o = logmatmulexp(jnp.expand_dims(x_init, -2), xs.squeeze(-3))[0]
    return logsumexp(o.squeeze(-2), -1)

### Testing

In [None]:
dim = 3
x = jax.random.normal(jax.random.PRNGKey(0), (2000, dim, dim))
x_init = jax.random.normal(jax.random.PRNGKey(1), (dim,))
sequential(x_init, x)
parallel(x_init, x)
forward(x_init, x)

In [None]:
%timeit y = sequential(x_init, x).copy()
%timeit y = parallel(x_init, x).copy()
%timeit y = forward(x_init, x).copy()


## Continuous Hidden Markov Model

In [None]:
import pathlib
from typing import Dict, Optional, Tuple

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro import diagnostics, infer
from numpyro.contrib.control_flow import scan

### From Scratch

In [None]:
def kf_model(
    x_obs: Optional[jnp.ndarray] = None,
    seq_len: int = 0,
    batch: int = 0,
    x_dim: int = 1,
    future_steps: int = 0,
    z_dim: int = 3,
) -> None:
    """Simple Kalman filter model (random walk).
    Args:
        x: **Batch-first** data, `shape = (seq_len, batch, data_dim)`.
        future_steps: Forecasting time steps.
        batch: Batch size for prior sampling.
        x_dim: Dimension of data for prior sampling.
    """

    if x_obs is not None:
        seq_len, batch, x_dim = x_obs.shape

    # Transition Distributions
    trans = numpyro.sample(
        "trans", dist.Normal(jnp.zeros((z_dim, z_dim)), jnp.ones((z_dim, z_dim)))
    )
    z_std = numpyro.sample("z_std", dist.Gamma(jnp.ones(z_dim), jnp.ones(z_dim)))

    # Emission Distributions
    emit = numpyro.sample(
        "emit", dist.Normal(jnp.zeros((x_dim, z_dim)), jnp.ones((x_dim, z_dim)))
    )
    x_std = numpyro.sample("x_std", dist.Gamma(jnp.ones(x_dim), jnp.ones(x_dim)))

    # Prior Distribution
    z0 = jnp.zeros((z_dim,))

    # Model
    def body(z_prev, x_prev):

        # transition
        z = numpyro.sample(f"z", dist.Normal(jnp.dot(trans, z_prev), z_std))

        # emission
        x = numpyro.sample(f"x", dist.Normal(jnp.dot(emit, z), x_std), obs=x_prev)

        return z, (z, x)

    # Loop through data
    _, (z, x) = scan(body, z0, x_obs, length=seq_len + future_steps)

    return z, x


#     def transition_fn(
#         carry: Tuple[jnp.ndarray], t: jnp.ndarray
#     ) -> Tuple[Tuple[jnp.ndarray], jnp.ndarray]:

#         z_prev, *_ = carry
#         z = numpyro.sample("z", dist.Normal(jnp.matmul(z_prev, trans), z_std))
#         numpyro.sample("x", dist.Normal(jnp.matmul(z, emit), x_std))
#         return (z,), None

#     z_init = jnp.zeros((batch, z_dim))
#     with numpyro.handlers.condition(data={"x": x}):
#         scan(transition_fn, (z_init,), jnp.arange(seq_len + future_steps))

#### Without Observations

In [None]:
seq_len = 40
x_dim = 2
z_dim = 3
batch = 10

with numpyro.handlers.seed(rng_seed=123):
    z_s, x_s = kf_model(seq_len=seq_len, x_dim=x_dim, z_dim=z_dim, batch=batch)

In [None]:
plt.plot(x_s[..., 0])

#### Observations

In [None]:
with numpyro.handlers.seed(rng_seed=123):
    z_s, x_s = kf_model(x)

print(z_s.shape, x_s.shape)

In [None]:
plt.plot(x_s[..., 0])

In [None]:
rng_key = random.PRNGKey(0)
rng_key, rng_key_prior, rng_key_infer, rng_key_posterior = random.split(rng_key, 4)

In [None]:
# prior
seq_len, batch, x_dim = x.shape

predictive = infer.Predictive(kf_model, num_samples=20)
prior_samples = predictive(
    rng_key_prior, None, seq_len=seq_len, batch=batch, x_dim=x_dim, future_steps=0
)

In [None]:
[prior_samples.keys()]

In [None]:
prior_samples["z"].shape

In [None]:
jnp.stack(z, axis=0).shape

In [None]:
z.shape, x.shape

### Sequential

In [None]:
def model(
    x: Optional[jnp.ndarray] = None,
    seq_len: int = 0,
    batch: int = 0,
    x_dim: int = 1,
    future_steps: int = 0,
    z_dim: int = 3,
) -> None:
    """Simple Kalman filter model (random walk).
    Args:
        x: **Batch-first** data, `shape = (seq_len, batch, data_dim)`.
        future_steps: Forecasting time steps.
        batch: Batch size for prior sampling.
        x_dim: Dimension of data for prior sampling.
    """

    if x is not None:
        seq_len, batch, x_dim = x.shape

    # Transition Distributions
    trans = numpyro.sample(
        "trans", dist.Normal(jnp.zeros((z_dim, z_dim)), jnp.ones((z_dim, z_dim)))
    )
    z_std = numpyro.sample("z_std", dist.Gamma(jnp.ones(z_dim), jnp.ones(z_dim)))

    # Emission Distributions
    emit = numpyro.sample(
        "emit", dist.Normal(jnp.zeros((z_dim, x_dim)), jnp.ones((z_dim, x_dim)))
    )
    x_std = numpyro.sample("x_std", dist.Gamma(jnp.ones(x_dim), jnp.ones(x_dim)))

    # Prior Distribution
    z_init = jnp.zeros((batch, z_dim))

    def transition_fn(
        carry: Tuple[jnp.ndarray], t: jnp.ndarray
    ) -> Tuple[Tuple[jnp.ndarray], jnp.ndarray]:

        z_prev, *_ = carry
        z = numpyro.sample("z", dist.Normal(jnp.matmul(z_prev, trans), z_std))
        numpyro.sample("x", dist.Normal(jnp.matmul(z, emit), x_std))
        return (z,), None

    z_init = jnp.zeros((batch, z_dim))
    with numpyro.handlers.condition(data={"x": x}):
        scan(transition_fn, (z_init,), jnp.arange(seq_len + future_steps))

In [None]:
def _load_dataset(n_batches: int = 10) -> jnp.ndarray:

    x0 = jnp.concatenate(
        [
            np.random.randn(10, n_batches),
            np.random.randn(10, n_batches) + 1,
            np.random.randn(10, n_batches) + 1.2,
            np.random.randn(10, n_batches) + 2,
        ]
    )

    print(x0.shape)

    x1 = jnp.concatenate(
        [
            np.random.randn(10, n_batches) - 0.2,
            np.random.randn(10, n_batches) - 1,
            np.random.randn(10, n_batches) - 2.7,
            np.random.randn(10, n_batches) - 4.2,
        ]
    )

    x = jnp.concatenate([x0[..., None], x1[..., None]], axis=-1)
    assert isinstance(x, jnp.ndarray)

    return x

In [None]:
n_batches = 2

x = _load_dataset(n_batches)
x.shape

In [None]:
with numpyro.handlers.seed(rng_seed=123):
    y = model(x)

In [None]:
y.shape

In [None]:
fig, ax = plt.subplots(nrows=2)

ax[0].plot(x[..., 0, 0])
ax[1].plot(x[..., 0, 1])

plt.show()

In [None]:
rng_key = random.PRNGKey(0)
rng_key, rng_key_prior, rng_key_infer, rng_key_posterior = random.split(rng_key, 4)

In [None]:
# prior
predictive = infer.Predictive(model, num_samples=10)
prior_samples = predictive(rng_key_prior, None, *x.shape, future_steps=20)

In [None]:
x.shape

In [None]:
# Inference
kernel = infer.NUTS(kf_model)
mcmc = infer.MCMC(kernel, num_warmup=100, num_samples=100)
mcmc.run(rng_key_infer, x, seq_len=seq_len, batch=batch, x_dim=x_dim, future_steps=0)
posterior_samples = mcmc.get_samples()

In [None]:
# Posterior prediction
predictive = infer.Predictive(model, posterior_samples=posterior_samples)
posterior_predictive = predictive(rng_key_posterior, None, *x.shape, future_steps=20)

In [None]:
def _save_results(
    x: jnp.ndarray,
    prior_samples: Dict[str, jnp.ndarray],
    posterior_samples: Dict[str, jnp.ndarray],
    posterior_predictive: Dict[str, jnp.ndarray],
    num_train: int,
) -> None:

    root = pathlib.Path("./")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "piror_samples.npz", **prior_samples)
    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)

    x_pred = posterior_predictive["x"]

    x_pred_trn = x_pred[:, :num_train]
    x_hpdi_trn = diagnostics.hpdi(x_pred_trn)
    t_train = np.arange(num_train)

    x_pred_tst = x_pred[:, num_train:]
    x_hpdi_tst = diagnostics.hpdi(x_pred_tst)
    num_test = x_pred_tst.shape[1]
    t_test = np.arange(num_train, num_train + num_test)

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    plt.figure(figsize=(12, 12))

    plt.subplot(211)
    plt.plot(x[..., 0, 0], label="ground truth", color=colors[0])

    plt.plot(
        t_train, x_pred_trn[..., 0, 0].mean(0), label="prediction", color=colors[1]
    )
    plt.fill_between(
        t_train,
        x_hpdi_trn[0, :, 0, 0],
        x_hpdi_trn[1, :, 0, 0],
        alpha=0.3,
        color=colors[1],
    )

    plt.plot(t_test, x_pred_tst[..., 0, 0].mean(0), label="forecast", color=colors[2])
    plt.fill_between(
        t_test,
        x_hpdi_tst[0, :, 0, 0],
        x_hpdi_tst[1, :, 0, 0],
        alpha=0.3,
        color=colors[2],
    )

    plt.legend()

    plt.subplot(212)
    plt.plot(x[..., 0, 1], label="ground truth", color=colors[0])

    plt.plot(
        t_train, x_pred_trn[..., 0, 1].mean(0), label="prediction", color=colors[1]
    )
    plt.fill_between(
        t_train,
        x_hpdi_trn[0, :, 0, 1],
        x_hpdi_trn[1, :, 0, 1],
        alpha=0.3,
        color=colors[1],
    )

    plt.plot(t_test, x_pred_tst[..., 0, 1].mean(0), label="forecast", color=colors[2])
    plt.fill_between(
        t_test,
        x_hpdi_tst[0, :, 0, 1],
        x_hpdi_tst[1, :, 0, 1],
        alpha=0.3,
        color=colors[2],
    )

    plt.legend()

    plt.tight_layout()
    plt.savefig("kalman_multi.png")
    plt.close()

In [None]:
_save_results(x, prior_samples, posterior_samples, posterior_predictive, len(x))