# Evaluation of distribution of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import itertools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr
import xskillscore as xss

from mlde_utils import cp_model_rotated_pole

from mlde_notebooks.data import prep_eval_data
from mlde_notebooks.display import pretty_table
from mlde_notebooks.spatial_correlation import plot_correlations, compute_correlations

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_notebooks.default_params import *
eval_vars=["pr", "relhum150cm", "tmean150cm"]

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

## Figure: Spatial Correlation

* Spearman

In [None]:
def extract_and_compute_correlations(ds, var_pair):
    x_pred = ds[f"pred_{var_pair[0]}"]
    y_pred = ds[f"pred_{var_pair[1]}"]
    x_target = ds[f"target_{var_pair[0]}"]
    y_target = ds[f"target_{var_pair[1]}"]
    return compute_correlations(x_pred, y_pred, x_target, y_target, corr_f=xss.spearman_r)

### Seasonal

In [None]:
for source in EVAL_DS.keys():
    ds = xr.merge([EVAL_DS[source]] + list(CPM_DAS.values()))
    for var_pair in itertools.combinations(eval_vars, 2):
        corr_ds = ds.groupby("time.season").map(extract_and_compute_correlations, var_pair=var_pair)

        for season, season_corr_ds in corr_ds.groupby("season"):
            if season not in ["DJF", "JJA"]: continue
            IPython.display.display_markdown(f"{season} {var_pair} Spearman", raw=True)
            fig, _ = plot_correlations(season_corr_ds)
            # fig.suptitle(f"{season} {var_pair} Spearman", fontsize="xx-small")
            plt.show()
        
        pretty_table(np.sqrt((corr_ds["Corr diff"]**2).mean(dim=["grid_latitude", "grid_longitude"])), round=4)

### Season & Time Period

In [None]:
for source in EVAL_DS.keys():
    ds = xr.merge([EVAL_DS[source]] + list(CPM_DAS.values()))
    for var_pair in itertools.combinations(eval_vars, 2):
        IPython.display.display_markdown(f"#### {var_pair}", raw=True)
        corr_ds = ds.groupby("time.season").map(lambda g: g.groupby("time_period").map(extract_and_compute_correlations, var_pair=var_pair))

        for season, season_corr_ds in corr_ds.groupby("season"):
            if season not in ["DJF", "JJA"]: continue
            for tp, stratum_corr_ds in season_corr_ds.groupby("time_period"):
                stratum = f"{tp} {season}"
                if tp not in ["future", "historic"]: continue
                IPython.display.display_markdown(f"##### {stratum}", raw=True)

                fig, _ = plot_correlations(stratum_corr_ds)
                fig.suptitle(f"{stratum} {var_pair} Spearman", fontsize="xx-small")
                plt.show()
                
        IPython.display.display_markdown(f"##### RMSE summary", raw=True)
        pretty_table(np.sqrt((corr_ds["Corr diff"]**2).mean(dim=["grid_latitude", "grid_longitude"])), round=4)