# Sampling variability in dataset

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import math
import string

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy
import seaborn as sns
import xarray as xr
import xskillscore as xs

from mlde_utils import cp_model_rotated_pole, dataset_split_path
from mlde_analysis import plot_map
from mlde_analysis.data import open_dataset_split, attach_eval_coords, attach_derived_variables
from mlde_analysis.display import pretty_table
from mlde_analysis.distribution import xr_hist, hist_dist
from mlde_analysis.examples import em_timestamps

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
dataset="bham64_ccpm-4x_12em_mv"
split="val"
ensemble_members = ["01", "04"]
var="target_swbgt"
derived_variables_config = {
    "swbgt": [
        "mlde_analysis.derived_variables.swbgt",
        {"temp": "tmean150cm", "rh": "relhum150cm"},
    ]
}
thresholds=[5, 25]
plotted_thresholds = [5, 25]
niterations = 10

In [None]:
DS = attach_derived_variables(
    attach_eval_coords(
        open_dataset_split(dataset, split, ensemble_members)
    ), 
    derived_variables_config, prefixes=["target"]
)
DS

In [None]:
da = DS[var]
da

In [None]:
def resample_examples(da):
    return xs.resampling.resample_iterations_idx(da.stack(member=["ensemble_member", "time"]), niterations, "member", replace=True).unstack("member")

bs_da = da.groupby("stratum").map(resample_examples)

bs_da

## RMS bias

In [None]:
per_it_mean_grid = bs_da.mean(dim=["ensemble_member", "time"])
it_mean_mean_grid = per_it_mean_grid.mean(dim="iteration")
bs_mean_relative_bias_grid = per_it_mean_grid - it_mean_mean_grid
if var == "target_pr":
    bs_mean_relative_bias_grid = bs_mean_relative_bias_grid / it_mean_mean_grid
bs_rmsb_mean = bs_mean_relative_bias_grid.groupby("iteration").map(lambda group_da: np.sqrt((group_da ** 2).mean())).rename("bs_mean_rmsb")

per_it_std_grid = bs_da.std(dim=["ensemble_member", "time"])
it_mean_std_grid = per_it_std_grid.mean(dim="iteration")
bs_std_relative_bias_grid = per_it_std_grid - it_mean_std_grid
if var == "target_pr":
    bs_std_relative_bias_grid = 100 * bs_std_relative_bias_grid / it_mean_std_grid
bs_rmsb_std = bs_std_relative_bias_grid.groupby("iteration").map(lambda group_da: np.sqrt((group_da ** 2).mean())).rename("bs_std_rmsb")

bs_rmsb = xr.merge([bs_rmsb_mean, bs_rmsb_std])

pretty_table(bs_rmsb, round=4)
pretty_table(bs_rmsb.quantile([0.9], dim="iteration"), round=4)

In [None]:
bs_mean_relative_bias_grid.plot(col="iteration")
bs_std_relative_bias_grid.plot(col="iteration")

## Freq density histograms

In [None]:
bins = np.histogram_bin_edges(da, bins=50)

In [None]:
split_hist_da = xr_hist(da, bins)

bs_hists = bs_da.groupby("iteration").map(xr_hist, bins=bins)

bs_hists

In [None]:
for yscale in ["log", "linear"]:
    fig = plt.figure(figsize=(5.5, 3.5), layout="compressed")
    
    ax = fig.add_subplot()
    
    for itidx, group_da in bs_hists.groupby("iteration"):
        ax.stairs(
            group_da,
            bins,
            fill=False,
            linewidth=1,
        )
    for threshold in plotted_thresholds:
        ax.axvline(threshold, color="k", linestyle="--", linewidth=1)
    ax.set_yscale(yscale)
    ax.set_xlabel(xr.plot.utils.label_from_attrs(da=da))
    ax.set_ylabel("Freq. density")
    plt.show()

In [None]:
for log in [True, False]:
    fig = plt.figure(figsize=(5.5, 3.5), layout="compressed")
    
    ax = fig.add_subplot()
    
    for itidx, group_da in bs_da.groupby("iteration"):
        group_da.plot.hist(bins=bins, density=True, histtype="step", log=log, ax=ax)
    
    plt.show()

In [None]:
data = bs_hists.assign_coords({"bins": bins[:-1]}).to_pandas().reset_index().melt(id_vars="iteration")
for yscale in ["log", "linear"]:
    g_results = sns.lineplot(data=data, x="bins", y="value", errorbar=("pi", 90))#, ax=ax)
    g_results.set(yscale=yscale)
    plt.show()

### Freq density weight over threshold

In [None]:
per_it_thshd_exceedence_prop_da = xr.concat(
[ bs_da.groupby("iteration").map(lambda group_da: (group_da.where(group_da > threshold).count()/group_da.count())).expand_dims(dict(threshold=[threshold])) 
 for threshold in thresholds ],
dim="threshold").rename("threshold_exceedence")

per_it_thshd_exceedence_prop_diff_da = (per_it_thshd_exceedence_prop_da - per_it_thshd_exceedence_prop_da.mean(dim="iteration")).rename("threshold_exceedence_diff")

pretty_table(per_it_thshd_exceedence_prop_da, round=4)
pretty_table(per_it_thshd_exceedence_prop_da.quantile([0.05, 0.95], dim="iteration"), round=5)

pretty_table(per_it_thshd_exceedence_prop_diff_da, round=8)
pretty_table(per_it_thshd_exceedence_prop_diff_da.quantile([0.05, 0.95], dim="iteration"), round=8)

### JS distance variability of histograms

In [None]:
bs_distances = bs_hists.groupby("iteration").map(hist_dist, target_hist_da=bs_hists.mean(dim="iteration"))

pretty_table(bs_distances, round=4)
pretty_table(bs_distances.quantile([0.9], dim="iteration"), round=4)
bs_distances.plot.hist()

In [None]:
bs_distances = bs_hists.groupby("iteration").map(hist_dist, target_hist_da=split_hist_da)

pretty_table(bs_distances, round=4)
pretty_table(bs_distances.quantile([0.9], dim="iteration"), round=4)

bs_distances.plot.hist()

In [None]:
split_hist = np.histogram(da, bins=bins,density=True)[0]

a = np.zeros(shape=niterations)
for i, (_, group_da) in enumerate(bs_da.groupby("iteration")):
    group_hist = np.histogram(group_da, bins=bins, density=True)[0]
    a[i] = scipy.spatial.distance.jensenshannon(split_hist, group_hist)

a == bs_distances

## Correlation coeff

In [None]:
per_it_domain_mean = bs_da.mean(dim=["grid_latitude", "grid_longitude"])
split_domain_mean = da.mean(dim=["grid_latitude", "grid_longitude"])
per_it_corr = xr.corr(per_it_domain_mean, split_domain_mean, dim=["ensemble_member", "time"]).rename("corr")

pretty_table(per_it_corr, round=2)
per_it_corr.mean()