# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import functools
import math
import string

import IPython
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_notebooks.data import prep_eval_and_model_data
from mlde_notebooks import plot_map
from mlde_notebooks.ccs import compute_changes, plot_changes, plot_tp_fd, bootstrap_seasonal_mean_pr_change_samples
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
split = "val"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-stan",
            "label": "Diffusion (cCPM)",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "PSD": True,
            "CCS": True,
            "color": "blue",
            "order": 10,
        },
        {
            "fq_model_id": "u-net/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-100",
            "input_xfm": "stan",
            "label": "U-Net (cCPM)",
            "dataset": "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": True,
            "PSD": True,
            "color": "orange",
            "order": 1,
        },
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "none",
            "label": "Bilinear cCPM",
            "deterministic": True,
            "dataset": "bham_gcmx-4x_12em_linpr_eqvt_random-season",
            "color": "red",
            "order": 0,
        },
    ],
    "GCM": [
        {
            "fq_model_id": "score-sde/subvpsde/xarray_12em_cncsnpp_continuous/bham-4x_12em_PslS4T4V4_random-season-IstanTsqrturrecen-no-loc-spec",
            "checkpoint": "epoch-20",
            "input_xfm": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-pixelmmsstan",
            "label": "Diffusion (GCM)",
            "dataset": "bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
            "CCS": True,
            "color": "green",
            "order": 20,
            
        },
    ],
}

desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
merged_ds, MODELS = prep_eval_and_model_data(data_configs, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run)
merged_ds

In [None]:
cpm_pr = merged_ds["CPM"]["target_pr"]

## Figure: Climate change signal

* Per time period freq density histogram
* Mean change diff: $(\mu_{{ML}}^{{future}} - \mu_{{ML}}^{{hist}})/\mu_{{ML}}^{{hist}} - (\mu_{{CPM}}^{{future}} - \mu_{{CPM}}^{{hist}})/\mu_{{CPM}}^{{hist}}$

In [None]:
ccs_models = {source: {model: spec for model, spec in mconfigs.items() if spec.get("CCS", False)} for source, mconfigs in MODELS.items() }

ccs_pred_pr_das = [ merged_ds[source]["pred_pr"].sel(model=model) for source, models in ccs_models.items() for model in models ]

ccs_ds = xr.combine_by_coords([cpm_pr, xr.concat(ccs_pred_pr_das, dim="model")])

In [None]:
for source, mconfigs in ccs_models.items():
    for model, spec in mconfigs.items():
        IPython.display.display_markdown(f"#### {model}", raw=True)

        fd_fig = plt.figure(figsize=(5.5, 5.5), layout="compressed")
        pred_pr = merged_ds[source]["pred_pr"].sel(model=model)
        
        plot_tp_fd(pred_pr, cpm_pr, fd_fig, source, model, spec)

        plt.show()

### Seasonal domain mean changes

In [None]:
def mean_change(ds):
    hist_stat = ds.where(ds["time_period"] == "historic", drop=True).mean(dim=...)
    fut_stat = ds.where(ds["time_period"] == "future", drop=True).mean(dim=...)
    return fut_stat - hist_stat

def per_model_change(ds):
    return ds.groupby("model").map(functools.partial(mean_change))
                                   
hist_cpm_means = ccs_ds["target_pr"].where(ccs_ds["target_pr"]["time_period"] == "historic", drop=True).groupby("time.season").mean(dim=...)
    
cpm_change = ccs_ds["target_pr"].groupby("time.season").map(mean_change)

emu_change = ccs_ds["pred_pr"].groupby("time.season").map(per_model_change)

table_ds = xr.merge([
    (100*cpm_change/np.abs(hist_cpm_means)).rename("CPM change (% of CPM historic)"),
    (100*emu_change/np.abs(hist_cpm_means)).rename("Emulator change (% of CPM historic)"),
    (100*(emu_change-cpm_change)/np.abs(cpm_change)).rename(f"Difference (% of CPM change)"),
    (100*(emu_change-cpm_change)/np.abs(hist_cpm_means)).rename(f"Difference (% of CPM historic)"),
])

IPython.display.display_html(table_ds.round(1).to_dataframe().to_html(), raw=True)

#### Significance testing on seasonal domain mean changes

With bootstrapped distribution and CIs

In [None]:
CCS_SEASONS = ["DJF", "MAM", "JJA", "SON"]
for source, mconfigs in ccs_models.items():
    for model in mconfigs.keys():    
        pred_pr = merged_ds[source]["pred_pr"].sel(model=model)

        IPython.display.display_markdown(f"#### {model}", raw=True)
        fig, axd = plt.subplot_mosaic([["model", "cpm", "difference"], ["clim model", "clim cpm", "clim difference"]], figsize=(9, 6), constrained_layout=True)
        
        seasonal_changes = {}
        for season in CCS_SEASONS:
            season_cpm_pr = cpm_pr.where(cpm_pr["time.season"] == season, drop=True)
            season_pred_pr = pred_pr.where(cpm_pr["time.season"] == season, drop=True)
            
            hist_season_cpm_pr_mean = (
                season_cpm_pr.where(season_cpm_pr["time_period"] == "historic", drop=True)
                .mean(dim=["grid_latitude", "grid_longitude", "time", "ensemble_member"])
            )
        
            fut_season_cpm_pr_mean = (
                season_cpm_pr.where(season_cpm_pr["time_period"] == "future", drop=True)
                .mean(dim=["grid_latitude", "grid_longitude", "time", "ensemble_member"])
            )
        
            season_cpm_pr_mean_change = fut_season_cpm_pr_mean-hist_season_cpm_pr_mean
            
            hist_mean_samples, fut_mean_samples = bootstrap_seasonal_mean_pr_change_samples(season_cpm_pr, season_pred_pr, nsamples=100_000)

            mean_cpm_change_samples = (100*(fut_mean_samples["target_pr"] - hist_mean_samples["target_pr"])/np.abs(season_cpm_pr_mean_change)).rename("cpm")
            mean_emu_change_samples = (100*(fut_mean_samples["pred_pr"] - hist_mean_samples["pred_pr"])/np.abs(season_cpm_pr_mean_change)).rename("emu")
            differences = (mean_emu_change_samples - mean_cpm_change_samples).rename("difference")            

            clim_mean_cpm_change_samples = (100*(fut_mean_samples["target_pr"] - hist_mean_samples["target_pr"])/np.abs(hist_season_cpm_pr_mean)).rename("clim_cpm")
            clim_mean_emu_change_samples = (100*(fut_mean_samples["pred_pr"] - hist_mean_samples["pred_pr"])/np.abs(hist_season_cpm_pr_mean)).rename("clim_emu")
            clim_differences = (clim_mean_emu_change_samples - clim_mean_cpm_change_samples).rename("clim_difference")            
            
            seasonal_changes[season] = xr.merge([mean_cpm_change_samples, mean_emu_change_samples, differences, clim_mean_cpm_change_samples, clim_mean_emu_change_samples, clim_differences])
            # print(seasonal_changes[season])

        for season in CCS_SEASONS:
            IPython.display.display_markdown(f"##### {season}", raw=True)

            ax = axd["model"]
            seasonal_changes[season]["emu"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            ax.legend()
            ax.set_title(f"{model}")
            
            ax = axd["cpm"]
            seasonal_changes[season]["cpm"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"CPM")

            ax = axd["difference"]
            seasonal_changes[season]["difference"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"Differences {model}")
            
            ax = axd["clim difference"]
            seasonal_changes[season]["clim_difference"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"Differences {model} relative to hist clim")

            ax = axd["clim model"]
            seasonal_changes[season]["clim_emu"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"{model} relative to hist clim")
            
            ax = axd["clim cpm"]
            seasonal_changes[season]["clim_cpm"].plot.hist(bins=50, ax=ax, label=season, alpha=0.75, histtype="step", linewidth=1, density=True)
            # ax.legend()
            ax.set_title(f"CPM relative to hist clim")
        
            alpha = 0.05
            IPython.display.display_html(seasonal_changes[season].quantile([alpha/2, 1-(alpha/2)]).to_dataframe().round(2).to_html(), raw=True)

        plt.show()

In [None]:
from mlde_notebooks import significance

# CCS_SEASONS = ["DJF", "JJA"]
# for season, season_ds in merged_ds["GCM"].isel(model=[0]).groupby("time.season"):
#     if season not in CCS_SEASONS:
#         continue
#     fig = plt.figure(figsize=(5.5, 2.5), layout="compressed")
#     axd = fig.subplot_mosaic(np.array(list(TIME_PERIODS.keys())).reshape(1, -1), sharex=True, subplot_kw=dict(projection=cp_model_rotated_pole))
    
#     for tp_idx, tp_key in enumerate(TIME_PERIODS.keys()):
#         tp_season_ds = season_ds.where(season_ds["time_period"] == tp_key, drop=True)

#         for model, ds in tp_season_ds.groupby("model"):
#             ttest_result = significance.significance_test(ds.squeeze())
#             ds["pvalue"] = xr.Variable(["grid_latitude", "grid_longitude"], ttest_result.pvalue)
    
#             N = len(ds["grid_longitude"]) * len(ds["grid_latitude"])
#             alpha_fdr = 0.1
#             pvalue_threshold = alpha_fdr*np.arange(1, N+1, step=1)/N        
#             sorted_pvalues = np.sort(np.ravel(ds["pvalue"]))
#             p_values_less_than_alpha_fdr_frac = np.nonzero(np.cumprod(sorted_pvalues <= pvalue_threshold))[0]
#             if len(p_values_less_than_alpha_fdr_frac) == 0:
#                 # no local tests are below the controlled FDR
#                 p_fdr_star = 0. 
#             else:
#                 idx_star = p_values_less_than_alpha_fdr_frac.max()
#                 p_fdr_star = sorted_pvalues[idx_star]
            
#             ax=axd[tp_key]
#             # plot_map(ds["pvalue"], ax=ax, add_colorbar=True, style=None)
#             plot_map(ds["pvalue"] <= p_fdr_star, ax=ax, add_colorbar=True, style=None, vmin=0, vmax=1)
#             ax.set_title(f"{tp_key}")
#             fig.suptitle(f"{season} {model}")
#     plt.show()

### Mean change maps

In [None]:
for ccs_pred_pr_da in ccs_pred_pr_das:
    changes = compute_changes([ccs_pred_pr_da], cpm_pr, ["DJF", "JJA"], stat_func=xr.DataArray.mean)
    mean_change_fig = plt.figure(figsize=(5.5, 4.5), layout="compressed")
    plot_changes(changes, ["DJF", "JJA"], mean_change_fig, show_change=[ccs_pred_pr_da["model"].data.item()])
    
    plt.show()

### Q99 change maps

In [None]:
for q in [0.99]:
    IPython.display.display_markdown(f"#### Quantile: {q}", raw=True)

    for ccs_pred_pr_da in ccs_pred_pr_das:
        changes = compute_changes([ccs_pred_pr_da], cpm_pr, ["DJF", "JJA"], stat_func=functools.partial(xr.DataArray.quantile, q=q))
        
        mean_change_fig = plt.figure(figsize=(5.5, 4.5), layout="compressed")
        plot_changes(changes, ["DJF", "JJA"], mean_change_fig, show_change=[ccs_pred_pr_da["model"].data.item()])
        
        plt.show()

### CCS mean Variablity

In [None]:
time_da = merged_ds["CPM"]["time"]

df = time_da.to_dataframe().drop_duplicates(["stratum", "dec_adjusted_year"])

from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

for ccs_pred_pr_da in ccs_pred_pr_das:
    for subsample_idx, (_, test_idx) in enumerate(skf.split(df[["dec_adjusted_year"]], df["stratum"])):
        fold_df = df.iloc[test_idx]
        fold_time_da = time_da.where(time_da["tp_season_year"].isin(fold_df["stratum"].str.cat(fold_df["dec_adjusted_year"].astype("str"), sep=' ').values), drop=True)
        
        ccs_pred_pr_da_subsamples = ccs_pred_pr_da.sel(time=fold_time_da.data)
        
        mean_changes = compute_changes([ccs_pred_pr_da_subsamples], merged_ds["CPM"]["target_pr"].sel(time=fold_time_da), CCS_SEASONS, stat_func=xr.DataArray.mean)
    
        mean_change_fig = plt.figure(figsize=(5.5, 5.5), layout="compressed")
        
        plot_changes(mean_changes, CCS_SEASONS, mean_change_fig, show_change=[ccs_pred_pr_da_subsamples["model"].data.item()])
    
        plt.show()