# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr

from mlde_analysis.utils import chained_groupby_map
from mlde_analysis.data import prep_eval_data
from mlde_analysis import plot_map, SUBREGIONS
from mlde_analysis.bootstrap import resample_examples
from mlde_analysis.ccs import compute_changes, plot_changes, plot_tp_fd, plot_hist_per_tp, plot_ccs_fc_figure, ccs_fc_da
from mlde_analysis.distribution import QUANTILES, normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias, compute_metrics, DIST_THRESHOLDS
from mlde_analysis.fractional_contribution import compute_fractional_contribution, frac_contrib_change, fc_bins
from mlde_analysis.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop_change, plot_threshold_exceedence_errors, THRESHOLDS
from mlde_analysis.display import VAR_RANGES, pretty_table
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_analysis.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

PRED_DAS = { var: xr.concat([ ds[f"pred_{var}"] for ds in EVAL_DS.values() ], dim="model") for var in eval_vars }

MODELLABEL2SPEC = { model: {"source": source} | spec for source, models in MODELS.items() for model, spec in models.items() } | {"CPM": {"source": "CPM", "color": "black"}}

In [None]:
CCS_SEASONS = ["DJF", "JJA"]

In [None]:
CCS_MODELS = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

CCS_PRED_DAS = {var: PRED_DAS[var].sel(model=list([model for source, models in CCS_MODELS.items() for model in models])) for var in eval_vars }

CCS_DS = xr.combine_by_coords([*list(CPM_DAS.values()), *list(CCS_PRED_DAS.values())])

## Figure: per Time Period per season frequency density

* Frequency Density Histogram of rainfall intensities for each time period

In [None]:
for var in eval_vars:
    fig = plt.figure(layout='constrained', figsize=(3.5*len(CCS_SEASONS), 2*len(TIME_PERIODS)))
    spec = np.array(list([[f"{season} {tp_key}" for season in CCS_SEASONS] for tp_key in TIME_PERIODS.keys() ]))
    
    axd = fig.subplot_mosaic(spec, sharex=True)
    for season, seasonal_ccs_ds in CCS_DS.groupby("time.season"):
        if season not in CCS_SEASONS:
            continue
        
        for tp_key, tp_ds in seasonal_ccs_ds.groupby("time_period"):
            ax = axd[f"{season} {tp_key}"]
            hist_data = [ dict(data=model_ds[f"pred_{var}"], label=model, color=MODELLABEL2SPEC[model]["color"]) for model, model_ds in tp_ds.groupby("model") ]
            hrange=VAR_RANGES[var]
            plot_freq_density(hist_data, ax=ax, target_da=tp_ds[f"target_{var}"], legend=(tp_key=="historic" and season=="DJF"), linewidth=1, title=f"{season} {tp_key}", hrange=hrange)
    
            ax.xaxis.label.set_visible(tp_key == "future")    
            
    plt.show()

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)

    metrics_ds = CCS_DS.groupby("time.season").map(lambda season_ds: season_ds.groupby("time_period").map(lambda tp_ds: compute_metrics(tp_ds[f"pred_{var}"], tp_ds[f"target_{var}"], thresholds=DIST_THRESHOLDS[var])))
        
    pretty_table(metrics_ds, round=4, dim_order=["season", "time_period", "model"], caption="Distribution metrics per season and time period")

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)
    
    quantiles_da = xr.concat([
        chained_groupby_map(
            CCS_DS[f"pred_{var}"], 
            ["model", "time.season", "time_period"], 
            lambda gda: gda.quantile(q=QUANTILES, dim=...)
        ),
        chained_groupby_map(
            CCS_DS[f"target_{var}"], 
            ["time.season", "time_period"], 
            lambda gda: gda.quantile(q=QUANTILES, dim=...)
        ).expand_dims(model=["CPM"])],
        dim="model",
    ).rename("quantile value")
    
    quantiles_change_da = quantiles_da.sel(time_period="future") - quantiles_da.sel(time_period="historic")
    
    pretty_table(
        quantiles_change_da, 
        round=1, 
        caption="Seasonal quantile change bootstrapped spread",
        dim_order=["season", "quantile", "model"],
    )

## Figure: per subdomain per time period per season distribution

In [None]:
from collections import deque
for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)

    for season, seasonal_ccs_ds in CCS_DS.groupby("time.season"):
        if season not in CCS_SEASONS:
            continue

        fig = plt.figure(figsize=(3.5*len(TIME_PERIODS), 2.5*len(SUBREGIONS)), layout="compressed")

        spec = np.array(list([[f"{srname} {tp_key}" for tp_key in TIME_PERIODS.keys()] for srname in SUBREGIONS.keys()]))
        axd = fig.subplot_mosaic(spec, sharex=True)#, sharey=True)
        panel_label_iter = iter(string.ascii_lowercase)
        for srname, srdefn in SUBREGIONS.items():
            for tp_key, tp_ds in seasonal_ccs_ds.groupby("time_period"):
            
                tpsrseason_cpm_da = tp_ds[f"target_{var}"].isel(**SUBREGIONS[srname])
                
                srseason_hist_data = [
                    dict(data=model_ds[f"pred_{var}"].isel(**SUBREGIONS[srname]), label=model, color=MODELLABEL2SPEC[model]["color"]) for model, model_ds in tp_ds.groupby("model")
                ]
    
                ax = axd[f"{srname} {tp_key}"]
                if var == "pr":
                    kwargs = {
                        "yscale": "log",
                    }
                else:
                    kwargs = {
                        "yscale": "linear",
                    }
                plot_freq_density(srseason_hist_data, ax=ax, target_da=tpsrseason_cpm_da, linewidth=1, hrange=VAR_RANGES[var], **kwargs)
                ax.set_title(f"{srname} {season} {tp_key}", size="small")

        deque(axd[axlabel.item()].annotate(f"{next(panel_label_iter)}.", xy=(-0.1, 1.0), xycoords=("axes fraction", "axes fraction"), weight='bold', ha="left", va="bottom") for axlabel in np.nditer(spec))
    
        plt.show()

## Figure: per Time Period per season fractional contribution

* Fractional contribution of rainfall intensities for each time period and season
* Change in fractional contribution of rainfall intensities from historic to future for each season

In [None]:
for var in eval_vars:
    if var not in ["pr"]:
        continue

    fcdata = ccs_fc_da(CCS_DS[f"pred_{var}"], CCS_DS[f"target_{var}"], extra_pred_dims=["time.season"], extra_cpm_dims=["time.season"])
    
    for season, season_fc in fcdata.groupby("season"):
        if season not in CCS_SEASONS:
            continue

        IPython.display.display_markdown(f"#### {season}", raw=True)

        
        fig = plt.figure(layout='constrained', figsize=(3.5, 6))
        plot_ccs_fc_figure(fig, season_fc, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()})
        plt.show()

### Bootstrapped per season fractional contribution

In [None]:
bs_ccs_ds = chained_groupby_map(CCS_DS, ["time_period", "time.season"], resample_examples, niterations=niterations)

for var in eval_vars:
    if var not in ["pr"]:
        continue
    
    fcdata = ccs_fc_da(CCS_DS[f"pred_{var}"], bs_ccs_ds[f"target_{var}"], extra_pred_dims=["time.season"], extra_cpm_dims=["time.season", "iteration"])

    for season, season_fc in fcdata.groupby("season"):
        if season not in CCS_SEASONS:
            continue
        IPython.display.display_markdown(f"#### {season}", raw=True)
        
        fig = plt.figure(layout='constrained', figsize=(3.5, 6))
        plot_ccs_fc_figure(fig, season_fc, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()}, errorbar=("pi", 90))
        plt.show()