In [None]:
import os
import glob
import json
from datetime import datetime
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns


device = torch.device("cpu")

In [None]:
start_time = datetime(2024, 10, 31, 0, 0, 0)
end_time = datetime(2024, 10, 31, 23, 59, 59)
all_folders = glob.glob('../lorenz_results/lorenz*')

def get_timestamp_from_folder(folder_name):
    timestamp_str = folder_name.split('-')[1] + folder_name.split('-')[2]
    return datetime.strptime(timestamp_str, '%y%m%d%H%M%S')

def quantile_loss(target, pred, q):
    return 2 * np.sum(
        np.abs((pred - target) * ((target <= pred) * 1.0 - q)), axis=1
    )

def calc_quantile_CRPS(states, assimilated_states):
    quantiles = np.arange(0.05, 1.0, 0.05)
    denom = np.sum(np.abs(states), axis=1) # (steps, )
    CRPS = np.zeros(states.shape[0])
    for i in range(len(quantiles)):
        pred = np.quantile(assimilated_states, quantiles[i], axis=1) # (steps, dim)
        q_loss = quantile_loss(states, pred, quantiles[i]) # (steps, )
        CRPS += q_loss / denom
    return CRPS / len(quantiles)

filtered_folders = []
for folder in all_folders:
    folder_name = os.path.basename(folder)
    try:
        timestamp = get_timestamp_from_folder(folder_name)
        if start_time <= timestamp <= end_time:
            filtered_folders.append(folder)
    except ValueError:
        print(f"{folder_name} doesn't match naming conventions")

df = []
for workdir in filtered_folders:
    with open(os.path.join(workdir, 'config.json'), 'r') as f:
        cfg = json.load(f)
    noise_std = cfg['measurement']['noise_std']
    n_train = cfg['train']['n_train']
    results = np.load(os.path.join(workdir, 'results.npz'))
    states = results['states'] # (steps, dim)
    observations = results['observations'] # (steps, dim)
    ssls_states = results['assimilated_states'] # (steps, nsamples, dim)
    ssls_mean = np.mean(ssls_states, axis=1) # (steps, dim)
    average_rmse = np.mean((states - ssls_mean)**2, axis=1)**0.5 # (steps, )
    variances = np.var(ssls_states, axis=1, ddof=0) # (steps, dim)
    average_spread = np.mean(np.sqrt(variances), axis=1) # (steps, )
    low_quantile = np.quantile(ssls_states, q=0.1, axis=1) # (steps, dim)
    high_quantile = np.quantile(ssls_states, q=0.9, axis=1) # (steps, dim)
    average_coverage = np.sum((states > low_quantile) & (states < high_quantile), axis=1) / states.shape[1]
    crps = calc_quantile_CRPS(states, ssls_states)
    for i in range(ssls_states.shape[0]):
        df.append([
            n_train, 
            noise_std,
            average_rmse[i],
            crps[i],
            average_spread[i],
            average_coverage[i],
            "SSLS",
            ])
        
for noise_std in [0.5, 1.0, 2.0]:
    for n_ensemble in [50, 200, 500, 1000, 2000]:
        results = np.load(f'../lorenz_results/apf_sigma{noise_std}_nensemble{n_ensemble}.npz')
        apf_states = results['assimilated_states'] # (steps, nsamples, dim)
        apf_mean = np.mean(apf_states, axis=1) # (steps, dim)
        average_rmse = np.mean((states - apf_mean)**2, axis=1)**0.5 # (steps, )
        variances = np.var(apf_states, axis=1, ddof=0) # (steps, dim)
        average_spread = np.mean(np.sqrt(variances), axis=1) # (steps, )
        low_quantile = np.quantile(apf_states, q=0.1, axis=1) # (steps, dim)
        high_quantile = np.quantile(apf_states, q=0.9, axis=1) # (steps, dim)
        average_coverage = np.sum((states > low_quantile) & (states < high_quantile), axis=1) / states.shape[1]
        crps = calc_quantile_CRPS(states, apf_states)
        for i in range(apf_states.shape[0]):
            df.append([
                n_ensemble, 
                noise_std,
                average_rmse[i],
                crps[i],
                average_spread[i],
                average_coverage[i],
                "APF"
                ])


df = pd.DataFrame(
    df, 
    columns=[
        "n", 
        r"$\sigma_{obs}$", 
        "Average RMSE", 
        "CRPS",
        "Average spread", 
        "Average coverage probability",
        "Method"
        ],
    )

df.to_pickle("../asset/Lorenz96_metrics.pkl")

In [None]:
df = pd.read_pickle("../asset/Lorenz96_metrics.pkl")
mpl.rcdefaults()
mpl.style.use("../configs/mplrc")
mpl.rc("figure.subplot", wspace=0.25, hspace=0.4)

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(7, 4.5))
for ax, metric_name in zip(
    axes.flatten(),
    ["Average RMSE", 
     "CRPS", 
     "Average spread", 
     "Average coverage probability"]):
    conditions = df[r'$\sigma_{obs}$'] != 2.0
    sns.lineplot(df[conditions], x="n", y=metric_name, style=r"$\sigma_{obs}$", hue="Method", hue_order=["APF", "SSLS"], palette=["C2", "C1"], markers=["o", "^"], errorbar=None, ax=ax, markerfacecolor='auto', markeredgecolor='auto', markersize=5)
    ax.set_xticks([50, 200, 500, 1000, 2000])
    ax.grid()
    ax.set_xlabel("Ensemble size")
for ax in axes.flat:
    ax.legend().remove()
custom_lines = [
    mpl.lines.Line2D([0], [0], color='C1', label='SSLS'),
    mpl.lines.Line2D([0], [0], color='C2', label='APF'),
    mpl.lines.Line2D([0], [0], marker='o', color='k', markeredgecolor='k', markerfacecolor='k', markeredgewidth=1, markersize=5, label=r'$\sigma_{\rm obs} = 0.5$'),
    mpl.lines.Line2D([0], [0], marker='^', color='k', markeredgecolor='k', markerfacecolor='k', markeredgewidth=1, markersize=5, linestyle='--', label=r'$\sigma_{\rm obs} = 1.0$'),
]
axes[0][0].legend(handles=custom_lines, bbox_to_anchor=(0.35, 1.1), loc='lower left', ncol=4)
plt.savefig('../asset/Lorenz96_metrics.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)