# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr

from mlde_analysis.utils import chained_groupby_map
from mlde_analysis.data import prep_eval_data
from mlde_analysis import plot_map
from mlde_analysis.bootstrap import resample_examples
from mlde_analysis.ccs import compute_changes, plot_changes, plot_tp_fd, plot_hist_per_tp, plot_ccs_fc_figure, ccs_fc_da
from mlde_analysis.distribution import QUANTILES, normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias, compute_metrics, DIST_THRESHOLDS
from mlde_analysis.fractional_contribution import compute_fractional_contribution, frac_contrib_change, fc_bins
from mlde_analysis.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop_change, plot_threshold_exceedence_errors, THRESHOLDS
from mlde_analysis.display import VAR_RANGES, pretty_table
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_analysis.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

PRED_DAS = { var: xr.concat([ ds[f"pred_{var}"] for ds in EVAL_DS.values() ], dim="model") for var in eval_vars }

MODELLABEL2SPEC = { model: {"source": source} | spec for source, models in MODELS.items() for model, spec in models.items() } | {"CPM": {"source": "CPM", "color": "black"}}

In [None]:
CCS_SEASONS = ["DJF", "JJA"]

In [None]:
CCS_MODELS = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

CCS_PRED_DAS = {var: PRED_DAS[var].sel(model=list([model for source, models in CCS_MODELS.items() for model in models])) for var in eval_vars }

CCS_DS = xr.combine_by_coords([*list(CPM_DAS.values()), *list(CCS_PRED_DAS.values())])

## Figure: Climate change signal

* Per time period freq density histogram
* Mean change diff: $(\mu_{{ML}}^{{future}} - \mu_{{ML}}^{{hist}})/\mu_{{ML}}^{{hist}} - (\mu_{{CPM}}^{{future}} - \mu_{{CPM}}^{{hist}})/\mu_{{CPM}}^{{hist}}$

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    for model, model_ds in CCS_DS.groupby("model"):
        IPython.display.display_markdown(f"##### {model}", raw=True)

        pred_da = model_ds[f"pred_{var}"]
        cpm_da = model_ds[f"target_{var}"]
        
        fig, axd = plt.subplot_mosaic([["cpm", "model"]], figsize=(3.5, 2), constrained_layout=True, sharex=True, sharey=True)
        
        hrange=VAR_RANGES[var]
        plot_hist_per_tp(cpm_da, axd["cpm"], title="CPM",  hrange=hrange, legend=False)
        plot_hist_per_tp(pred_da, axd["model"], title="Emulator", hrange=hrange)
        axd["model"].set_ylabel(None)

        plt.show()

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    for model, model_ds in CCS_DS.groupby("model"):
        IPython.display.display_markdown(f"##### {model}", raw=True)

        fd_fig = plt.figure(figsize=(5.5, 5.5), layout="compressed")
        pred_da = model_ds[f"pred_{var}"]

        if var == "pr":
            hrange = (0, 250)
        else:
            hrange = None
        plot_tp_fd(pred_da, CPM_DAS[var], fd_fig, MODELLABEL2SPEC[model]["source"], model, MODELLABEL2SPEC[model], hrange=hrange)

        plt.show()

## Figure: per Time Period distribution

* Frequency Density Histogram of rainfall intensities for each time period

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    
    fig = plt.figure(layout='constrained', figsize=(3.5, 3.5))
        
    # labels = [ label for mconfigs in MODELS.values() for label in mconfigs.keys() ]
    
    # meanb_axes_keys = list(map(lambda x: f"meanb {x}", labels))
    # meanb_spec = np.array(meanb_axes_keys).reshape(1,-1)
    
    # stddevb_axes_keys = list(map(lambda x: f"stddevb {x}", labels))
    # stddevb_spec = np.array(stddevb_axes_keys).reshape(1,-1)
    
    dist_spec = np.array(list(TIME_PERIODS.keys())).reshape(-1,1)
    
    spec = dist_spec
    
    axd = fig.subplot_mosaic(spec, sharex=True, sharey=True)#, gridspec_kw=dict(height_ratios=[3, 2, 2]), per_subplot_kw={ak: {"projection": cp_model_rotated_pole} for ak in meanb_axes_keys + stddevb_axes_keys})
    
    for tp_key, tp_ds in CCS_DS.groupby("time_period"):
        ax = axd[tp_key]
    
        hist_data = [ dict(data=model_ds[f"pred_{var}"], label=model, color=MODELLABEL2SPEC[model]["color"]) for model, model_ds in tp_ds.groupby("model") ]
        hrange=VAR_RANGES[var]
        plot_freq_density(hist_data, ax=ax, target_da=tp_ds[f"target_{var}"], legend=(tp_key=="historic"), linewidth=1, title=tp_key, hrange=hrange)

        ax.xaxis.label.set_visible(tp_key == "future")    

    plt.show()

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)

    metrics_ds = CCS_DS.groupby("time_period").map(lambda tp_ds: compute_metrics(tp_ds[f"pred_{var}"], tp_ds[f"target_{var}"], thresholds=DIST_THRESHOLDS[var]))
    pretty_table(metrics_ds, round=4, dim_order=["time_period", "model"])

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)
    
    quantiles_da = xr.concat([
        chained_groupby_map(
            CCS_DS[f"pred_{var}"], 
            ["model", "time_period"], 
            lambda gda: gda.quantile(q=QUANTILES, dim=...)
        ),
        chained_groupby_map(
            CCS_DS[f"target_{var}"], 
            ["time_period"], 
            lambda gda: gda.quantile(q=QUANTILES, dim=...)
        ).expand_dims(model=["CPM"])],
        dim="model",
    ).rename("quantile value")
    
    quantiles_change_da = quantiles_da.sel(time_period="future") - quantiles_da.sel(time_period="historic")
    
    pretty_table(
        quantiles_change_da, 
        round=1, 
        caption="Annual quantile change bootstrapped spread",
        dim_order=["quantile", "model"],
    )

## Figure: Fractional contribution

* Fractional contribution of rainfall intensities for each time period
* Change in fractional contribution of rainfall intensities from historic to future

In [None]:
for var in eval_vars:
    if var not in ["pr"]:
        continue
    
    fcdata = ccs_fc_da(CCS_DS[f"pred_{var}"], CCS_DS[f"target_{var}"])
    fig = plt.figure(layout='constrained', figsize=(3.5, 6))
    plot_ccs_fc_figure(fig, fcdata, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()})
    plt.show()

### Bootstrapped fractional contributions

In [None]:
bs_ccs_ds = chained_groupby_map(CCS_DS, ["time_period", "time.season"], resample_examples, niterations=niterations)

In [None]:
for var in eval_vars:
    if var not in ["pr"]:
        continue
        
    fcdata = ccs_fc_da(CCS_DS[f"pred_{var}"], bs_ccs_ds[f"target_{var}"], extra_cpm_dims=["iteration"])
    
    fig = plt.figure(layout='constrained', figsize=(3.5, 6))
    plot_ccs_fc_figure(fig, fcdata, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()}, errorbar=("pi", 90))
    plt.show()

In [None]:
for var in eval_vars:
    if var not in ["pr"]:
        continue
    
    fcdata = ccs_fc_da(bs_ccs_ds[f"pred_{var}"], CCS_DS[f"target_{var}"], extra_pred_dims=["iteration"])
    fig = plt.figure(layout='constrained', figsize=(3.5, 6))
    plot_ccs_fc_figure(fig, fcdata, palette={label: spec["color"] for label, spec in MODELLABEL2SPEC.items()}, errorbar=("pi", 90))
    plt.show()

## Change maps

### Mean change maps

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)    
    for model, ccs_pred_da in CCS_PRED_DAS[var].groupby("model"):
        changes = compute_changes([ccs_pred_da.squeeze("model")], CPM_DAS[var], CCS_SEASONS, stat_func=xr.DataArray.mean)
        change_fig = plt.figure(figsize=(5.5, 4.5), layout="compressed")
        plot_changes(changes, CCS_SEASONS, change_fig, show_change=[model])
        
    plt.show()

### Q99 change maps

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    for q in [0.99]:
        for model, ccs_pred_da in CCS_PRED_DAS[var].groupby("model"):
            changes = compute_changes([ccs_pred_da.squeeze("model")], CPM_DAS[var], CCS_SEASONS, stat_func=functools.partial(xr.DataArray.quantile, q=q))
            change_fig = plt.figure(figsize=(5.5, 4.5), layout="compressed")
            plot_changes(changes, CCS_SEASONS, change_fig, show_change=[model])
            
        plt.show()

### CCS mean Variablity