# Sampling variability in dataset

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy
import seaborn as sns
import xarray as xr
import xskillscore as xs

from mlde_utils import cp_model_rotated_pole, dataset_split_path
from mlde_analysis import plot_map
from mlde_analysis.data import open_dataset_split, attach_eval_coords, attach_derived_variables
from mlde_analysis.display import pretty_table, VAR_RANGES
from mlde_analysis.distribution import xr_hist, hist_dist
from mlde_analysis.bootstrap import resample_examples
from mlde_analysis.fractional_contribution import fc_bins, fc_binval, compute_fractional_contribution, frac_contrib_change, plot_fractional_contribution

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
dataset="demo-ccpm_mv"
split="val"
ensemble_members = ["01", "04"]
var="target_pr"
derived_variables_config = {
    "swbgt": [
        "mlde_analysis.derived_variables.swbgt",
        {"temp": "tmean150cm", "rh": "relhum150cm"},
    ]
}
thresholds=[5, 25]
plotted_thresholds = [5, 25]
niterations = 10

In [None]:
DS = attach_derived_variables(
    attach_eval_coords(
        open_dataset_split(dataset, split, ensemble_members)
    ), 
    derived_variables_config, prefixes=["target"]
)
DS

In [None]:
da = DS[var]
da

In [None]:
da = da.where((da["time.season"].isin(["JJA", "DJF"])) & (da["time_period"].isin(["historic", "future"])), drop=True)

bs_da = da.groupby("time_period").map(lambda tpda: tpda.groupby("time.season").map(resample_examples, niterations=niterations))

bs_da

## Per-time-period freq density histograms

In [None]:
# hist_bins = np.histogram_bin_edges(da, bins=50)
hist_bins = np.histogram_bin_edges([], bins=50, range=VAR_RANGES[var.replace("target_", "")])

bs_hists = bs_da.groupby("time.season").map(lambda season_da: season_da.groupby("iteration").map(lambda it_da: it_da.groupby("time_period").map(xr_hist, bins=hist_bins)))

bs_hists

In [None]:
colors = {"future": "tab:blue", "historic": "tab:orange"}
for season, season_bs_hists in bs_hists.groupby("season"):
    for yscale in ["log", "linear"]:
        fig = plt.figure(figsize=(3.5, 2), layout="compressed")
        
        ax = fig.add_subplot()
        legend_handles=[]
        for tp, tp_hists in season_bs_hists.groupby("time_period"):
            
            for itidx, group_da in tp_hists.groupby("iteration"):
                stairs_artist = ax.stairs(
                    group_da,
                    hist_bins,
                    fill=False,
                    linewidth=1,
                    color=colors[tp],
                    label=tp,
                    alpha=0.5
                )
                if itidx == 0:
                    legend_handles.append(stairs_artist)
        ax.legend(handles=legend_handles)
        ax.set_yscale(yscale)
        ax.set_xlabel(xr.plot.utils.label_from_attrs(da=da))
        ax.set_ylabel("Freq. density")
        ax.set_title(f"{season}")
        plt.show()

## Fractional contribution change

In [None]:
fcbins = fc_bins()

bs_fc_change = bs_da.groupby("time.season").map(lambda season_da: season_da.groupby("iteration").map(frac_contrib_change, bins=fcbins))

In [None]:
for season, season_bs_fc_change in bs_fc_change.groupby("season"):
    fig = plt.figure(layout='constrained', figsize=(4.5, 3))
    axd = fig.subplot_mosaic([["Change"]])
    ax = axd["Change"]
    
    frcontrib_change_data = [ 
        dict(data=(group_da.values, fc_binval(fcbins)), label=f"CPM {itidx}", color="tab:blue", source="CPM")
        for itidx, group_da in season_bs_fc_change.groupby("iteration")
    ]
    
    plot_fractional_contribution(frcontrib_change_data, ax=ax, title=f"{season} Change from Historic to Future", alpha=0.25, linewidth=1, legend=False, ylim=[-0.4, 0.4])

In [None]:
for season, season_bs_fc_change in bs_fc_change.groupby("season"):
    
    data = season_bs_fc_change.assign_coords({"bins": fc_binval(fcbins)}).to_pandas().reset_index().melt(id_vars="iteration")
    
    fig = plt.figure(layout='constrained', figsize=(4.5, 3))
    axd = fig.subplot_mosaic([["Change"]])
    ax = axd["Change"]
    
    g_results = sns.lineplot(data=data, x="bins", y="value", errorbar=("pi", 90), linewidth=1, ax=ax)
    g_results.set(xscale="log")
    g_results.set(
        title=f"{season} Change from Historic to Future",
        xlabel="Precip (mm/day)",
        ylabel="Change in frac. contrib.",
        xlim=[0.1, 200.0],
        ylim=[-0.4, 0.4],
    )
    plt.show()

In [None]:
for season, season_bs_fc_change in bs_fc_change.groupby("season"):
    
    data = season_bs_fc_change.assign_coords({"bins": fc_binval(fcbins)}).to_pandas().reset_index().melt(id_vars="iteration")

    g_results = sns.lineplot(data=data, x="bins", y="value", errorbar=("pi", 90), linewidth=1)
    g_results.set(xscale="log")
    g_results.set(
        title=f"{season} Change from Historic to Future",
        xlabel="Precip (mm/day)",
        ylabel="Change in frac. contrib.",
        xlim=[0.1, 200.0],
        ylim=[-0.4, 0.4],
    )
    plt.show()

## Change in quantiles

In [None]:
qs = 1-np.power(10.0, np.arange(-2, -10, -1))

bs_quantiles_da = bs_da.groupby("time.season").map(lambda season_bs_da: season_bs_da.groupby("iteration").quantile(q=qs, dim=...)).rename("quantiles")

_ = pretty_table(xr.merge([
    bs_quantiles_da.groupby("season").map(lambda season_bs_da: season_bs_da.groupby("quantile").quantile(q=0.05, dim=...)).rename("5th%ile").drop("quantile"),
    bs_quantiles_da.groupby("season").map(lambda season_bs_da: season_bs_da.groupby("quantile").mean(...)).rename("mean"),
    bs_quantiles_da.groupby("season").map(lambda season_bs_da: season_bs_da.groupby("quantile").quantile(q=0.95, dim=...)).rename("95th%ile").drop("quantile"),]), round=1, caption="Bootstrapped quantile spread")

In [None]:
for season, season_bs_quantiles_da in bs_quantiles_da.groupby("season"): 
    fig = plt.figure(layout='constrained', figsize=(3.5, 2.5))
    axd = fig.subplot_mosaic([["Change"]])
    ax = axd["Change"]
    
    season_bs_quantiles_da["quantile"] = -np.log10(1-season_bs_quantiles_da["quantile"]).round().astype(int)
    season_bs_quantiles_da = season_bs_quantiles_da.rename(quantile="nines")
    season_bs_quantiles_da.plot.scatter(ax=ax, alpha=0.5, add_legend=False, s=5)
    plt.show()

    df = season_bs_quantiles_da.to_pandas().reset_index().melt(id_vars="iteration")
    
    g_results = sns.boxplot(data=df, x="nines", y="value")
    g_results.set(title=f"{season}")
    plt.show()