In [None]:
import sys, os
from pyprojroot import here


# spyder up to find the root
root = here(project_files=[".local"])
local = here(project_files=[".local"])

# append to path
sys.path.append(str(root))
sys.path.append(str(local))

%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.distributions as dist
from einops import repeat, rearrange

## Torch Implementation

In [None]:
state_dim = 8
obs_dim = 4
n_batch = 1

fn = lambda x: repeat(x, "... -> B ...", B=n_batch)

# matrices
F = torch.randn(state_dim, state_dim)
R_noise = 0.5**2 * torch.ones(obs_dim)
R_cov = torch.diag(R_noise)
H = torch.randn(size=(obs_dim, state_dim))
Q = 0.5**2 * torch.eye(obs_dim, obs_dim)


# states
x = torch.randn(size=(state_dim,))
x_batch = fn(x)
P = 0.01 * torch.eye(state_dim)
P_batch = fn(P)

# pred obs
mu_pred = torch.randn(size=(obs_dim,))
mu_pred_batch = fn(mu_pred)

Sigma_pred = 10 * torch.eye(obs_dim)
Sigma_pred_batch = fn(Sigma_pred)

# observations
mask = torch.randint(low=0, high=2, size=(obs_dim,))
mask_batch = torch.randint(
    low=0,
    high=2,
    size=(
        n_batch,
        obs_dim,
    ),
)
obs = torch.randn(size=(obs_dim,))
obs_batch = fn(obs)

In [None]:
R_noise.shape, R_cov.shape

## Batches

In [None]:
R_noise_masked = fn(R_noise)
assert R_noise_masked.shape == (n_batch, obs_dim)

identities = torch.eye(obs_dim, obs_dim).repeat(n_batch, 1, 1)
assert identities.shape == (n_batch, obs_dim, obs_dim)

H_masked = H.repeat(n_batch, 1, 1)
assert H_masked.shape == (n_batch, obs_dim, state_dim)

In [None]:
# maskv = mask.unsqueeze(0)
# cov_mask = (maskv.to(torch.bool) + maskv.to(torch.bool).T)
# cov_mask = (maskv + maskv.T)
# maskv.shape, cov_mask.shape

### Masked Parameter Matrices

#### Mask Observation Operator

$$
\mathbf{H}_t = 
\begin{bmatrix}
\mathbf{H}_t^{\text{obs}} \\
\mathbf{H}_t^{\text{missing}}
\end{bmatrix} =
\begin{bmatrix}
\mathbf{H}_t^{\text{obs}} \\
\mathbf{0}
\end{bmatrix}
$$

In [None]:
from torchkf.kf import mask_observation_operator

In [None]:
# H_masked = mask_observation_operator(H, mask)
# assert H_masked.shape == H.shape

# H_masked_batched = mask_observation_operator(fn(H), fn(mask))
# assert H_masked_batched.shape == fn(H).shape

#### Mask Noise Operator

$$
\mathbf{R}_t = 
\begin{bmatrix}
\mathbf{R}_{11t}^{\text{obs}} & \mathbf{R}_{12t}^{\text{cross}}\\
\mathbf{R}_{21t}^{\text{cross}} & \mathbf{R}_{22t}^{\text{missing}}
\end{bmatrix} =
\begin{bmatrix}
\mathbf{R}_{11t}^{\text{obs}} & \mathbf{0}\\
\mathbf{0} & \mathbf{I}
\end{bmatrix}
$$

In [None]:
def update_batched(x, P, obs, H, R, mask=None):

    # create masks
    if mask is not None:
        H = mask_observation_operator(H, mask)
        R = mask_observation_noise_diag(R, mask)

    # emission update
    pred_sigma = H @ P @ H.transpose(0, 1) + R

    # UPDATES
    K = stable_kalman_gain(H, P, pred_sigma)

    v = obs - torch.einsum("so,bo->bs", H, x)
    x = x + torch.einsum("bso,bo->bs", K, v)

    identity = torch.eye(*P.shape[1:], device=P.device)

    # joseph form for numerical stability
    P = (identity - K @ H) @ P @ (identity - K @ H).transpose(
        1, 2
    ) + K @ R @ K.transpose(1, 2)

    return x, P

In [None]:
x_batch.shape, P_batch.shape, obs_batch.shape, H.shape, R_noise.shape

In [None]:
# x_new, P_new = update_batched(x_batch, P_batch, obs_batch, H, R_noise, mask=mask, diag=True)

# assert x_new.shape == x_batch.shape
# assert P_new.shape == P_batch.shape

### Masked Likelihood

In [None]:
import math

INV2PI = (2 * math.pi) ** -1


def masked_multivariate_likelihood(x, mean, cov, mask=None):
    """Masked Likelihood for full covariance matrices

    Parameters
    ----------
    x : torch.Tensor, sha"""

    if mask is not None:
        maskv = mask.unsqueeze(0)
        # fill x values with zeros
        x = x.masked_fill(mask == 0, 0)

        # fill mean values with zeros
        mean = mean.masked_fill(mask == 0, 0)

        # ensure masked entries are independent
        cov_masked = cov.masked_fill(maskv + maskv.T == 0, 0)

        # ensure masked entries return log likelihood of 0
        cov = cov_masked.masked_fill(torch.diag(mask) == 0, INV2PI)

    return dist.MultivariateNormal(mean, cov).log_prob(x)

In [None]:
n_batches = 1
obs_dim = 3
torch.manual_seed(234)
x = torch.randn(n_batches, obs_dim)
mask = torch.randint(0, 2, size=(obs_dim,))
mean = torch.randn(n_batches, obs_dim)
cov = torch.randn(n_batches, obs_dim, obs_dim)
cov = cov @ cov.transpose(1, 2)

In [None]:
mask

In [None]:
-masked_multivariate_likelihood(x, mean, cov, mask=mask).mean()

## Demo

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
Nt = 100
t = np.linspace(0, 10, Nt)
y = np.stack([np.sin(t), np.cos(t)]).T + np.random.normal(0, 0.1, (Nt, 2))
mask = torch.randint(low=0, high=1, size=(Nt, 2))
plt.scatter(t, y[:, 0])

plt.scatter(t, y[:, 1])
plt.plot(t, np.sin(t))
plt.plot(t, np.cos(t))
plt.show()

In [None]:
y_train = torch.Tensor(y).unsqueeze(0)
y_train_batches = repeat(y_train, "1 ... -> B ...", B=10)

## Model

In [None]:
from torchkf.kf import transition_predict, emission_predict, update_step, predict_step
from torchkf.kf import (
    mask_observation_noise_diag,
    masked_multivariate_likelihood,
    mask_observation_operator,
)

In [None]:
class DiscreteKalmanFilter(nn.Module):
    def __init__(self, obs_dim, latent_dim):
        super().__init__()
        self.obs_dim = obs_dim
        self.latent_dim = latent_dim

        # State update: x_{k} = A @ x_{k-1} + q
        self.F = nn.Parameter(torch.randn(latent_dim, latent_dim) / latent_dim)
        self.pre_Q = nn.Parameter(torch.ones(latent_dim))

        # Emission: y_{k} = H @ x_{k} + r
        self.H = nn.Parameter(torch.randn(obs_dim, latent_dim) / latent_dim)
        self.pre_R = nn.Parameter(torch.ones(obs_dim))

        # Priors
        self.x0 = torch.zeros(latent_dim)
        self.P0 = torch.eye(latent_dim, latent_dim)

    @property
    def obs_noise(self):
        """
        Calculate observation noise covariance
        """
        # pre_R_norm = torch.sigmoid(self.pre_R) * (self.noise_upper - self.noise_lower) + self.noise_lower
        return torch.eye(self.obs_dim, self.obs_dim) * self.pre_R**2

    @property
    def trans_noise(self):
        """
        Calculate process noise covariance
        """
        return torch.eye(self.latent_dim, self.latent_dim) * self.pre_Q**2

    def forward(self, z, mask=None):
        """
        Forward pass for the Kalman Filter

        Keyword arguments:
        z -- observed values (torch.Tensor)
        Returns:
        loss - NLL of observed sequence in predicted probability dist (torch.Tensor)
        """
        assert isinstance(z, torch.Tensor)
        n_batches, n_time = z.shape[:2]

        pred_means, pred_sigmas, x, P, log_probs = self.filter_forward(z, mask=mask)
        # print(z.shape, pred_means.shape)
        # print(pred_means.shape, pred_sigmas.shape, x.shape, P.shape, z.shape)
        # evaluate observed sequence in predicted distribution
        # dist = torch.distributions.MultivariateNormal(pred_means, pred_sigmas)
        log_probs = log_probs.sum(dim=1) / n_time
        loss = -log_probs.mean()
        return loss

    def predict(self, x, P):
        """
        Update state mean and covariance p(x_{k} | x_{k-1}) and calculate mean and
        covariance in the observation space in the case of discrete time steps
        """
        n_batch = x.shape[0]

        x, P, y_pred, y_sigma = predict_step(
            x, P, self.F, self.trans_noise, self.H, self.Q
        )

        assert x.shape == (n_batch, self.latent_dim)
        assert P.shape == (n_batch, self.latent_dim, self.latent_dim)
        return x, P, pred_mean, pred_sigma

    #     def emission(self, x, P):
    #         """
    #         emission from state space m & P to observed space mean & sigma
    #         """
    #         # create masks
    #         H = self.H
    #         R = self.trans_noise

    #         # print(H.shape, R.shape)
    #         pred_mean = torch.einsum("ij,kj->ki", H, x)

    #         pred_sigma = torch.einsum("ij,kjl,ml->kim", H, P, H) + R
    #         assert pred_sigma.ndim == 3
    #         assert pred_sigma.shape[0] == P.shape[0]
    #         assert pred_sigma.shape[1] == self.obs_dim
    #         assert pred_mean.shape[0] == x.shape[0]
    #         assert pred_mean.shape[1] == self.obs_dim
    #         return pred_mean, pred_sigma

    def update(self, x, P, z, pred_sigma, mask=None):
        """
        Update state x and P after the observation,
        outputs filtered state and covariance
        """
        assert x.ndim == 2
        assert P.ndim == 3
        assert z.ndim == 2

        n_batch = x.shape[0]

        H = self.H
        R = self.trans_noise

        # create masks
        if mask is not None:
            H = mask_observation_operator(H, mask)
            R = mask_observation_noise_diag(R, mask)

        # print(H.shape, R.shape)

        # Update state mean and covariance p(x | y), Joseph Form
        # Kalman gain, a more stable implementation than naive P @ H^T @ y_sigma^{-1}
        L = torch.linalg.cholesky(
            pred_sigma
            + 1e-6 * torch.eye(pred_sigma.shape[-1], device=pred_sigma.device)
        )
        K = torch.triangular_solve(H @ P.transpose(1, 2), L, upper=False)[0]
        K = torch.triangular_solve(K, L.transpose(1, 2))[0].transpose(1, 2)

        # v = z - self.H @ x
        v = z - torch.einsum("ij,kj->ki", H, x)
        # x = x + K @ v
        x = x + torch.einsum("bso,bo->bs", K, v)
        # P = (torch.eye(*P.shape[1:]) - K @ self.H) @ P @ (torch.eye(*P.shape[1:]) - K @ self.H).T + K @ self.trans_noise @ K.T
        identity = torch.eye(*P.shape[1:], device=P.device)

        # joseph form for numerical stability

        t1 = identity - K @ H
        P = t1 @ P @ t1.transpose(1, 2) + K @ self.trans_noise @ K.transpose(1, 2)

        assert x.shape == (n_batch, self.latent_dim)
        assert P.shape == (n_batch, self.latent_dim, self.latent_dim)
        return x, P

    #         n_batch, n_time, _ = obs.shape
    #         # Initialization
    #         x, P = self.x0, self.P0
    #         x = repeat(x, "... -> B ...", B=n_batch)
    #         P = repeat(P, "... -> B ...", B=n_batch)
    #         pred_means, pred_sigmas, log_probs = [], [], []
    #         xs, Ps = [], []

    #         # do prior transition
    #         # x, P = transition_predict(self.x0, self.P, self.F, self.Q)

    #         # Iterate through sequence performing predict-update steps
    #         for i in range(n_time):
    #             # print(x.shape, P.shape)
    #             x, P, pred_means, pred_sigmas =
    #             x_prio, P_prio, pred_mean, pred_sigma = self.predict(x, P)
    #             # print(x_prio.shape, P_prio.shape, pred_mean.shape, pred_sigma.shape)
    #             x, P = self.update(x_prio, P_prio, obs[:, i, :], pred_sigma)

    #             # save predictive observations
    #             pred_means.append(pred_mean)
    #             pred_sigmas.append(pred_sigma)
    #             xs.append(x)
    #             Ps.append(P)

    #             # calculate log prob
    #             dist = torch.distributions.MultivariateNormal(pred_mean, pred_sigma)

    #             p = masked_multivariate_likelihood(
    #                 obs[:, i, :],
    #                 pred_mean,
    #                 pred_sigma,
    #                 mask[i, :]  if mask is not None else None
    #             )
    #             # if mask is not None:
    #             #     print(pred_mean.shape, pred_sigma.shape, obs[:, i, :].shape)
    #             #     p = dist.log_prob(obs[:, i, :] * (1 - mask[:, i, :]))
    #             #     print(p.shape)
    #             # else:
    #             #     p = dist.log_prob(obs[:, i, :])
    #             # # print(obs[:, i, :].shape, p.shape, .shape)
    #             log_probs.append(p)

    def filter_forward(self, obs, mask=None):
        """
        Iterate input data in case of discrete time steps

        Parameters
        ----------
        obs : torch.Tensor, shape=(Batch,Time,ObsDim)
            the observations
        mask : torch.Tensor, shape=(Batch, Time, ObsDim)
            the mask for the observations

        Returns
        -------
        pred_means : torch.Tensor, shape=(Batch,Time,ObsDim)
            the filtered observation means
        pred_sigmas : torch.Tensor, shape=(Batch,Time,Dimension,ObsDim)
            the filtered observation covariances
        xs : torch.Tensor, shape=(Batch,Time,StateDim)
            the filtered state dimensions
        Ps : torch.Tensor, shape=(Batch,Time,StateDim)
            the filtered state covariances
        log_probs : torch.Tensor, shape=(Batch,Time,ObsDim)
            the log probabilities of the innovations
        """
        n_batch, n_time, _ = obs.shape

        pred_means, pred_sigmas, log_probs = [], [], []
        xs, Ps = [], []

        # do prior
        x, P = self.x0, self.P0
        x = repeat(x, "... -> B ...", B=n_batch)
        P = repeat(P, "... -> B ...", B=n_batch)
        # print(x.shape, P.shape)

        for i in range(n_time):

            # predict step
            x, P, pred_mean, pred_sigma = predict_step(
                x, P, self.F, self.trans_noise, self.H, self.obs_noise
            )

            # log likelihood on innovations
            lprob = masked_multivariate_likelihood(
                obs[:, i, :],
                pred_mean,
                pred_sigma,
                mask[i, :] if mask is not None else None,
            )

            # update step
            x, P = update_step(
                obs[:, i, :],
                x,
                P,
                self.H,
                self.obs_noise,
                pred_sigma,
                mask[i, :] if mask is not None else None,
            )

            # save predictive observations
            pred_means.append(pred_mean)
            pred_sigmas.append(pred_sigma)
            xs.append(x)
            Ps.append(P)
            log_probs.append(lprob)

        # collapse all variables together (along time dimension)
        pred_means = torch.stack(pred_means, dim=1)
        pred_sigmas = torch.stack(pred_sigmas, dim=1)
        log_probs = torch.stack(log_probs, dim=1)
        xs = torch.stack(xs, dim=1)
        Ps = torch.stack(Ps, dim=1)

        return pred_means, pred_sigmas, xs, Ps, log_probs

    def forecasting(self, T, x, P):
        """
        forecast means and sigmas over given time period

        Parameters
        ----------
        T : int,
            the time steps after final states to forecast
        x : torch.Tensor, shape=(Batch,Time,ObsDim)
            the final state mean before forecasting window
        P : torch.Tensor, shape=(Batch,Time,ObsDim,ObsDim)
            the final state cov before the forecasting window

        Returns
        -------
        pred_means : torch.Tensor,  shape=(Batch,T,ObsDim)
            the predicted mean observations
        pred_sigmas : torch.Tensor, shape=(Batch,T,ObsDim)
            the predicted cov observations
        """
        pred_means = torch.Tensor([])
        pred_sigmas = torch.Tensor([])
        assert isinstance(T, int)
        assert T > 0
        pred_means, pred_sigmas, log_probs = [], [], []

        for i in range(T):
            x, P, pred_mean, pred_sigma = predict_step(
                x, P, self.F, self.trans_noise, self.H, self.obs_noise
            )
            pred_means.append(pred_mean)
            pred_sigmas.append(pred_sigma)

        pred_means = torch.stack(pred_means, dim=1)
        pred_sigmas = torch.stack(pred_sigmas, dim=1)
        return pred_means, pred_sigmas

In [None]:
# x = torch.randn((10, 5))
# P = torch.randn((10, 5,5))
# H = torch.randn((2, 5))
# # y = torch.matmul(H.unsqueeze(0), x.unsqueeze(1))
# Px = H.matmul(P).matmul(H.t())
# Px_ = torch.einsum("ij,kjl,ml->kim", H, P, H)
# torch.testing.assert_equal(Px, Px_)
# Px.shape,

In [None]:
y_mask_batches = repeat(y_mask, "... -> B ...", B=y_train_batches.shape[0])

In [None]:
model = DiscreteKalmanFilter(obs_dim=2, latent_dim=20)

# outs = model.filter_forward(y_train)

In [None]:
model.trans_noise.shape

In [None]:
outs = model.filter_forward(y_train_batches)

In [None]:
outs[0].shape

In [None]:
loss = model.forward(y_train_batches)

In [None]:
loss

## Training

In [None]:
model = DiscreteKalmanFilter(obs_dim=2, latent_dim=20)
optim = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
from tqdm.notebook import tqdm, trange

In [None]:
n_iterations = 1_000

losses = []

with trange(n_iterations) as pbar:

    for i in pbar:
        optim.zero_grad()
        loss = model(y_train)

        losses.append(loss.item())

        pbar.set_description(f"Iter {i}, loss: {loss:.4f}")
        loss.backward()
        optim.step()

In [None]:
plt.plot(losses)

In [None]:
with torch.no_grad():

    pred_mu, pred_sigma, x, P, _ = model.filter_forward(y_train)
    pred_mu = pred_mu.numpy()
    pred_sigma = pred_sigma.numpy()


i_batch = 0

plt.scatter(t, y_train[0, :, 0], c="C0")
plt.scatter(t, y_train[0, :, 1], c="C1")

plt.plot(np.linspace(0, 10, Nt), pred_mu[i_batch, :, 0], c="C0")
plt.plot(np.linspace(0, 10, Nt), pred_mu[i_batch, :, 1], c="C1")

plt.fill_between(
    np.linspace(0, 10, Nt),
    pred_mu[i_batch, :, 0] - 1.96 * pred_sigma[i_batch, :, 0, 0],
    pred_mu[i_batch, :, 0] + 1.96 * pred_sigma[i_batch, :, 0, 0],
    alpha=0.2,
)
plt.fill_between(
    np.linspace(0, 10, Nt),
    pred_mu[i_batch, :, 1] - 1.96 * pred_sigma[i_batch, :, 1, 1],
    pred_mu[i_batch, :, 1] + 1.96 * pred_sigma[i_batch, :, 1, 1],
    alpha=0.2,
)

# plt.ylim(-1.1, 1.1)

plt.plot()
plt.show()

In [None]:
with torch.no_grad():

    pred_mu, pred_sigma, x, P, _ = model.filter_forward(y_train)

    # forecasting
    Nt_fore = 100
    pred_mu_fore, pred_sigma_fore = model.forecasting(Nt_fore, x[:, -1, :], P[:, -1, :])

    pred_mu = torch.cat([pred_mu, pred_mu_fore], dim=1)
    pred_sigma = torch.cat([pred_sigma, pred_sigma_fore], dim=1)

    pred_mu = pred_mu.numpy()
    pred_sigma = pred_sigma.numpy()


i_batch = 0

plt.figure()

plt.scatter(t, y_masked[:, 0], c="C0")
plt.scatter(t, y_masked[:, 1], c="C1")

plt.plot(np.linspace(0, 20, Nt + Nt_fore), pred_mu[i_batch, :, 0], c="C0")
plt.plot(np.linspace(0, 20, Nt + Nt_fore), pred_mu[i_batch, :, 1], c="C1")

plt.fill_between(
    np.linspace(0, 20, Nt + Nt_fore),
    pred_mu[i_batch, :, 0] - 1.96 * pred_sigma[i_batch, :, 0, 0],
    pred_mu[i_batch, :, 0] + 1.96 * pred_sigma[i_batch, :, 0, 0],
    alpha=0.2,
)
plt.fill_between(
    np.linspace(0, 20, Nt + Nt_fore),
    pred_mu[i_batch, :, 1] - 1.96 * pred_sigma[i_batch, :, 1, 1],
    pred_mu[i_batch, :, 1] + 1.96 * pred_sigma[i_batch, :, 1, 1],
    alpha=0.2,
)

# plt.ylim(-1.1, 1.1)

plt.plot()
plt.show()

## Masked

In [None]:
Nt

In [None]:
t = np.linspace(0, 10, Nt)
y = np.stack([np.sin(t), np.cos(t)]).T + np.random.normal(0, 0.1, (Nt, 2))
idx = np.random.choice(np.arange(0, t.shape[0]), size=(Nt - 20))
y_masked = y.copy()
y_masked[idx, 0] = np.nan
y_masked[:, 1] = np.nan
mask = np.isnan(y_masked).astype(np.float32)

plt.show()
plt.scatter(t, y_masked[:, 0])
plt.scatter(t, y_masked[:, 1])
plt.plot(t, np.sin(t))
plt.plot(t, np.cos(t))
plt.show()

In [None]:
y_train = torch.Tensor(y_masked).unsqueeze(0)
y_train = torch.nan_to_num(y_train)
y_train_batches = repeat(y_train, "1 ... -> B ...", B=10)
y_mask = torch.Tensor(mask)
y_train.shape, y_mask.shape, y_train_batches.shape

In [None]:
# y_mask

In [None]:
y_train.shape, y_mask.shape

In [None]:
model = DiscreteKalmanFilter(obs_dim=2, latent_dim=8)
optim = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
n_iterations = 1_000

losses = []

with trange(n_iterations) as pbar:

    for i in pbar:
        optim.zero_grad()
        loss = model(y_train, y_mask)

        losses.append(loss.item())

        pbar.set_description(f"Iter {i}, loss: {loss:.4f}")
        loss.backward()
        optim.step()

In [None]:
plt.plot(losses)

In [None]:
with torch.no_grad():

    pred_mu, pred_sigma, x, P, _ = model.filter_forward(y_train, y_mask)
    pred_mu = pred_mu.numpy()
    pred_sigma = pred_sigma.numpy()


i_batch = 0

plt.scatter(t, y_masked[:, 0], c="C0")
plt.scatter(t, y_masked[:, 1], c="C1")

plt.plot(np.linspace(0, 10, Nt), pred_mu[i_batch, :, 0], c="C0")
plt.plot(np.linspace(0, 10, Nt), pred_mu[i_batch, :, 1], c="C1")

plt.fill_between(
    np.linspace(0, 10, Nt),
    pred_mu[i_batch, :, 0] - 1.96 * pred_sigma[i_batch, :, 0, 0],
    pred_mu[i_batch, :, 0] + 1.96 * pred_sigma[i_batch, :, 0, 0],
    alpha=0.2,
)
plt.fill_between(
    np.linspace(0, 10, Nt),
    pred_mu[i_batch, :, 1] - 1.96 * pred_sigma[i_batch, :, 1, 1],
    pred_mu[i_batch, :, 1] + 1.96 * pred_sigma[i_batch, :, 1, 1],
    alpha=0.2,
)

# plt.ylim(-1.1, 1.1)

plt.plot()
plt.show()

In [None]:
with torch.no_grad():

    pred_mu, pred_sigma, x, P, _ = model.filter_forward(y_train, y_mask)

    # forecasting
    Nt_fore = 100
    pred_mu_fore, pred_sigma_fore = model.forecasting(Nt_fore, x[:, -1, :], P[:, -1, :])

    pred_mu = torch.cat([pred_mu, pred_mu_fore], dim=1)
    pred_sigma = torch.cat([pred_sigma, pred_sigma_fore], dim=1)

    pred_mu = pred_mu.numpy()
    pred_sigma = pred_sigma.numpy()


i_batch = 0

plt.figure()

plt.scatter(t, y_masked[:, 0], c="C0")
plt.scatter(t, y_masked[:, 1], c="C1")

plt.plot(np.linspace(0, 20, Nt + Nt_fore), pred_mu[i_batch, :, 0], c="C0")
plt.plot(np.linspace(0, 20, Nt + Nt_fore), pred_mu[i_batch, :, 1], c="C1")

plt.fill_between(
    np.linspace(0, 20, Nt + Nt_fore),
    pred_mu[i_batch, :, 0] - 1.96 * pred_sigma[i_batch, :, 0, 0],
    pred_mu[i_batch, :, 0] + 1.96 * pred_sigma[i_batch, :, 0, 0],
    alpha=0.2,
)
plt.fill_between(
    np.linspace(0, 20, Nt + Nt_fore),
    pred_mu[i_batch, :, 1] - 1.96 * pred_sigma[i_batch, :, 1, 1],
    pred_mu[i_batch, :, 1] + 1.96 * pred_sigma[i_batch, :, 1, 1],
    alpha=0.2,
)

# plt.ylim(-1.1, 1.1)

plt.plot()
plt.show()