In [None]:
import os
import json
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib as mpl

device = torch.device("cpu")

In [None]:
workdir = os.path.join(
    os.path.dirname(os.getcwd()),
    'lorenz_results/lorenz96-241031-150134',
)

with open(os.path.join(workdir, 'config.json'), 'r') as f:
    cfg = json.load(f)

dim = cfg['dynamics']['dim']
steps = cfg['dynamics']['steps']
dt = cfg['dynamics']['dt']
noise_std = cfg['measurement']['noise_std'] # 0.5
n_train = cfg["train"]["n_train"] # 500

results = np.load(os.path.join(workdir, 'results.npz'))
states = results['states'] # (steps+1, dim)
observations = results['observations'] # (steps+1, dim)
ssls_states = results['assimilated_states'] # (steps+1, nsamples, dim)
ssls_mean = np.mean(ssls_states, axis=1)
guess_init = np.ones((1, dim)) * cfg["dynamics"]["prior_mean"]
ssls_mean = np.concat((guess_init, ssls_mean))

results = np.load(f'../lorenz_results/apf_sigma0.5_nensemble500.npz')
apf_states = results['assimilated_states'] # (steps+1, nsamples, dim)
apf_mean = np.mean(apf_states, axis=1)
guess_init = np.ones((1, dim)) * cfg["dynamics"]["prior_mean"]
apf_mean = np.concat((guess_init, apf_mean))

with open("../asset/Lorenz96_evolution.pkl", "wb") as file:
    pickle.dump((states, ssls_mean, apf_mean), file)

In [None]:
with open("../asset/Lorenz96_evolution.pkl", "rb") as file:
    states, ssls_mean, apf_mean = pickle.load(file)

mpl.rcdefaults()
mpl.style.use("../configs/mplrc")
mpl.rc("figure.subplot", wspace=0.2, hspace=0.6)

markevery = 1

t = np.arange(100+2) * 0.05

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(7, 4))
for i in range(3):
    ax = axes[i]
    for j in range(3):
        axj = ax[j]
        axj.scatter(t[1::markevery], states[:, j+3*i][::markevery], label='Truth', color='C0', marker='o', facecolors='none', s=4, linewidths=1.)
        axj.plot(t, ssls_mean[:, j+3*i], label='SSLS', color='C1', markevery=markevery)
        axj.plot(t, apf_mean[:, j+3*i], label='APF', color='C2', markevery=markevery)
        axj.set_title(r"$\bf{x_{" f"{i*3+j+1}" r"}}$")
        axj.set_xlim([0, 5.1])
        axj.set_xticks([0, 1, 2, 3, 4, 5])
axes[0][1].legend(bbox_to_anchor=(0.2, 1.3), loc='lower left', ncol=3)
plt.text(0.08, 0.5, 'State', transform=plt.gcf().transFigure, fontsize=9, rotation='vertical')
plt.text(0.43, 0.03, 'Time step', transform=plt.gcf().transFigure, fontsize=9)
plt.savefig('../asset/Lorenz96_evolution.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)