In [None]:
import os
import json
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib as mpl

device = torch.device("cpu")

In [None]:
workdir = os.path.join(
    os.path.dirname(os.getcwd()),
    'lorenz_results/lorenz96-241031-150134',
)

with open(os.path.join(workdir, 'config.json'), 'r') as f:
    cfg = json.load(f)

dim = cfg['dynamics']['dim']
steps = cfg['dynamics']['steps']
dt = cfg['dynamics']['dt']
noise_std = cfg['measurement']['noise_std']

results = np.load(os.path.join(workdir, 'results.npz'))
states = results['states'] # (steps+1, dim)
observations = results['observations'] # (steps+1, dim)
ssls_states = results['assimilated_states'] # (steps+1, nsamples, dim)
guess_init = np.ones((1, cfg["train"]["n_train"], dim)) * cfg["dynamics"]["prior_mean"]
guess_init += np.random.randn(*guess_init.shape) * cfg["dynamics"]["prior_std"]
ssls_states = np.concat((guess_init, ssls_states), axis=0)
ssls_mean = np.mean(ssls_states, axis=1)

results = np.load(f'../lorenz_results/apf_sigma{noise_std}_nensemble{cfg["train"]["n_train"]}.npz')
apf_states = results['assimilated_states'] # (steps+1, nsamples, dim)
apf_states = np.concat((guess_init, apf_states), axis=0)
apf_mean = np.mean(apf_states, axis=1)

with open("../asset/Lorenz96_ensemble.pkl", "wb") as file:
    pickle.dump((states, ssls_states, apf_states), file)

In [None]:
with open("../asset/Lorenz96_ensemble.pkl", "rb") as file:
    states, ssls_states, apf_states = pickle.load(file)
mpl.rcdefaults()
mpl.style.use("../configs/mplrc")
marker_style = {
    "marker": 'o',
    "markersize": 3,
    "markerfacecolor": (0, 0, 0, 0),
    "markeredgecolor": 'C0',
    "markeredgewidth": 0.8
}

t = np.arange(100+2)
markevery = 1

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 2.5))
ax.plot(t[1::markevery], states[:, -1][::markevery], color='C0', markevery=markevery, zorder=99, **marker_style)
ax.plot(t[::markevery], ssls_states[:, :, -1][::markevery], color='C1', alpha=0.1, linewidth=0.2, markevery=markevery, rasterized=True)
ax.plot(t[::markevery], apf_states[:, :, -1][::markevery], color='C2', alpha=0.1, linewidth=0.2, markevery=markevery, rasterized=True)
ax.set_xlim(left=0)
custom_lines = [
    mpl.lines.Line2D([0], [0], color='C0', label='Ref. state', **marker_style),
    mpl.lines.Line2D([0], [0], color='C1', label='SSLS ensemble'),
    mpl.lines.Line2D([0], [0], color='C2', label='APF ensemble'),
]
ax.legend(handles=custom_lines, bbox_to_anchor=(0.1, 1.0), loc='lower left', ncol=3)
ax.set_xlabel('Time step')
ax.set_ylabel('10th dimension')
plt.savefig('../asset/Lorenz96_ensemble.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)