# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_analysis.data import prep_eval_data
from mlde_analysis import plot_map
from mlde_analysis.ccs import compute_changes, plot_changes, plot_tp_fd, plot_hist_per_tp
from mlde_analysis.distribution import normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias
from mlde_analysis.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop_change, plot_threshold_exceedence_errors, THRESHOLDS
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_analysis.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

In [None]:
CCS_SEASONS = ["DJF", "MAM", "JJA", "SON"]

## Figure: Climate change signal

* Per time period freq density histogram
* Mean change diff: $(\mu_{{ML}}^{{future}} - \mu_{{ML}}^{{hist}})/\mu_{{ML}}^{{hist}} - (\mu_{{CPM}}^{{future}} - \mu_{{CPM}}^{{hist}})/\mu_{{CPM}}^{{hist}}$

In [None]:
ccs_models = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

ccs_pred_das = [ 
    xr.merge([ 
        EVAL_DS[source][f"pred_{var}"].sel(model=model) for var in eval_vars 
    ]) for source, models in ccs_models.items() for model in models
]

ccs_ds = xr.combine_by_coords([*list(CPM_DAS.values()), xr.concat(ccs_pred_das, dim="model")])

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    for source, mconfigs in ccs_models.items():
        for model, spec in mconfigs.items():
            IPython.display.display_markdown(f"##### {model}", raw=True)

            pred_da = EVAL_DS[source][f"pred_{var}"].sel(model=model)
            cpm_da = CPM_DAS[var]
            
            fig, axd = plt.subplot_mosaic([["cpm", "model"]], figsize=(3.5, 2), constrained_layout=True, sharex=True, sharey=True)
            
            if var == "pr":
                hrange = (0, 250)
            else:
                hrange = None
            plot_hist_per_tp(cpm_da, axd["cpm"], title="CPM",  hrange=hrange, legend=False)
            plot_hist_per_tp(pred_da, axd["model"], title="Emulator", hrange=hrange)
            axd["model"].set_ylabel(None)
    
            plt.show()

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    for source, mconfigs in ccs_models.items():
        for model, spec in mconfigs.items():
            IPython.display.display_markdown(f"##### {model}", raw=True)
    
            fd_fig = plt.figure(figsize=(5.5, 5.5), layout="compressed")
            pred_da = EVAL_DS[source][f"pred_{var}"].sel(model=model)
            if var == "pr":
                hrange = (0, 250)
            else:
                hrange = None
            plot_tp_fd(pred_da, CPM_DAS[var], fd_fig, source, model, spec, hrange=hrange)
    
            plt.show()

## Figure: per Time Period distribution

* Frequency Density Histogram of rainfall intensities for each time period

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    
    fig = plt.figure(layout='constrained', figsize=(3.5, 3.5))
        
    # labels = [ label for mconfigs in MODELS.values() for label in mconfigs.keys() ]
    
    # meanb_axes_keys = list(map(lambda x: f"meanb {x}", labels))
    # meanb_spec = np.array(meanb_axes_keys).reshape(1,-1)
    
    # stddevb_axes_keys = list(map(lambda x: f"stddevb {x}", labels))
    # stddevb_spec = np.array(stddevb_axes_keys).reshape(1,-1)
    
    dist_spec = np.array(list(TIME_PERIODS.keys())).reshape(-1,1)
    
    spec = dist_spec
    
    axd = fig.subplot_mosaic(spec, sharex=True, sharey=True)#, gridspec_kw=dict(height_ratios=[3, 2, 2]), per_subplot_kw={ak: {"projection": cp_model_rotated_pole} for ak in meanb_axes_keys + stddevb_axes_keys})
    
    for tp_key, tp_ds in ccs_ds.groupby("time_period"):
        ax = axd[tp_key]
    
        hist_data = {var: [ dict(data=tp_ds[f"pred_{var}"].sel(model=model).where(EVAL_DS[source]["time_period"] == tp_key, drop=True), label=model, color=spec["color"]) for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() if spec.get("CCS", False) ] }
    
        # mean_biases = {var: [ dict(data=normalized_mean_bias(tp_ds[f"pred_{var}"].sel(model=model), CPM_DAS[var].where(CPM_DAS[var]["time_period"] == tp_key, drop=True)), label=model)  for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() if spec.get("CCS", False) ] }
    
        # std_biases = {var: [ dict(data=normalized_std_bias(tp_ds[f"pred_{var}"].sel(model=model), CPM_DAS[var].where(CPM_DAS[var]["time_period"] == tp_key, drop=True)), label=model)  for source, mconfigs in MODELS.items() for model, spec in mconfigs.items() if spec.get("CCS", False) ] }
        if var == "pr":
            hrange = (0, 250)
        else:
            hrange = None
        plot_freq_density(hist_data[var], ax=ax, target_da=CPM_DAS[var].where(CPM_DAS[var]["time_period"] == tp_key, drop=True), legend=(tp_key=="historic"), linewidth=1, title=tp_key, hrange=hrange)

        ax.xaxis.label.set_visible(tp_key == "future")    
            
        # ax.annotate("a.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")
        
        # axes = plot_mean_biases(mean_biases[var], axd)
        # axes[0].annotate("b.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")
        
        # axes = plot_std_biases(std_biases[var], axd)
        # axes[0].annotate("c.", xy=(0.04, 1.0), xycoords=("figure fraction", "axes fraction"), weight='bold', ha="left", va="bottom")
        
        
    plt.show()

In [None]:
for var in eval_vars:
    rms_mean_biases = xr.concat([ tp_ds[f"pred_{var}"].groupby("model", squeeze=False).map(functools.partial(rms_mean_bias, cpm_da=CPM_DAS[var].where(CPM_DAS[var]["time_period"] == tp_key, drop=True), normalize=False)).expand_dims(dim={"time_period": [tp_key]}) for tp_key, tp_ds in ccs_ds.groupby("time_period") ], dim="time_period").rename(f"RMS Mean Bias ({CPM_DAS[var].attrs['units']})")
    rms_std_biases = xr.concat([ tp_ds[f"pred_{var}"].groupby("model", squeeze=False).map(functools.partial(rms_std_bias, cpm_da=CPM_DAS[var].where(CPM_DAS[var]["time_period"] == tp_key, drop=True), normalize=False)).expand_dims(dim={"time_period": [tp_key]}) for tp_key, tp_ds in ccs_ds.groupby("time_period") ], dim="time_period").rename(f"RMS Std Dev Bias ({CPM_DAS[var].attrs['units']})")
    
    relative_rms_mean_biases = xr.concat([ tp_ds[f"pred_{var}"].groupby("model", squeeze=False).map(functools.partial(rms_mean_bias, cpm_da=CPM_DAS[var].where(CPM_DAS[var]["time_period"] == tp_key, drop=True), normalize=True)).expand_dims(dim={"time_period": [tp_key]}) for tp_key, tp_ds in ccs_ds.groupby("time_period") ], dim="time_period").rename("Relative RMS Mean Bias (%)")
    relative_rms_std_biases = xr.concat([ tp_ds[f"pred_{var}"].groupby("model", squeeze=False).map(functools.partial(rms_std_bias, cpm_da=CPM_DAS[var].where(CPM_DAS[var]["time_period"] == tp_key, drop=True), normalize=True)).expand_dims(dim={"time_period": [tp_key]}) for tp_key, tp_ds in ccs_ds.groupby("time_period") ], dim="time_period").rename("Relative RMS Std Dev Bias (%)")
        
    IPython.display.display_html(xr.merge([rms_mean_biases, rms_std_biases, relative_rms_mean_biases, relative_rms_std_biases]).to_dataframe().round(2).to_html(), raw=True)

### Mean change maps

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    for ccs_pred_da in ccs_pred_das:
        changes = compute_changes([ccs_pred_da[f"pred_{var}"]], CPM_DAS[var], ["DJF", "JJA"], stat_func=xr.DataArray.mean)
        mean_change_fig = plt.figure(figsize=(5.5, 4.5), layout="compressed")
        plot_changes(changes, ["DJF", "JJA"], mean_change_fig, show_change=[ccs_pred_da["model"].data.item()])
        
    plt.show()

### Q99 change maps

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    for q in [0.99]:
        IPython.display.display_markdown(f"#### Quantile: {q}", raw=True)
        for ccs_pred_da in ccs_pred_das:
            changes = compute_changes([ccs_pred_da[f"pred_{var}"]], CPM_DAS[var], ["DJF", "JJA"], stat_func=functools.partial(xr.DataArray.quantile, q=q))
            
            mean_change_fig = plt.figure(figsize=(5.5, 4.5), layout="compressed")
            plot_changes(changes, ["DJF", "JJA"], mean_change_fig, show_change=[ccs_pred_da["model"].data.item()])
            
            plt.show()

### CCS mean Variablity

In [None]:
time_da = EVAL_DS["CPM"]["time"]

df = time_da.to_dataframe().drop_duplicates(["stratum", "dec_adjusted_year"])

from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)
    for ccs_pred_da in ccs_pred_das:
        for subsample_idx, (_, test_idx) in enumerate(skf.split(df[["dec_adjusted_year"]], df["stratum"])):
            fold_df = df.iloc[test_idx]
            fold_time_da = time_da.where(time_da["tp_season_year"].isin(fold_df["stratum"].str.cat(fold_df["dec_adjusted_year"].astype("str"), sep=' ').values), drop=True)
            
            ccs_pred_da_subsamples = ccs_pred_da[f"pred_{var}"].sel(time=fold_time_da.data)
            
            mean_changes = compute_changes([ccs_pred_da_subsamples], EVAL_DS["CPM"][f"target_{var}"].sel(time=fold_time_da), CCS_SEASONS, stat_func=xr.DataArray.mean)
        
            mean_change_fig = plt.figure(figsize=(5.5, 5.5), layout="compressed")
            
            plot_changes(mean_changes, CCS_SEASONS, mean_change_fig, show_change=[ccs_pred_da_subsamples["model"].data.item()])
        
            plt.show()

# Threshold proportion change

### CCS threshold change: gridbox

In [None]:
model_das = {var: [
    EVAL_DS[source][f"pred_{var}"].sel(model=model)
    for source, mconfigs in MODELS.items() for model, spec in mconfigs.items()
] for var, thresholds in THRESHOLDS.items() if var in eval_vars}

change_stats = {    
    var: { 
        threshold: threshold_exceeded_prop_stats(model_das[var], CPM_DAS[var], threshold=threshold, threshold_exceeded_prop_statistic=threshold_exceeded_prop_change) for threshold in thresholds
    } for var, thresholds in THRESHOLDS.items() if var in eval_vars
}


for var, thresholds in THRESHOLDS.items():
    if var in eval_vars:
        for threshold in thresholds:
            IPython.display.display_markdown(f"#### Threshold: {threshold}mm/day", raw=True)
            
            plot_threshold_exceedence_errors(change_stats[var][threshold][f"change in % threshold exceeded"], style="change")
        
            plt.show()

### CCS threshold change: domain mean

In [None]:
dfs = [change_stats[var][threshold].mean(dim=["grid_latitude", "grid_longitude"]).to_dataframe().style.set_table_attributes("style='display:inline'").set_caption(f"Threshold: {threshold}mm/day").format(precision=1).to_html() for var, thresholds in THRESHOLDS.items() if var in eval_vars for threshold in thresholds]

IPython.display.display_html(functools.reduce(lambda v, e: v+e, dfs), raw=True)