# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr

from mlde_analysis.data import prep_eval_data
from mlde_analysis import plot_map, distribution_figure, scatter_plots
from mlde_analysis.distribution import normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias
from mlde_analysis.psd import plot_psd, pysteps_rapsd
from mlde_analysis.uncertainty import plot_spread_error
from mlde_analysis.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop, threshold_exceeded_prop_error, threshold_exceeded_prop_change, plot_threshold_exceedence_errors
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
matplotlib.rcParams['figure.dpi'] = 100

In [None]:
from mlde_analysis.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

## Figure: distribution

* Frequency Density Histograms of rainfall intensities, split up by sample_id, ensemble_member and random time partition

In [None]:
rng = np.random.default_rng()

# create a random partitioning by time of the data
# attempt to partition independently by  season and time period
def partition(group):    
    rtp = rng.integers(2, size=group["time"].shape)
    return group.assign_coords({"time_subset": ("time", rtp)})
random_time_partition = CPM_DAS[eval_vars[0]].groupby("stratum").map(partition)["time_subset"].values # NB assumes all variables have same time dimension

In [None]:
import gc

model_labels = np.concatenate([ ds["model"].values for ds in EVAL_DS.values() ], axis=0)
labels = np.append(model_labels, "CPM")

def plot_sampling_variability(ax, da, label, hue, bins, var, target_da):
    sns.histplot(target_da.to_dataframe(), ax=ax, x=f"{var}", bins=bins, stat="density", common_norm=False, legend=False, element="bars", color="black", alpha=0.2, linewidth=0)
    sns.histplot(da.to_dataframe(), ax=ax, x=f"{var}", hue=hue, bins=bins, stat="density", common_norm=False, element="step", fill=False, legend=False)

    ax.set_title(label)
    ax.set_yscale("log")

for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)
    ts_fig = plt.figure(layout='constrained', figsize=(10, 3))
    ts_axd = ts_fig.subplot_mosaic(labels.reshape(1, -1), sharex=True, sharey=True)

    cpm_da = CPM_DAS[var].rename(var).assign_coords({"time_subset": ("time", random_time_partition)})

    xrange = (min(cpm_da.min(),  *[ds[f"pred_{var}"].min() for ds in EVAL_DS.values()]), max(cpm_da.max(), *[ds[f"pred_{var}"].max() for ds in EVAL_DS.values()]))
    bins = np.histogram_bin_edges([], bins=50, range=xrange)
    
    # em_fig = plt.figure(layout='constrained', figsize=(4.5, 5.5))
    # em_axd = em_fig.subplot_mosaic(labels.reshape(-1, 1))
    
    for source, ds in EVAL_DS.items():
        da = ds[f"pred_{var}"].rename(var).assign_coords({"time_subset": ("time", random_time_partition)})
        
        for label, group_da in da.groupby("model"):
            plot_sampling_variability(ts_axd[label], group_da, label, "time_subset", bins, var, cpm_da)
            # plot_sampling_variability(em_axd[label], group_da, label, "ensemble_member", var)
            gc.collect()
    
    label = "CPM"
    plot_sampling_variability(ts_axd[label], cpm_da, label, "time_subset", bins, var, cpm_da)
    # plot_sampling_variability(em_axd[label], cpm_da, label, "ensemble_member", var)
    
    plt.show()