In [None]:
import sys
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib as mpl
sys.path.append(os.path.dirname(os.getcwd()))

from src.dynamics import Lorenz96
from src.measurements import Linear

device = torch.device("cuda:0")

In [None]:
seed = 42
dim = 20
prior_mean = 2.0
prior_std = 1.0
dt = 0.05
forcing = 8.0
perturb_std = 0.1
solver = "Runge-Kutta"
steps = 100

noise_std = 0.5

n_ensemble = 500
n_resampling = n_ensemble

np.random.seed(seed)
torch.manual_seed(seed)

dynamics = Lorenz96(
    dim=dim,
    prior_mean=prior_mean,
    prior_std=prior_std,
    dt=dt,
    forcing=forcing,
    perturb_std=perturb_std,
    solver=solver,
)

measurement = Linear(noise_std=noise_std)

x0 = forcing * torch.ones((1, dim), device=device)  # (1, dim)
x0[0][0] += 0.01

states: torch.Tensor = dynamics.generate(
    x0=x0,
    steps=steps,
)  # (steps+1, dim)
observations: torch.Tensor = measurement.measure(states) # (steps+1, dim)

In [None]:
ensembles = torch.empty((steps+1, n_ensemble, dim), device=device)
ensembles[0] = torch.randn((n_ensemble, dim)) * prior_std + prior_mean

for i in range(steps):
    # mu = dynamics.transition(ensembles[i])
    mu = dynamics.transition(ensembles[i]) + torch.randn_like(ensembles[i]) * 0.5

    # generate observation at k+1
    y = observations[i+1]*torch.ones_like(mu)

    # calculate lambda_k in step 2 of APF
    tmp = y-mu
    logits1 = -torch.einsum('ij,ij->i', tmp, tmp)/(2*noise_std**2)
    logits1 = logits1 - logits1.max()
    w1 = np.exp(logits1.cpu())
    lam = w1/w1.sum()

    # draw from the transition density in step 3 of APF
    index_tmp = lam.multinomial(num_samples=n_resampling, replacement=True)
    sample_tmp = mu[index_tmp]

    # calculate the importance weight in step 4 of APF
    tmp = y[index_tmp] - sample_tmp
    logits2 = -torch.einsum('ij,ij->i', tmp, tmp)/(2*noise_std**2)
    logits2 = logits2 - logits2.max()
    w2 = np.exp(logits2.cpu())
    w2 = w2/w2.sum()
    
    tmp = y[index_tmp]-sample_tmp
    logits3 = -torch.einsum('ij,ij->i', tmp, tmp)/(2*noise_std**2)
    logits3 = logits3 - logits3.max()
    w3 = np.exp(logits3.cpu())
    w3 = w3/w3.sum()
    
    ww = w2/w3
    w = ww/ww.sum()
    
    # resampling the posterior samples as in step 5 of APF
    index_res = w.multinomial(num_samples=n_ensemble, replacement=True)
    ensembles[i+1] = sample_tmp[index_res]

np.savez(
    f'../lorenz_results/apf_sigma{noise_std}_nensemble{n_ensemble}.npz',
    assimilated_states=ensembles.cpu().numpy()
)

In [None]:
# mean_estimation = torch.mean(ensembles, dim=1)
# t = np.arange(steps+1) * dt
# markevery=1

# mpl.rcdefaults()
# mpl.rc("mathtext", fontset="cm")
# mpl.rc("font", family="serif", serif="DejaVu Serif")
# mpl.rc("figure", dpi=600, titlesize=9)
# mpl.rc("figure.subplot", wspace=0.2, hspace=0.6)
# mpl.rc("axes", grid=False, labelsize=9, labelpad=0.5)
# mpl.rc("axes.spines", top=False, right=False)
# mpl.rc("xtick", labelsize=6, direction="out")
# mpl.rc("ytick", labelsize=6, direction="out")
# mpl.rc("xtick.major", pad=2)
# mpl.rc("ytick.major", pad=2)
# mpl.rc("grid", linestyle=":", alpha=0.8)
# mpl.rc("lines", linewidth=1, markersize=5, markerfacecolor="none", markeredgecolor="auto", markeredgewidth=0.5)
# mpl.rc("scatter", marker='o')
# mpl.rc("legend", fontsize=9)

# fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(7, 4))
# for i in range(3):
#     ax = axes[i]
#     for j in range(3):
#         axj = ax[j]
#         axj.scatter(t[::markevery], states.cpu()[:, j+3*i][::markevery], label='Truth', color='C0', marker='o', facecolors='none', s=4, linewidths=1.)
#         axj.plot(t, mean_estimation.cpu()[:, j+3*i], label='APF', color='C1', markevery=markevery)
#         axj.set_title(f"$x_{{{i*3+j+1}}}$")
# axes[0][1].legend(bbox_to_anchor=(0.9, 1.3), loc='lower left', ncol=2)
# plt.text(0.08, 0.5, 'State', transform=plt.gcf().transFigure, fontsize=9, rotation='vertical')
# plt.text(0.43, 0.03, 'Time step', transform=plt.gcf().transFigure, fontsize=9)