# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv
    
import math
import os

import cftime
import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pysteps
import scipy
import xarray as xr

from mlde_notebooks import plot_map, show_samples, distribution_figure, plot_mean_bias, plot_std_bias, scatter_plots, seasonal_distribution_figure, compute_gridspec, freq_density_plot
from mlde_notebooks import plot_psd, plot_fss
from mlde_utils import cp_model_rotated_pole
from mlde_utils.utils import prep_eval_data
from mlde_notebooks import create_map_fig, qq_plot, STYLES, plot_spread_skill

In [None]:
def reasonable_quantiles(da):
    limit = int(np.log10(1/da.size))
    print(limit)
    return np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, limit-1, -1)] + [[1.0]])

In [None]:
TIME_PERIODS = {
    "historic": (cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2000, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
    "present": (cftime.Datetime360Day(2020, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2040, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
    "future": (cftime.Datetime360Day(2060, 12, 1, 12, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2080, 11, 30, 12, 0, 0, 0, has_year_zero=True)),
}

In [None]:
split = "val"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        # {
        #     "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
        #     "checkpoint": "epoch-20",
        #     "input_xfm": "stan",
        #     "label": "Diff",
        #     "dataset": "bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
        #     "deterministic": False,
        # },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "stan",
            "label": "Diff no locspec",
            "dataset": "bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
        {
            "fq_model_id": "u-net/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-100",
            "input_xfm": "stan",
            "label": "U-Net",
            "dataset": "bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": True,
        },
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "none",
            "label": "Bilinear",
            "deterministic": True,
            "dataset": "bham_gcmx-4x_12em_linpr_eqvt_random-season",
        },
    ],
    "GCM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-20",
            "input_xfm": "pixelmmsstan",
            "label": "GCM Diff bc",
            "dataset": "bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-20",
            "input_xfm": "stan",
            "label": "GCM Diff",
            "dataset": "bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
    ],
}
# the datasets to use for comparisons like PSD which need default datasets with CPM-based hi-res precip and GCM-based lo-res precip respectively
simulation_pr_datasets = {
    "GCM": "bham_60km-4x_linpr_random-season",
    "CPM": "bham_60km-4x_linpr_random-season"
}
gcm_lr_lin_pr_dataset = "bham_60km-4x_12em_linpr_eqvt_random-season"
cpm_hr_pr_dataset = "bham_gcmx-4x_12em_linpr_eqvt_random-season"
desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
det_models = { source: [
    run_config["label"] for run_config in run_configs if run_config["deterministic"]
] for source, run_configs in data_configs.items() }
stoch_models = { source: [
    run_config["label"]
    for run_config in run_configs
    if not run_config["deterministic"]
] for source, run_configs in data_configs.items() }
merged_ds = { source: prep_eval_data(data_config, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run) for source, data_config in data_configs.items() }
cpm_pr = merged_ds["CPM"]["target_pr"]
merged_ds

### FSS

In [None]:
fig = plt.figure(layout="constrained")

ds = merged_ds["CPM"]
thresholds = cpm_pr.quantile([0.5, 0.75, 0.9, 0.99]).values
print(thresholds)
plot_fss(fig, ds, thresholds)