Reference: Chapter 4.2.1 of https://project-archive.inf.ed.ac.uk/msc/20204379/msc_proj.pdf

__Linear Gaussian State Space Model__
$$
\begin{aligned}
p(x_1) & = \mathcal{N}_{x_1}(\mathbf{0}, \mathbf{I}) \\
f(x_t | x_{t-1}) & = \mathcal{N}_{x_t}(Ax_{t-1}+c, \mathbf{R}) \\
g(y_t | x_t) & = \mathcal{N}_{y_t}(Cx_t+g, \mathbf{Q}) \\
\end{aligned}
$$

In [None]:
import os, sys, pickle
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.distributions import MultivariateNormal
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
sys.path.append(os.path.dirname(os.getcwd()))

from src.networks import MLP
from src.utils import assemble_grad_potential, langevin_sampler

device = torch.device("cpu")

In [None]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

dim = 1
I = torch.eye(dim, device=device)
zeros = torch.zeros((dim,), device=device)
A = I
C = I
c = 0.
g = 0.
R = 5 * I
Q = 0.2 * I
Q_inv = torch.linalg.inv(Q)
mu1 = torch.zeros((dim, ), device=device)
cov1 = I
T = 10

DynamicsNoiseGenerator = MultivariateNormal(loc=zeros, covariance_matrix=R)
ObservationNoiseGenerator = MultivariateNormal(loc=zeros, covariance_matrix=Q)

def transition(x):
    loc = x @ A + c
    return loc + DynamicsNoiseGenerator.sample((x.shape[0],))

def observation(x):
    loc = x @ C + g
    return loc + ObservationNoiseGenerator.sample((x.shape[0],))

def score_likelihood_fn(x, y):
    return (y - x @ C - g) @ C.T @ Q_inv

x = torch.empty((T, dim), device=device)
x_true = torch.empty((T, dim), device=device)
y = torch.empty((T, dim), device=device)
x[0] = mu1
for i in range(T-1):
    x[i+1:i+2] = transition(x[i:i+1])

for i in range(T):
    y[i:i+1] = observation(x[i:i+1])

mu = mu1
cov = cov1
mu_post = torch.empty((T, dim), device=device)
cov_post = torch.empty((T, dim, dim), device=device)
for i in range(T):
    # prediction
    mu_hat = A @ mu + c
    cov_hat = A @ cov @ A.T + R
    # Kalman Gain
    K = cov_hat @ C.T @ torch.linalg.inv(C @ cov_hat @ C.T + Q)
    # residual
    residual = y[i] - C @ mu_hat
    # update
    mu = mu_hat + K @ residual
    cov = (I - K @ C) @ cov_hat
    mu_post[i] = mu
    cov_post[i] = cov

print(x[-1], mu)

In [None]:
workdir = "../linear_results/exact_prior"
os.makedirs(workdir, exist_ok=True)
ntrain = 500
denoising_sigma = 0.2
lr = 1e-3
nepoch = 1000
lmc_steps = 1000
lmc_stepsize = 1e-3
anneal_init = 1e-3
anneal_decay = 0.5
anneal_steps = 1

mu0 = 0.
std0 = 1.

prior = torch.randn((ntrain, dim), device=device) * std0 + mu0
model = MLP(dim=dim, widths=[32, 64], use_bn=True).to(device)
assimilated_states = torch.empty((T, ntrain, dim), device=device)

with tqdm(range(T), maxinterval=50.0, desc="state step", file=sys.stdout) as pbar:
    for i in pbar:
        prior_mean, prior_std = prior.mean(dim=0), prior.std(dim=0)
        # normalize states for stable input to the network
        normalized_prior = (prior - prior_mean) / prior_std
        # model predicts the noise from the noised normalized state, as in DDPM
        # normalized_score_fn predicts the score for the normalized state
        normalized_score_fn = lambda x: -model(x) / denoising_sigma

        # denoising score matching
        model.train()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        dataset = TensorDataset(normalized_prior)
        loader = DataLoader(dataset, batch_size=ntrain, shuffle=True)
        for _ in range(nepoch):
            for batch_no, batch in enumerate(loader, start=1):
                (x0,) = batch
                x0 = x0.to(device)  # (B, *state_shape)
                z = torch.randn_like(x0, device=device)
                xt = x0 + denoising_sigma * z
                score = normalized_score_fn(xt)
                loss = nn.MSELoss()(score * denoising_sigma, -z)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        model.eval()

        # Posterior sampling and prior update
        # score_fn predicts the score for the original (unnormalized) states
        # Y = aX + b => s_X(x) = a s_Y(ax+b)
        score_fn = lambda x: normalized_score_fn((x - prior_mean) / prior_std) / prior_std
        # potential gradient = - score_likelihood - score_prior
        grad_potential_fn = assemble_grad_potential(
            y=y[i],
            score_likelihood=score_likelihood_fn,
            score_prior=score_fn,
        )

        # Debugging scales between likelihood and prior
        with torch.no_grad():
            score_likelihood = score_likelihood_fn(prior, y[i])
            score_likelihood = torch.mean(score_likelihood**2)
            score_prior = score_fn(prior)
            score_prior = torch.mean(score_prior**2)

        with torch.no_grad():
            posterior = langevin_sampler(
                grad_potential_fn=grad_potential_fn,
                x=prior,
                steps=lmc_steps,
                dt=lmc_stepsize,
                anneal_init=anneal_init,
                anneal_decay=anneal_decay,
                anneal_steps=anneal_steps,
            )  # (n_train, *shape)
            assimilated_states[i] = posterior
            prior = transition(posterior)

        """Postprocessing."""
        mean_estimation = torch.mean(posterior, dim=0)  # (*shape, )
        median_estimation = torch.median(posterior, dim=0)[0]  # (*shape, )
        mean_rmse = torch.sqrt(torch.mean((mean_estimation - x[i]) ** 2))
        median_rmse = torch.sqrt(torch.mean((median_estimation - x[i]) ** 2))
        pbar.set_postfix(
            {
                "mean(RMSE)": mean_rmse.item(),
                "median(RMSE)": median_rmse.item(),
            },
            refresh=False,
        )

np.savez(
    os.path.join(workdir, "results.npz"),
    assimilated_states=assimilated_states.cpu().numpy(),  # (steps, ntrain, d)
)

In [None]:
print(assimilated_states.cpu().numpy().mean(axis=1).squeeze()) 
print(mu_post.cpu().numpy().squeeze())

In [None]:
print(assimilated_states.cpu().numpy().std(axis=1).squeeze())
print((cov_post.cpu().numpy()**0.5).squeeze())

In [None]:
workdir = "../linear_results/inexact_prior"
os.makedirs(workdir, exist_ok=True)
mu0 = -10.
std0 = 1.

prior = torch.randn((ntrain, dim), device=device) * std0 + mu0
model = MLP(dim=dim, widths=[32, 64], use_bn=True).to(device)
assimilated_states = torch.empty((T, ntrain, dim), device=device)

with tqdm(range(T), maxinterval=50.0, desc="state step", file=sys.stdout) as pbar:
    for i in pbar:
        prior_mean, prior_std = prior.mean(dim=0), prior.std(dim=0)
        # normalize states for stable input to the network
        normalized_prior = (prior - prior_mean) / prior_std
        # model predicts the noise from the noised normalized state, as in DDPM
        # normalized_score_fn predicts the score for the normalized state
        normalized_score_fn = lambda x: -model(x) / denoising_sigma

        # denoising score matching
        model.train()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        dataset = TensorDataset(normalized_prior)
        loader = DataLoader(dataset, batch_size=ntrain, shuffle=True)
        for _ in range(nepoch):
            for batch_no, batch in enumerate(loader, start=1):
                (x0,) = batch
                x0 = x0.to(device)  # (B, *state_shape)
                z = torch.randn_like(x0, device=device)
                xt = x0 + denoising_sigma * z
                score = normalized_score_fn(xt)
                loss = nn.MSELoss()(score * denoising_sigma, -z)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        model.eval()

        # Posterior sampling and prior update
        # score_fn predicts the score for the original (unnormalized) states
        # Y = aX + b => s_X(x) = a s_Y(ax+b)
        score_fn = lambda x: normalized_score_fn((x - prior_mean) / prior_std) / prior_std
        # potential gradient = - score_likelihood - score_prior
        grad_potential_fn = assemble_grad_potential(
            y=y[i],
            score_likelihood=score_likelihood_fn,
            score_prior=score_fn,
        )

        # Debugging scales between likelihood and prior
        with torch.no_grad():
            score_likelihood = score_likelihood_fn(prior, y[i])
            score_likelihood = torch.mean(score_likelihood**2)
            score_prior = score_fn(prior)
            score_prior = torch.mean(score_prior**2)

        with torch.no_grad():
            posterior = langevin_sampler(
                grad_potential_fn=grad_potential_fn,
                x=prior,
                steps=lmc_steps,
                dt=lmc_stepsize,
                anneal_init=anneal_init,
                anneal_decay=anneal_decay,
                anneal_steps=anneal_steps,
            )  # (n_train, *shape)
            assimilated_states[i] = posterior
            prior = transition(posterior)

        """Postprocessing."""
        mean_estimation = torch.mean(posterior, dim=0)  # (*shape, )
        median_estimation = torch.median(posterior, dim=0)[0]  # (*shape, )
        mean_rmse = torch.sqrt(torch.mean((mean_estimation - x[i]) ** 2))
        median_rmse = torch.sqrt(torch.mean((median_estimation - x[i]) ** 2))
        pbar.set_postfix(
            {
                "mean(RMSE)": mean_rmse.item(),
                "median(RMSE)": median_rmse.item(),
            },
            refresh=False,
        )

np.savez(
    os.path.join(workdir, "results.npz"),
    assimilated_states=assimilated_states.cpu().numpy(),  # (steps, ntrain, d)
)

In [None]:
print(assimilated_states.cpu().numpy().mean(axis=1).squeeze()) 
print(mu_post.cpu().numpy().squeeze())

In [None]:
print(assimilated_states.cpu().numpy().std(axis=1).squeeze())
print((cov_post.cpu().numpy()**0.5).squeeze())

In [None]:
workdirs = ["../linear_results/exact_prior", "../linear_results/inexact_prior"]
states = []
for workdir in workdirs:
    data = np.load(os.path.join(workdir, "results.npz"))
    states.append(data["assimilated_states"])

with open("../asset/linear.pkl", "wb") as file:
    pickle.dump((states, mu_post.cpu(), cov_post.cpu()), file)
    

In [None]:
with open("../asset/linear.pkl", "rb") as file:
    states, mu_post, cov_post = pickle.load(file)
mpl.rcdefaults()
mpl.style.use('../configs/mplrc')
mpl.rc("figure.subplot", wspace=0.35, hspace=0.4)


nrows = 2
ncols = 5
fig, axes = plt.subplots(
    nrows=nrows, 
    ncols=ncols, 
    figsize=(7, 2.5)
    )

for assimilated_states, ax in zip(states, axes):
    for states, axi in zip(assimilated_states[::2], ax):
        sns.kdeplot(states.squeeze(), ax=axi, color='C1', legend=False)

for mu, cov, ax1, ax2 in zip(mu_post[::2], cov_post[::2], axes[0], axes[1]):
    truth = MultivariateNormal(loc=mu, covariance_matrix=cov).sample((10000, )).cpu().squeeze()
    sns.kdeplot(truth.squeeze(), ax=ax1, legend=False, fill=True, linewidth=0., color='C0', alpha=0.5)
    sns.kdeplot(truth.squeeze(), ax=ax2, legend=False, fill=True, linewidth=0., color='C0', alpha=0.5)
    xlim = (mu - 4*cov.cpu()[0]**0.5, mu + 4*cov[0]**0.5)
    ax1.set_xlim(xlim)
    ax2.set_xlim(xlim)

for ax in axes.flat:
    ax.set_ylabel('')

axes[0][0].set_ylabel('Exact Prior')
axes[1][0].set_ylabel('Inexact Prior')

for i, ax in enumerate(axes[1]):
    ax.set_xlabel(r'$\bf{t_' f'{i+1}' r'}$')

custom_lines = [
    mpl.lines.Line2D([0], [0], color='C1', label='SSLS ensemble density'),
    mpl.patches.Patch(facecolor='C0', alpha=0.5, label='Ground truth posterior')
]
axes[0][0].legend(handles=custom_lines, bbox_to_anchor=(0.5, 1.0), loc='lower left', ncol=2)

plt.savefig('../asset/Linear.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)