# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr

from mlde_analysis.utils import chained_groupby_map
from mlde_analysis.data import prep_eval_data
from mlde_analysis import plot_map, SUBREGIONS
from mlde_analysis.bootstrap import resample_examples
from mlde_analysis.ccs import compute_changes, plot_changes, plot_tp_fd, plot_hist_per_tp, plot_ccs_fc_figure, ccs_fc_da
from mlde_analysis.distribution import PER_GRIDBOX_QUANTILES, normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias, compute_metrics, DIST_THRESHOLDS
from mlde_analysis.fractional_contribution import compute_fractional_contribution, frac_contrib_change, fc_bins
from mlde_analysis.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop_change, plot_threshold_exceedence_errors, THRESHOLDS
from mlde_analysis.display import VAR_RANGES, pretty_table
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_analysis.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

PRED_DAS = { var: xr.concat([ ds[f"pred_{var}"] for ds in EVAL_DS.values() ], dim="model") for var in eval_vars }

MODELLABEL2SPEC = { model: {"source": source} | spec for source, models in MODELS.items() for model, spec in models.items() } | {"CPM": {"source": "CPM", "color": "black"}}

In [None]:
CCS_SEASONS = ["DJF", "JJA"]

In [None]:
CCS_MODELS = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

CCS_PRED_DAS = {var: PRED_DAS[var].sel(model=list([model for source, models in CCS_MODELS.items() for model in models])) for var in eval_vars }

CCS_DS = xr.combine_by_coords([*list(CPM_DAS.values()), *list(CCS_PRED_DAS.values())])

## Figure: per Time Period per season fractional contribution

* Fractional contribution of rainfall intensities for each time period and season
* Change in fractional contribution of rainfall intensities from historic to future for each season

In [None]:
for var in eval_vars:
    if var not in ["pr"]:
        continue

    fcdata = ccs_fc_da(CCS_DS[f"pred_{var}"], CCS_DS[f"target_{var}"], extra_pred_dims=["time.season"], extra_cpm_dims=["time.season"])
    
    for season, season_fc in fcdata.groupby("season"):
        if season not in CCS_SEASONS:
            continue

        IPython.display.display_markdown(f"#### {season}", raw=True)

        
        fig = plt.figure(layout='constrained', figsize=(3.5, 6))
        plot_ccs_fc_figure(fig, season_fc, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()})
        plt.show()

### Bootstrapped per season fractional contribution

In [None]:
for var in eval_vars:
    if var not in ["pr"]:
        continue

    for season, season_ds in CCS_DS.groupby("time.season"):
        if season not in CCS_SEASONS:
            continue
        IPython.display.display_markdown(f"#### {season}", raw=True)
        
        bs_cpm_da = chained_groupby_map(season_ds[f"target_{var}"], ["time_period"], resample_examples, niterations=niterations)

        # fcdata = ccs_fc_da(season_ds[f"pred_{var}"], bs_cpm_da, extra_pred_dims=[], extra_cpm_dims=["iteration"])
        fcdata = ccs_fc_da(season_ds[f"pred_{var}"], bs_cpm_da, extra_cpm_dims=["iteration"])
        
        fig = plt.figure(layout='constrained', figsize=(3.5, 6))
        plot_ccs_fc_figure(fig, fcdata, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()}, errorbar=("pi", 90))
        plt.show()        

In [None]:
for var in eval_vars:
    if var not in ["pr"]:
        continue
    for season, season_ds in CCS_DS.groupby("time.season"):
        if season not in CCS_SEASONS:
            continue
        IPython.display.display_markdown(f"#### {season}", raw=True)
        
        for model, model_pred_da in season_ds[f"pred_{var}"].groupby("model", squeeze=False):

            IPython.display.display_markdown(f"##### {model}", raw=True)    
            
            bs_cpm_da = chained_groupby_map(season_ds[f"target_{var}"], ["time_period"], resample_examples, niterations=niterations)
            bs_pred_da = chained_groupby_map(model_pred_da.squeeze("model"), ["time_period"], resample_examples, niterations=niterations)
    
            fcdata = ccs_fc_da(bs_pred_da.expand_dims(model=[model]), bs_cpm_da, extra_pred_dims=["iteration"], extra_cpm_dims=["iteration"])
            
            fig = plt.figure(layout='constrained', figsize=(3.5, 6))
            plot_ccs_fc_figure(fig, fcdata, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()}, errorbar=("pi", 90))
            plt.show()

In [None]:
for var in eval_vars:
    if var not in ["pr"]:
        continue
    for season, season_ds in CCS_DS.groupby("time.season"):
        if season not in CCS_SEASONS:
            continue
        IPython.display.display_markdown(f"#### {season}", raw=True)
        
        for model, model_pred_da in season_ds[f"pred_{var}"].groupby("model", squeeze=False):

            IPython.display.display_markdown(f"##### {model}", raw=True)    
    
            bs_pred_da = chained_groupby_map(model_pred_da.squeeze("model"), ["time_period"], resample_examples, niterations=niterations)
    
            fcdata = ccs_fc_da(bs_pred_da.expand_dims(model=[model]), season_ds[f"target_{var}"], extra_pred_dims=["iteration"])
            
            fig = plt.figure(layout='constrained', figsize=(3.5, 6))
            plot_ccs_fc_figure(fig, fcdata, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()}, errorbar=("pi", 90))
            plt.show()