In [None]:
import sys, os
from pyprojroot import here


# spyder up to find the root
root = here(project_files=[".home"])

# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
from jax import random


from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
from filterjax._src.models.lgssm import StateSpaceModelDiag
import jax

np.random.seed(123)
%load_ext autoreload
%autoreload 2

In [None]:
Nt = 100
t = np.linspace(0, 10, Nt)
x = np.stack([np.sin(t), np.cos(t)]).T
y = x + np.random.normal(0, 0.1, (Nt, 2))
mask = np.random.randint(low=0, high=1, size=(Nt, 2))
plt.scatter(t, y[:, 0])

plt.scatter(t, y[:, 1])
plt.plot(t, np.sin(t))
plt.plot(t, np.cos(t))
plt.show()

In [None]:
y_true = x[None, ...]
y_train = y[None, ...]
y_mask = mask[None, ...]
y_train.shape, y_mask.shape

In [None]:
t = np.linspace(0, 10, Nt)
y = np.stack([np.sin(t), np.cos(t)]).T + np.random.normal(0, 0.1, (Nt, 2))
idx = np.random.choice(np.arange(0, t.shape[0]), size=(Nt - 20))
y_masked = y.copy()
y_masked[idx, 0] = np.nan
y_masked[idx[::-1], 1] = np.nan
mask = np.isnan(y_masked).astype(np.float64)

plt.show()
plt.scatter(t, y_masked[:, 0])
plt.scatter(t, y_masked[:, 1])
plt.plot(t, np.sin(t))
plt.plot(t, np.cos(t))
plt.show()

In [None]:
y_true = jnp.array(x[None, ...])
y_train = jnp.array(jnp.nan_to_num(y_masked[None, ...]))
y_mask = jnp.array(mask[None, ...])
y_train.shape, y_mask.shape

In [None]:
y_mask.min(), y_mask.max()

In [None]:
y_train.shape, y_mask.shape

In [None]:
rng = np.random.RandomState(123)

num_timesteps = y_train.shape[1]
state_size = 6
obs_size = 2
mu0 = jnp.zeros(state_size)
mu0 = jnp.asarray(mu0)
Sigma0 = jnp.ones(state_size) * 0.01


# transition_matrix = jnp.eye(state_size)
noise_std = 0.2
transition_matrix = rng.randn(state_size, state_size) / state_size
transition_matrix = jnp.asarray(transition_matrix)
transition_noise = noise_std * jnp.ones(state_size)

# obsercation
noise_std = 0.1
observation_matrix = rng.randn(obs_size, state_size) / state_size
observation_matrix = jnp.asarray(observation_matrix)
observation_noise = noise_std * jnp.ones(obs_size)

In [None]:
observation_matrix.shape, observation_noise.shape

In [None]:
def build_model(params):
    # parse params
    mu0 = params["prior_mu"]
    Sigma0 = params["prior_sigma"]
    transition_noise = params["trans"]
    observation_noise = params["obs"]
    transition_matrix = params["trans_mat"]
    observation_matrix = params["obs_mat"]

    # transition_noise = jax.nn.softplus(transition_noise)
    # observation_noise = jax.nn.softplus(observation_noise)

    # print(mu0.shape, transition_noise.shape, observation_noise.shape)

    # build model
    model = StateSpaceModelDiag(
        mu0=mu0,
        Sigma0=Sigma0,
        transition_matrix=transition_matrix,
        transition_noise=transition_noise,
        observation_matrix=observation_matrix,
        observation_noise=observation_noise,
    )
    return model

In [None]:
init_params = {
    "prior_mu": mu0,
    "prior_sigma": Sigma0,
    "trans_mat": transition_matrix,
    "trans": transition_noise,
    "obs_mat": observation_matrix,
    "obs": observation_noise,
}

model = build_model(init_params)

In [None]:
model.prior_dist.event_shape, model.prior_dist.batch_shape

## Draw Samples

In [None]:
%%time

n_samples = 5
n_time_steps = 100
sample_prior = True
seed = 123  # 31415
key = random.PRNGKey(seed)

# noise_samples = model.transition_noise.sample(sample_shape=n_samples, seed=key)
state_samples, obs_samples = model.sample(
    key, n_time_steps=n_time_steps, n_samples=n_samples, sample_prior=sample_prior
)

print(state_samples.shape, obs_samples.shape)

In [None]:
for i in range(n_samples):
    plt.figure()
    # plt.plot(state_samples[i, ...], linestyle="-", linewidth=3, marker="")
    plt.plot(obs_samples[i, ...], linestyle="", marker=".", color="red")
    plt.show()

### Forward Filter

In [None]:
# init_params = {"prior": mu0, "trans": transition_noise, "obs": observation_noise}

# model = build_model(init_params)

In [None]:
%%time

(
    filtered_z_means,
    filtered_z_covs,
    filtered_x_means,
    filtered_x_covs,
    log_probs,
) = model.forward_filter(y_train, y_mask)
assert filtered_z_means.shape == (1, y_train.shape[1], state_size)
assert filtered_z_covs.shape == (1, y_train.shape[1], state_size, state_size)

assert filtered_x_means.shape == (1, y_train.shape[1], 2)
assert filtered_x_covs.shape == (1, y_train.shape[1], 2, 2)

In [None]:
# marginal_log_likelihood = jnp.sum(sols[2], axis=-1)
# nll = - jnp.mean(marginal_log_likelihood)
# nll

In [None]:
i = 0
plt.figure()
plt.plot(y_true[i, ..., 0], linestyle="-", linewidth=3, marker="", label="True State")
plt.plot(y_true[i, ..., 1], linestyle="-", linewidth=3, marker="", label="True State")
plt.plot(
    filtered_x_means[i, ..., 0],
    linestyle="--",
    linewidth=3,
    marker="",
    label="Filtered (x1)",
)
plt.plot(
    filtered_x_means[i, ..., 1],
    linestyle="--",
    linewidth=3,
    marker="",
    label="Filtered (x2)",
)
plt.plot(
    y_train[i, ..., 0], linestyle="", marker=".", color="red", label="Observations"
)
plt.plot(
    y_train[i, ..., 1],
    linestyle="",
    marker=".",
    color="red",
)
plt.legend()
plt.show()



## Backwards Smoothing

In [None]:
%%time

# smoothed_z_means, smoothed_z_covs, j = model.backward_smoothing_pass(filtered_z_means, filtered_z_covs)
# assert smoothed_z_means.shape == (1, X_Train.shape[1], 1)
# assert smoothed_z_covs.shape == (1, X_Train.shape[1], 1, 1)

smoothed_x_means, smoothed_x_covs = model.posterior_marginals(y_train, y_mask)

In [None]:
plt.figure()
plt.plot(y_true[i, ..., 0], linestyle="-", linewidth=3, marker="", label="True State")
plt.plot(y_true[i, ..., 1], linestyle="-", linewidth=3, marker="", label="True State")
plt.plot(
    smoothed_x_means[i, ..., 0], linestyle="-", linewidth=3, marker="", label="filtered"
)
plt.plot(
    smoothed_x_means[i, ..., 1], linestyle="-", linewidth=3, marker="", label="smoothed"
)
plt.plot(
    y_train[i, ..., 0], linestyle="", marker=".", color="red", label="Observations"
)
plt.plot(
    y_train[i, ..., 1],
    linestyle="",
    marker=".",
    color="red",
)
plt.legend()
plt.show()

## Training

### Model

In [None]:
def build_model(params):
    # parse params
    mu0 = params["prior_mu"]
    Sigma0 = params["prior_sigma"]
    transition_noise = params["trans"]
    observation_noise = params["obs"]
    transition_matrix = params["trans_mat"]
    observation_matrix = params["obs_mat"]

    transition_noise = jax.nn.softplus(transition_noise)
    observation_noise = jax.nn.softplus(observation_noise)
    # transition_matrix = jax.nn.softplus(transition_matrix)
    # observation_matrix = jax.nn.softplus(observation_matrix)
    Sigma0 = jax.nn.softplus(Sigma0)
    # print(mu0.shape, transition_noise.shape, observation_noise.shape)

    # build model
    model = StateSpaceModelDiag(
        mu0=mu0,
        Sigma0=Sigma0,
        transition_matrix=transition_matrix,
        transition_noise=transition_noise,
        observation_matrix=observation_matrix,
        observation_noise=observation_noise,
    )
    return model

### Parameters

In [None]:
rng = np.random.RandomState(123)

num_timesteps = y_train.shape[1]
state_size = 6
obs_size = 2
mu0 = jnp.zeros(state_size)
mu0 = jnp.asarray(mu0)
Sigma0 = jnp.ones(state_size)


# transition_matrix = jnp.eye(state_size)
noise_std = 1.0
transition_matrix = rng.randn(state_size, state_size) / state_size
transition_matrix = jnp.asarray(transition_matrix)
transition_noise = noise_std * jnp.ones(state_size)

# obsercation
noise_std = 1.0
observation_matrix = rng.randn(obs_size, state_size) / state_size
observation_matrix = jnp.asarray(observation_matrix)
observation_noise = noise_std * jnp.ones(obs_size)

init_params = {
    "prior_mu": mu0,
    "prior_sigma": Sigma0,
    "trans_mat": transition_matrix,
    "trans": transition_noise,
    "obs_mat": observation_matrix,
    "obs": observation_noise,
}

#### Objective Function

In [None]:
def objective_func(params, obs, mask):

    kf_model = build_model(params)

    loss = kf_model.negative_log_likelihood(obs, mask)
    return loss

In [None]:
objective_func(init_params, y_train, None)

In [None]:
objective_func(init_params, y_train, jnp.zeros_like(y_train))

In [None]:
objective_func(init_params, y_train, jnp.ones_like(y_train))

In [None]:
objective_func(init_params, y_train, y_mask)

#### LBGFS

In [None]:
import jaxopt

solver = jaxopt.ScipyMinimize(fun=objective_func, method="L-BFGS-B")
# solver = jaxopt.LBFGS(fun=objective_func, maxiter=10_000)
soln = solver.run(init_params, obs=y_train, mask=y_mask)

In [None]:
soln.state.fun_val

In [None]:
soln.params

In [None]:
model = build_model(soln.params)

### Gradient Descent

In [None]:
import optax

In [None]:
def make_adam_optimizer(
    learning_rate,
    lr_schedule,
    b1=0.9,
    b2=0.999,
    eps=1e-8,
):
    """Make Adam optimizer."""
    # Maximize log-prob instead of minimizing loss
    return optax.chain(
        optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
        optax.scale_by_adam(b1=b1, b2=b2, eps=eps),
        optax.scale_by_schedule(lr_schedule),
        optax.scale(-learning_rate),
    )


def make_cosine_lr_schedule(init_lr, total_steps):
    """Cosine LR schedule."""

    def schedule(step):
        t = step / total_steps
        return 0.5 * init_lr * (1 + jnp.cos(t * np.pi))

    return schedule

In [None]:
lr = 0.05
decay_steps = 200
n_iterations = 3_000

cosine_decay_schedule = optax.cosine_decay_schedule(
    init_value=lr, decay_steps=decay_steps, alpha=0.95
)

# schedule = make_cosine_lr_schedule(lr, total_steps=n_iterations)
tx = make_adam_optimizer(lr, cosine_decay_schedule)

In [None]:
# tx = optax.adam(learning_rate=0.0005)
params = init_params.copy()

In [None]:
# tx = optax.adam(learning_rate=0.0005)
# params = init_params.copy()
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(objective_func))

In [None]:
from tqdm.notebook import trange

In [None]:
losses = []

# n_iterations = 3_001
with trange(n_iterations) as pbar:
    for i in pbar:
        loss_val, grads = loss_grad_fn(params, y_train, y_mask)
        losses.append(loss_val)
        updates, opt_state = tx.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        pbar.set_description(f"Loss: {loss_val:.4f}")

In [None]:
plt.plot(losses)

In [None]:
params

In [None]:
model = build_model(params)

In [None]:
model._transition_noise, model._observation_noise

## Results

### Filtering

In [None]:
%%time

*_, filtered_x_means, filtered_x_covs, _ = model.forward_filter(y_train, y_mask)

In [None]:
i = 0
plt.figure()
plt.plot(y_true[i, ..., 0], linestyle="-", linewidth=3, marker="", label="True State")
plt.plot(y_true[i, ..., 1], linestyle="-", linewidth=3, marker="", label="True State")
plt.plot(
    filtered_x_means[i, ..., 0],
    linestyle="--",
    linewidth=3,
    marker="",
    label="Filtered (x1)",
)
plt.plot(
    filtered_x_means[i, ..., 1],
    linestyle="--",
    linewidth=3,
    marker="",
    label="Filtered (x2)",
)
plt.plot(y_masked[..., 0], linestyle="", marker=".", color="red", label="Observations")
plt.plot(
    y_masked[..., 1],
    linestyle="",
    marker=".",
    color="red",
)
plt.legend()
plt.show()

### Smoothing

In [None]:
smoothed_x_means, smoothed_x_covs = model.posterior_marginals(y_train, y_mask)

In [None]:
plt.figure()
plt.plot(
    y_true[i, ..., 0],
    linestyle="-",
    linewidth=3,
    marker="",
    color="black",
    label="True State",
)
plt.plot(
    y_true[i, ..., 1],
    linestyle="-",
    linewidth=3,
    marker="",
    color="black",
)
plt.plot(
    smoothed_x_means[i, ..., 0],
    linestyle="-",
    linewidth=3,
    marker="",
    color="blue",
    label="filtered",
)
plt.plot(
    smoothed_x_means[i, ..., 1],
    linestyle="-",
    linewidth=3,
    marker="",
    color="green",
    label="smoothed",
)
plt.plot(y_masked[..., 0], linestyle="", marker=".", color="red", label="Observations")
plt.plot(
    y_masked[..., 1],
    linestyle="",
    marker=".",
    color="red",
)
plt.legend()
plt.show()

### Samples

In [None]:
%%time

n_samples = 10
n_time_steps = 100
sample_prior = True
seed = 123  # 31415
key = random.PRNGKey(seed)

# noise_samples = model.transition_noise.sample(sample_shape=n_samples, seed=key)
state_samples, obs_samples = model.sample(
    key, n_time_steps=n_time_steps, n_samples=n_samples, sample_prior=sample_prior
)

print(state_samples.shape, obs_samples.shape)

In [None]:
for i in range(n_samples):
    plt.figure()
    # plt.plot(state_samples[i, ...], linestyle="-", linewidth=3, marker="")
    plt.plot(obs_samples[i, ...], linestyle="", marker=".", color="red")
    plt.show()