# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_analysis.data import prep_eval_data
from mlde_analysis import plot_map
from mlde_analysis.ccs import compute_changes, plot_changes, plot_tp_fd, plot_hist_per_tp, bootstrap_seasonal_mean_pr_change_samples
from mlde_analysis.distribution import normalized_mean_bias, normalized_std_bias, plot_freq_density, plot_mean_biases, plot_std_biases, rms_mean_bias, rms_std_bias
from mlde_analysis.wet_dry import threshold_exceeded_prop_stats, threshold_exceeded_prop_change, plot_threshold_exceedence_errors, THRESHOLDS
from mlde_analysis.display import VAR_RANGES, pretty_table
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
from mlde_analysis.default_params import *

In [None]:
IPython.display.Markdown(desc)

In [None]:
EVAL_DS, MODELS = prep_eval_data(sample_configs, dataset_configs, derived_variables_config, eval_vars, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
EVAL_DS

In [None]:
CPM_DAS = { var: EVAL_DS["CPM"][f"target_{var}"] for var in eval_vars }

MODELLABEL2SPEC = { model: {"source": source} | spec for source, models in MODELS.items() for model, spec in models.items() }

PRED_DAS = { var: xr.concat([ ds[f"pred_{var}"] for ds in EVAL_DS.values() ], dim="model") for var in eval_vars }

In [None]:
CCS_SEASONS = ["DJF", "MAM", "JJA", "SON"]

In [None]:
CCS_MODELS = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

CCS_PRED_DAS = {var: PRED_DAS[var].sel(model=list([model for source, models in CCS_MODELS.items() for model in models])) for var in eval_vars }

CCS_DS = xr.combine_by_coords([*list(CPM_DAS.values()), *list(CCS_PRED_DAS.values())]).sel(model=[m for ms in CCS_MODELS.values() for m in ms.keys()])

## Seasonal domain mean changes

In [None]:
def mean_change(ds):
    hist_stat = ds.where(ds["time_period"] == "historic", drop=True).mean(dim=...)
    fut_stat = ds.where(ds["time_period"] == "future", drop=True).mean(dim=...)
    return fut_stat - hist_stat

def per_model_change(ds):
    return ds.groupby("model").map(functools.partial(mean_change))

for var in eval_vars:
    IPython.display.display_markdown(f"### {var}", raw=True)
    hist_cpm_means = CCS_DS[f"target_{var}"].where(CCS_DS[f"target_{var}"]["time_period"] == "historic", drop=True).groupby("time.season").mean(dim=...)
        
    cpm_change = CCS_DS[f"target_{var}"].groupby("time.season").map(mean_change)
    
    emu_change = CCS_DS[f"pred_{var}"].groupby("time.season").map(per_model_change)
    
    table_ds = xr.merge([
        (100*cpm_change/np.abs(hist_cpm_means)).rename("CPM change (% of CPM historic)"),
        (100*emu_change/np.abs(hist_cpm_means)).rename("Emulator change (% of CPM historic)"),
        (100*(emu_change-cpm_change)/np.abs(cpm_change)).rename(f"Difference (% of CPM change)"),
        (100*(emu_change-cpm_change)/np.abs(hist_cpm_means)).rename(f"Difference (% of CPM historic)"),
    ])

    pretty_table(table_ds)

### Significance testing on seasonal domain mean changes

With bootstrapped distribution and CIs

In [None]:
for var in eval_vars:
    IPython.display.display_markdown(f"#### {var}", raw=True)

    for model, model_ds in CCS_DS.groupby("model"):
        pred_da = model_ds[f"pred_{var}"]
        target_da = model_ds[f"target_{var}"]
        
        IPython.display.display_markdown(f"##### {model}", raw=True)
        fig, axd = plt.subplot_mosaic([["model", "cpm", "difference"], ["clim model", "clim cpm", "clim difference"]], figsize=(9, 6), constrained_layout=True)
        
        seasonal_changes = {}
        for season, season_ds in model_ds.groupby("time.season"):
            if season not in CCS_SEASONS:
                continue
                
            season_cpm_pr = season_ds[f"target_{var}"]
            season_pred_pr = season_ds[f"pred_{var}"].squeeze("model")
            
            hist_season_cpm_pr_mean = (
                season_cpm_pr.where(season_cpm_pr["time_period"] == "historic", drop=True)
                .mean(dim=["grid_latitude", "grid_longitude", "time", "ensemble_member"])
            )
        
            fut_season_cpm_pr_mean = (
                season_cpm_pr.where(season_cpm_pr["time_period"] == "future", drop=True)
                .mean(dim=["grid_latitude", "grid_longitude", "time", "ensemble_member"])
            )
        
            season_cpm_pr_mean_change = fut_season_cpm_pr_mean-hist_season_cpm_pr_mean
            
            hist_mean_samples, fut_mean_samples = bootstrap_seasonal_mean_pr_change_samples(season_cpm_pr, season_pred_pr, nsamples=100_000)

            mean_cpm_change_samples = (100*(fut_mean_samples[f"target_{var}"] - hist_mean_samples[f"target_{var}"])/np.abs(season_cpm_pr_mean_change)).rename("cpm")
            mean_emu_change_samples = (100*(fut_mean_samples[f"pred_{var}"] - hist_mean_samples[f"pred_{var}"])/np.abs(season_cpm_pr_mean_change)).rename("emu")
            differences = (mean_emu_change_samples - mean_cpm_change_samples).rename("difference")            

            clim_mean_cpm_change_samples = (100*(fut_mean_samples[f"target_{var}"] - hist_mean_samples[f"target_{var}"])/np.abs(hist_season_cpm_pr_mean)).rename("clim_cpm")
            clim_mean_emu_change_samples = (100*(fut_mean_samples[f"pred_{var}"] - hist_mean_samples[f"pred_{var}"])/np.abs(hist_season_cpm_pr_mean)).rename("clim_emu")
            clim_differences = (clim_mean_emu_change_samples - clim_mean_cpm_change_samples).rename("clim_difference")            
            
            seasonal_changes[season] = xr.merge([mean_cpm_change_samples, mean_emu_change_samples, differences, clim_mean_cpm_change_samples, clim_mean_emu_change_samples, clim_differences])
            # print(seasonal_changes[season])

        for season in CCS_SEASONS:
            IPython.display.display_markdown(f"###### {season}", raw=True)

            ax = axd["model"]
            seasonal_changes[season]["emu"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            ax.legend()
            ax.set_title(f"{model}")
            
            ax = axd["cpm"]
            seasonal_changes[season]["cpm"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"CPM")

            ax = axd["difference"]
            seasonal_changes[season]["difference"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"Differences {model}")
            
            ax = axd["clim difference"]
            seasonal_changes[season]["clim_difference"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"Differences {model} relative to hist clim")

            ax = axd["clim model"]
            seasonal_changes[season]["clim_emu"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"{model} relative to hist clim")
            
            ax = axd["clim cpm"]
            seasonal_changes[season]["clim_cpm"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"CPM relative to hist clim")
        
            alpha = 0.05
            IPython.display.display_html(seasonal_changes[season].quantile([alpha/2, 1-(alpha/2)]).to_dataframe().round(2).to_html(), raw=True)

        plt.show()

In [None]:
from mlde_analysis import significance

# CCS_SEASONS = ["DJF", "JJA"]
# for season, season_ds in EVAL_DS["GCM"].isel(model=[0]).groupby("time.season"):
#     if season not in CCS_SEASONS:
#         continue
#     fig = plt.figure(figsize=(5.5, 2.5), layout="compressed")
#     axd = fig.subplot_mosaic(np.array(list(TIME_PERIODS.keys())).reshape(1, -1), sharex=True, subplot_kw=dict(projection=cp_model_rotated_pole))
    
#     for tp_idx, tp_key in enumerate(TIME_PERIODS.keys()):
#         tp_season_ds = season_ds.where(season_ds["time_period"] == tp_key, drop=True)

#         for model, ds in tp_season_ds.groupby("model"):
#             ttest_result = significance.significance_test(ds.squeeze())
#             ds["pvalue"] = xr.Variable(["grid_latitude", "grid_longitude"], ttest_result.pvalue)
    
#             N = len(ds["grid_longitude"]) * len(ds["grid_latitude"])
#             alpha_fdr = 0.1
#             pvalue_threshold = alpha_fdr*np.arange(1, N+1, step=1)/N        
#             sorted_pvalues = np.sort(np.ravel(ds["pvalue"]))
#             p_values_less_than_alpha_fdr_frac = np.nonzero(np.cumprod(sorted_pvalues <= pvalue_threshold))[0]
#             if len(p_values_less_than_alpha_fdr_frac) == 0:
#                 # no local tests are below the controlled FDR
#                 p_fdr_star = 0. 
#             else:
#                 idx_star = p_values_less_than_alpha_fdr_frac.max()
#                 p_fdr_star = sorted_pvalues[idx_star]
            
#             ax=axd[tp_key]
#             # plot_map(ds["pvalue"], ax=ax, add_colorbar=True, style=None)
#             plot_map(ds["pvalue"] <= p_fdr_star, ax=ax, add_colorbar=True, style=None, vmin=0, vmax=1)
#             ax.set_title(f"{tp_key}")
#             fig.suptitle(f"{season} {model}")
#     plt.show()