# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

import math
import os

import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pysteps
import seaborn as sns
import xarray as xr
from collections import defaultdict

from mlde_utils.utils import cp_model_rotated_pole, plot_grid, prep_eval_data, plot_examples, distribution_figure, plot_mean_bias, plot_std_bias, plot_psd, scatter_plots, seasonal_distribution_figure
from mlde_utils.plotting import qq_plot

In [None]:
quantiles = np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, -8, -1)])
ymin=0
ymax=300

variabilty_src_to_key = {
    "training": "model",
    "sampling": "sample_id",
}

variabilty_src_to_sample_runs = {
    "training": 1,
    "sampling": 20,
}

In [None]:
data_configs = {
    "training": {
        # "train": [{
        #     "datasets": {"CPM": "bham_gcmx-4x_psl-temp-vort_random"},
        #     "runs": [ [f"score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-IstanTsqrturrecen-{i}", "epoch-100", f"Run {i}",] for i in range(1,21) ]
        # }],
        "val": [{
            "datasets": {"CPM": "bham_gcmx-4x_psl-temp-vort_random"},#, "GCM": "bham_60km-4x_psl-temp-vort_random",},
            "runs": [ [f"score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-IstanTsqrturrecen-{i}", "epoch-100", f"Run {i}",] for i in range(1,21) ]
        }],
    },
    "sampling": {
        # "train": [{
        #     "datasets": {"CPM": "bham_gcmx-4x_psl-temp-vort_random"},
        #     "runs": [ ["score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-IstanTsqrturrecen-1", "epoch-100", "Run 1",] ]
        # }],
        "val": [{
            "datasets": {"CPM": "bham_gcmx-4x_psl-temp-vort_random"},#, "GCM": "bham_60km-4x_psl-temp-vort_random",},
            "runs": [ ["score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_PslTV_random-IstanTsqrturrecen-1", "epoch-100", "Run 1",] ]
        }],
    },
}

In [None]:
quantile_data=defaultdict(dict)
import itertools

for (variabilty_src, split) in itertools.product(["sampling", "training"], ["val"]):
    IPython.display.display_html(f"<h1>{variabilty_src} {split}<h1>", raw=True)
    merged_ds = xr.merge([ prep_eval_data(c["datasets"], c["runs"], split, samples_per_run=variabilty_src_to_sample_runs[variabilty_src]) for c in data_configs[variabilty_src][split] ])

    target_pr = merged_ds.sel(source="CPM")["target_pr"]
    target_quantiles = target_pr.quantile(quantiles)


    l = list(merged_ds["pred_pr"].dims)
    l.remove(variabilty_src_to_key[variabilty_src])
    l.remove("source")
    
    quantile_data[variabilty_src][split] = dict(
        target_quantiles = target_pr.quantile(quantiles),
        sample_quantiles = merged_ds["pred_pr"].groupby(variabilty_src_to_key[variabilty_src]).quantile(quantiles, dim=l),
    )

    fig, axes = plt.subplot_mosaic([[f"{source} Quantiles" for source in quantile_data[variabilty_src][split]["sample_quantiles"]["source"].values]], figsize=(16.5, 5.5), constrained_layout=True)    

    for source in merged_ds["source"].values:
        qq_plot(axes[f"{source} Quantiles"], target_pr, merged_ds.sel(source=source), quantiles)
        axes[f"{source} Quantiles"].set_xlim(ymin, ymax)
        axes[f"{source} Quantiles"].set_ylim(ymin, ymax)
        
    plt.show()

In [None]:
for (variabilty_src, split) in itertools.product(["sampling", "training"], ["val"]):
    IPython.display.display_html(f"<h1>{variabilty_src} {split}<h1>", raw=True)
    fig, axes = plt.subplot_mosaic(
        [[f"{source} Quantiles" for source in quantile_data[variabilty_src][split]["sample_quantiles"]["source"].values]], figsize=(16.5, 5.5), constrained_layout=True
    )

    quantile_std = quantile_data[variabilty_src][split]["sample_quantiles"].std(dim=[variabilty_src_to_key[variabilty_src]]).to_pandas().reset_index().merge(quantile_data[variabilty_src][split]["target_quantiles"].to_pandas().rename('CPM pr').reset_index()).melt(id_vars='CPM pr', value_vars=merged_ds["source"].values, value_name="Model quantile std", var_name="source")
    ax = sns.lineplot(data=quantile_std, x='CPM pr', y="Model quantile std", hue="source")
    ax.set(ylabel="Model quantile $\sigma$")
    ax.set_ylim(0, 20)
    # sns.barplot(data=data, x="cpm_quantile", y="GCM")
    for source in merged_ds["source"].values:
        ax.fill_between(quantile_std[quantile_std["source"] == source]["CPM pr"], quantile_std[quantile_std["source"] == source]["Model quantile std"], alpha=0.5)

In [None]:
quantile_data[variabilty_src][split]["sample_quantiles"].std(dim=[variabilty_src_to_key[variabilty_src]]).sel(source="CPM").to_pandas().reset_index()

In [None]:
quantile_data[variabilty_src][split]["target_quantiles"].to_pandas().rename('CPM pr').reset_index()