# Evaluation of a structure of samples from models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_notebooks.data import prep_eval_data
from mlde_notebooks.psd import plot_psd, pysteps_rapsd

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_notebooks.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

## Figure: structure

* PSD

In [None]:
# if len(eval_vars) > 1:
#     gridspec = np.pad(np.array(eval_vars), (0, -len(eval_vars) % 2), constant_values=".").reshape(-1, 1)
# else:
#     gridspec = np.array([eval_vars])
# structure_fig = plt.figure(figsize=(5.5*gridspec.shape[1], 3.5*gridspec.shape[0]), layout="constrained")
# axd = structure_fig.subplot_mosaic(gridspec, sharey=True, sharex=False)

for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)
    # gridspec = np.array([var]).reshape(1,1)
    structure_fig = plt.figure(figsize=(4, 3), layout="constrained")
    axd = structure_fig.subplot_mosaic([[var]], sharey=True, sharex=False)
    cpm_hr_rapsd = pysteps_rapsd(CPM_DAS[var].stack(example=["ensemble_member", "time"]).transpose("example", "grid_latitude", "grid_longitude"), pixel_size=8.8).mean(dim="example").drop_sel(freq=0)
    
    pred_rapsds = [
        {
            "label": model,
            "color": spec["color"],
            "data": pysteps_rapsd(EVAL_DS[source][f"pred_{var}"].sel(model=model).stack(example=["ensemble_member", "sample_id", "time"]).transpose("example", "grid_latitude", "grid_longitude"), pixel_size=8.8).mean(dim="example").drop_sel(freq=0)
        }
        for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
    ]
    
    ax = axd[var]

    plot_psd(cpm_hr_rapsd, pred_rapsds, ax=ax)
    # ax.set_title(CPM_DAS[var].attrs["long_name"])
    
    plt.show()