# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv
    
import IPython
import matplotlib.pyplot as plt

from mlde_utils.utils import prep_eval_data
from mlde_notebooks import plot_spread_skill

In [None]:
split = "val"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-20",
            "input_xfm": "stan",
            "label": "Diffusion 12em",
            "dataset": "bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
    ],
}
highlighted_cpm_models = ["Diffusion 12em"]
# the datasets to use for comparisons like PSD which need default datasets with CPM-based hi-res precip and GCM-based lo-res precip respectively
simulation_pr_datasets = {
    "GCM": "bham_60km-4x_linpr_random-season",
    "CPM": "bham_60km-4x_linpr_random-season"
}
gcm_lr_lin_pr_dataset = "bham_60km-4x_12em_linpr_eqvt_random-season"
cpm_hr_pr_dataset = "bham_gcmx-4x_12em_linpr_eqvt_random-season"
desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
data_configs = { source: [ config for config in configs if config["deterministic"] is False ] for source, configs in data_configs.items() }

In [None]:
det_models = { source: [
    run_config["label"] for run_config in run_configs if run_config["deterministic"]
] for source, run_configs in data_configs.items() }
stoch_models = { source: [
    run_config["label"]
    for run_config in run_configs
    if not run_config["deterministic"]
] for source, run_configs in data_configs.items() }
merged_ds = { source: prep_eval_data(data_config, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run) for source, data_config in data_configs.items() }
merged_ds

## Spread-skill

https://journals.ametsoc.org/view/journals/hydr/15/4/jhm-d-14-0008_1.xml?tab_body=fulltext-display

https://journals.ametsoc.org/view/journals/aies/2/2/AIES-D-22-0061.1.xml

https://www.sciencedirect.com/science/article/pii/S0021999107000812

In [None]:
cpm_ds = merged_ds["CPM"].sel(model=stoch_models["CPM"])

fig = plt.figure(figsize=(5.5, 5.5), layout="constrained")
axd = fig.subplot_mosaic([["Spread-Skill"]])
ax = axd["Spread-Skill"]
plot_spread_skill(cpm_ds, ax)

plt.show()

## CRPS