In [None]:
import sys
import os
import glob
import json
from datetime import datetime
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
sys.path.append(os.path.dirname(os.getcwd()))

from src.dynamics import Lorenz96
from src.measurements import Gaussian

device = torch.device("cpu")

In [None]:
seed = 42
dim = 1000
prior_mean = 6.0
prior_std = 1.0
dt = 0.05
forcing = 8.0
perturb_std = 0.5
solver = "Runge-Kutta"
steps = 100

noise_std = 2.0

n_ensemble = 2000

np.random.seed(seed)
torch.manual_seed(seed)

dynamics = Lorenz96(
    dim=dim,
    prior_mean=prior_mean,
    prior_std=prior_std,
    dt=dt,
    forcing=forcing,
    perturb_std=perturb_std,
    solver=solver,
)

measurement = Gaussian(noise_std=noise_std)

x0 = forcing * torch.ones((1, dim), device=device)  # (1, dim)
x0[0][0] += 0.01

states: torch.Tensor = dynamics.generate(
    x0=x0,
    steps=steps,
)  # (steps+1, dim)
observations: torch.Tensor = measurement.measure(states) # (steps+1, dim)

In [None]:
ensembles = torch.empty((steps+1, n_ensemble, dim), device=device)
ensembles[0] = torch.randn((n_ensemble, dim)) * prior_std + prior_mean

for i in range(steps):
    ensembles[i+1] = dynamics.transition(ensembles[i]) # (n_ensemble, dim)
    cov = torch.cov(ensembles[i+1].T) # (dim, dim)
    K = cov @ torch.linalg.inv(cov + torch.eye(dim)*noise_std**2) # (dim, dim)
    ensembles[i+1] = ensembles[i+1] + (observations[i+1] - ensembles[i+1]) @ K

# np.savez(
#     f'../lorenz_results/enkf_sigma{noise_std}_nensemble{n_ensemble}.npz',
#     assimilated_states=ensembles.numpy()
# )

In [None]:
mean_estimation = torch.mean(ensembles, dim=1)
t = np.arange(steps+1) * dt
markevery=1

mpl.rcdefaults()
mpl.rc("mathtext", fontset="cm")
mpl.rc("font", family="serif", serif="DejaVu Serif")
mpl.rc("figure", dpi=600, titlesize=9)
mpl.rc("figure.subplot", wspace=0.2, hspace=0.6)
mpl.rc("axes", grid=False, labelsize=9, labelpad=0.5)
mpl.rc("axes.spines", top=False, right=False)
mpl.rc("xtick", labelsize=6, direction="out")
mpl.rc("ytick", labelsize=6, direction="out")
mpl.rc("xtick.major", pad=2)
mpl.rc("ytick.major", pad=2)
mpl.rc("grid", linestyle=":", alpha=0.8)
mpl.rc("lines", linewidth=1, markersize=5, markerfacecolor="none", markeredgecolor="auto", markeredgewidth=0.5)
mpl.rc("scatter", marker='o')
mpl.rc("legend", fontsize=9)

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(7, 4))
for i in range(3):
    ax = axes[i]
    for j in range(3):
        axj = ax[j]
        axj.scatter(t[::markevery], states[:, i+3*i][::markevery], label='Truth', color='C0', marker='o', facecolors='none', s=4, linewidths=1.)
        axj.plot(t, mean_estimation[:, i+3*i], label='EnKF', color='C1', markevery=markevery)
        axj.set_title(f"$x_{{{i*3+j+1}}}$")
axes[0][1].legend(bbox_to_anchor=(0.9, 1.3), loc='lower left', ncol=2)
plt.text(0.08, 0.5, 'State', transform=plt.gcf().transFigure, fontsize=9, rotation='vertical')
plt.text(0.43, 0.03, 'Time step', transform=plt.gcf().transFigure, fontsize=9)
# plt.savefig('../asset/Lorenz96_evolution.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)