In [None]:
import xarray as xr
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import scipy.stats as sts
import json
import random
import pandas as pd
from functools import partial
from statsmodels.stats import multitest as smm
import multiprocessing as mp
import seaborn as sns

In [None]:
def randomise_new(ens_min, ens_max, ens_size, with_repl=False, ncases=2, uniq=False):
    ens_idx = sorted(range(ens_min, ens_max + 1))
    assert len(ens_idx) > ens_size, "ENSEMBLE SIZE MUST BE SMALLER THAN ENSEMBLE RANGE"
    if not with_repl and not uniq:
        selected = [random.sample(ens_idx, ens_size) for _ in range(ncases)]
    elif not with_repl:
        _sel = random.sample(ens_idx, ens_size * ncases)
        selected = [
            _sel[idx * ens_size : (idx + 1) * ens_size] for idx in range(ncases)
        ]
    else:
        selected = [
            [random.randint(ens_min, ens_max) for _ in range(ens_size)]
            for _ in range(ncases)
        ]
    return selected


def rolling_mean_data(data, period_len=12, time_var="time"):
    select = {time_var: period_len}
    return data.rolling(**select).mean().dropna(time_var)


def ks_pval(data_x, data_y):
    _res = sts.mstats.ks_2samp(data_x, data_y)
    return _res[1]


def cvm_2samp(data_x, data_y):
    """Perform a 2 sample Cramer von Mises test, map output to a tuple."""
    _res = sts.cramervonmises_2samp(data_x, data_y)
    return _res.pvalue


def anderson_pval(data_1, data_2):
    try:
        _res = sts.anderson_ksamp(
            [data_1, data_2], method=sts.PermutationMethod(n_resamples=100)
        )
    except ValueError:
        return 1.0
    return _res.pvalue


ks_test_vec = np.vectorize(ks_pval, signature="(n),(n)->()")
cvm_test_vec = np.vectorize(cvm_2samp, signature="(n),(n)->()")
anderson_test_vec = np.vectorize(anderson_pval, signature="(n),(n)->()")


def mannwhitney(data_1, data_2):
    return sts.mannwhitneyu(data_1, data_2, axis=1).pvalue


def epps_singleton(data_1, data_2):
    # print(data_1.shape, data_2.shape)
    try:
        _out = sts.epps_singleton_2samp(data_1, data_2, axis=1).pvalue
    except np.linalg.LinAlgError:
        _out = np.ones(data_1.shape[0])
    return _out


def test_all_times(data, ens_ids, test_fcn):
    """Perform statistical test on two arrays across all times in the array.

    Parameters
    ----------
    data_1, data_2 : array_like
        Arrays of data for testing, dimension 2 (typically [ensemble, time]),
        with time dimension as the rightmost dimension.

    Returns
    -------
    test_output : `xarray.DataArray`
        Array with shape [data_n.shape[1]] of 2 sample statistical test p-value

    """
    data_1 = data.isel(exp=0, ens=ens_ids[0])
    data_2 = data.isel(exp=1, ens=ens_ids[1])

    _pval = test_fcn(data_1.T, data_2.T)
    try:
        _out = xr.DataArray(data=_pval, dims=("time",), coords={"time": data.time})
    except ValueError as _err:
        print(_err)
        return None

    return _out


def bootstrap_test(ens_ids, data, test_fcn):
    return data.apply(test_all_times, ens_ids=ens_ids, test_fcn=test_fcn)


def convert_to_array(pvals_in):
    """Convert pvals DataSet to an array of shape [n_iter, n_vars, n_times]

    Parameters
    ----------
    pvals_in : xarray.Dataset
        pvalues for each output field at each bootstrap iteration and time

    Returns
    -------
    pvals : numpy.ndarray
        pvalues array with shape [Num output field, Num bootstrap iteration, Num times]

    """
    pvals_out = pvals_in.to_array().values
    return np.swapaxes(pvals_out, 0, 1)

In [None]:
%%time
scratch = Path("/lcrc/group/e3sm/ac.mkelleher/scratch/chrys")
# scratch = Path("/home/mikek/Code/2025-09-16.F2010.ne30pg2_r05_oECv3_aavgs")
# scratch = Path("/home/mikek/Code/detclim_data")
# in_dirs = sorted(scratch.glob("*"))
in_dirs = [
    # Path(scratch, "20230321.F2010.ne4_oQU240.dtcl_pertlim_1e-10_n0120/run"),
    # Path(scratch, "20230721.F2010.ne4_oQU240.dtcl_clubb_c1_2p520000_n0120/run"),
    Path(scratch, "2025-09-16.F2010.ne30pg2_r05_oECv3_control", "run"),
    Path(scratch, "2025-09-16.F2010.ne30pg2_r05_oECv3_clubb_c1_2p520000", "run"),
]

_ds_ctl = xr.open_mfdataset(
    sorted(in_dirs[0].glob("*eam.h0.aavg.nc")), combine="nested", concat_dim="ens"
).load()

_ds_exp = xr.open_mfdataset(
    sorted(in_dirs[1].glob("*eam.h0.aavg.nc")), combine="nested", concat_dim="ens"
).load()

# _ds_all = xr.concat([_ds_ctl, _ds_exp], dim="exp")
_ds_all = xr.concat([_ds_ctl, _ds_exp], dim="exp")
dvars = json.loads(open("../new_vars.json", "r", encoding="utf-8").read())["default"]

_ds_all_mean = _ds_all[dvars].map(rolling_mean_data, period_len=12)
_emin = _ds_all_mean.ens.values.min()
_emax = _ds_all_mean.ens.values.max()

In [None]:
case_names = ["ctl", "clubb_c1_2p51000"]
# case_names = ["ctl", "ctl"]
_ds_all["exp"] = case_names
_ds_all.coords

In [None]:
ninst = 60
ens_size = 29
if case_names[0] == case_names[1]:
    unique = True
else:
    unique = False
ens_sel = [
    randomise_new(_emin, _emax, ens_size=ens_size, ncases=2, uniq=unique)
    for _ in range(ninst)
]

In [None]:
%%time
# ks_bootsrap_part = partial(ks_bootstrap, data=_ds_all_mean[dvars])
ks_bootstrap_part = partial(
    bootstrap_test, data=_ds_all_mean[dvars], test_fcn=ks_test_vec
)
with mp.Pool(16) as pool:
    pvals_out_ks = xr.concat(pool.map(ks_bootstrap_part, ens_sel), dim="iter")

In [None]:
%%time
es_bootstrap_part = partial(
    bootstrap_test, data=_ds_all_mean[dvars], test_fcn=epps_singleton
)
with mp.Pool(16) as pool:
    pvals_out_es = xr.concat(pool.map(es_bootstrap_part, ens_sel), dim="iter")

In [None]:
%%time
mw_bootstrap_part = partial(
    bootstrap_test, data=_ds_all_mean[dvars], test_fcn=mannwhitney
)
with mp.Pool(16) as pool:
    pvals_out_mw = xr.concat(pool.map(mw_bootstrap_part, ens_sel), dim="iter")

%%time
anderson_bootstrap_part = partial(bootstrap_test, data=_ds_all_mean[dvars], test_fcn=anderson_test_vec)
with mp.Pool(16) as pool:
    pvals_out_anderson_100 = xr.concat(pool.map(anderson_bootstrap_part, ens_sel), dim="iter")

In [None]:
%%time
cvm_bootstrap_part = partial(
    bootstrap_test, data=_ds_all_mean[dvars], test_fcn=cvm_test_vec
)
with mp.Pool(16) as pool:
    pvals_out_cvm = xr.concat(pool.map(cvm_bootstrap_part, ens_sel), dim="iter")

%%time
andr_bootstrap_part = partial(bootstrap_test, data=_ds_all_mean[dvars], test_fcn=anderson_test_vec)
with mp.Pool(16) as pool:
    pvals_out_andr = xr.concat(pool.map(andr_bootstrap_part, ens_sel), dim="iter")

In [None]:
ks_pvals = convert_to_array(pvals_out_ks)
print(ks_pvals.shape)

In [None]:
pvals_all = {
    "ks": convert_to_array(pvals_out_ks),
    "cvm": convert_to_array(pvals_out_cvm),
    "mw": convert_to_array(pvals_out_mw),
    # "es": convert_to_array(pvals_out_es),
    # "andr": convert_to_array(pvals_out_andr),
}

fig, axis = plt.subplots(3, 1, figsize=(12, 10), sharey=False)
axis = axis.flatten()
for idx, pvals_out in enumerate(pvals_all):
    pvals = pvals_all[pvals_out][:, :, -1].T
    pvals.sort(axis=0)
    _ = axis[idx].semilogy(pvals, color="grey", lw=0.5)
    _ = axis[idx].semilogy(np.median(pvals, axis=1), color="k")
    _ = axis[idx].axhline(0.05, ls="--", color="green")
    axis[idx].set_title(pvals_out)
plt.tight_layout()

Methods
-------
- bonferroni : one-step correction
- sidak : one-step correction
- holm-sidak : step down method using Sidak adjustments
- holm : step-down method using Bonferroni adjustments
- simes-hochberg : step-up method (independent)
- hommel : closed method based on Simes tests (non-negative)
- fdr_bh : Benjamini/Hochberg (non-negative)
- fdr_by : Benjamini/Yekutieli (negative)
- fdr_tsbh : two stage fdr correction (non-negative)
- fdr_tsbky : two stage fdr correction (non-negative)


In [None]:
fig, axis = plt.subplots(3, 2, figsize=(15, 15), sharey=False)
axis = axis.flatten()
for idx, pvals_out in enumerate(pvals_all):
    pvals = pvals_all[pvals_out][:, :, -1].T
    pvals.sort(axis=0)
    # _ = axis[idx].semilogy(pvals, color="grey", lw=0.5)
    _ = axis[idx].semilogy(np.median(pvals, axis=1), color="k")
    # methods = ["fdr_bh", "fdr_by", "bonferroni", "sidak", "holm-sidak", "holm", "simes-hochberg", "hommel", "fdr_tsbh", "fdr_tsbky"]
    methods = [
        "fdr_bh",
        "fdr_by",
        "bonferroni",
        "sidak",
        "holm-sidak",
        "simes-hochberg",
        "hommel",
        "fdr_tsbh",
        "fdr_tsbky",
    ]
    # methods = ["fdr_bh"]
    for _method in methods:
        _pval_cr = np.array(
            [
                smm.multipletests(
                    pvals=pvals[:, kdx],
                    alpha=0.05,
                    method=_method,
                    is_sorted=False,
                )[1]
                for kdx in range(pvals.shape[1])
            ]
        )
        _ = axis[idx].semilogy(np.median(_pval_cr, axis=0), label=_method)
    _ = axis[idx].axhline(0.05, ls="--", color="green")
    axis[idx].set_title(pvals_out)
    axis[idx].grid(visible=True, ls="--")
plt.legend()

In [None]:
nreject = {
    mode: [
        (pvals_all[mode][i, :, -1] < 0.05).sum()
        for i in range(pvals_all[mode].shape[0])
    ]
    for mode in pvals_all
}
nreject_cr = {}
for mode in pvals_all:
    # methods = ["fdr_bh", "fdr_by", "bonferroni", "sidak", "holm-sidak", "simes-hochberg", "hommel", "fdr_tsbh", "fdr_tsbky"]
    # for _method in methods:
    _method = "fdr_bh"
    _pval_cr = np.array(
        [
            smm.multipletests(
                pvals=pvals_all[mode][:, kdx, -1],
                alpha=0.05,
                method=_method,
                is_sorted=False,
            )[1]
            for kdx in range(pvals_all[mode].shape[1])
        ]
    )
    nreject_cr[mode] = [
        (_pval_cr[:, i] < 0.05).sum() for i in range(pvals_all[mode].shape[0])
    ]

In [None]:
# plt.figure(figsize=(12, 5))
bins = np.arange(27, 45, 1)
fig, axes = plt.subplots(2, 2, figsize=(12, 5), sharex=True, sharey=True)
axes = axes.flatten()
for idx, mode in enumerate(nreject):
    axes[idx].hist(nreject[mode], bins=bins, edgecolor="k")
    axes[idx].set_title(mode)
plt.tight_layout()

In [None]:
# pd.DataFrame(nreject)
# help(sns.boxplot)
_rejdf = pd.DataFrame(nreject)
_rejdf["Mode"] = "Uncorrected"
_dfcorr = pd.DataFrame(nreject_cr)
_dfcorr["Mode"] = "FDR-BH"
nreject_df = pd.concat([_rejdf, _dfcorr])
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True)
sns.boxplot(_rejdf, ax=axes[0], orient="h")
sns.boxplot(_dfcorr, ax=axes[1], orient="h")

In [None]:
plt.figure(figsize=(12, 5))
bins = np.arange(27, 45, 1)
for idx, mode in enumerate(nreject):
    plt.subplot(2, 2, idx + 1)
    # plt.hist(nreject[mode], bins=bins, edgecolor="k")
    plt.hist(nreject_cr[mode], bins=15, edgecolor="k")
    # plt.gca().set_xlim([26, 46])
    # plt.gca().set_ylim([0, 40])
    plt.title(mode)
plt.tight_layout()

In [None]:
nreject_cr["ks"]

In [None]:
for mode in nreject:
    print(
        f"{mode:9s}: "
        f"{np.percentile(nreject[mode], 5):.2f} - "
        f"{np.percentile(nreject[mode], 50):.2f} - "
        f"{np.percentile(nreject[mode], 95):.2f}"
    )

In [None]:
def to_dataarray(pvals, data_vars, times, test_name):
    out_coords = {
        "iter": np.arange(pvals.shape[0]),
        "vars": data_vars,
        "time": times,
    }
    return xr.DataArray(
        data=pvals,
        coords=out_coords,
        dims=("iter", "vars", "time"),
        attrs={
            "units": "",
            "desc": f"2-sample {test_name} p-value",
            "long_name": f"{test_name}_pvalue",
            "short_name": f"{test_name}_pvalue",
        },
    )

In [None]:
ds_out = {}
for _test in pvals_all:
    ds_out[f"{_test}_pval"] = to_dataarray(
        pvals_all[_test], dvars, pvals_out_ks.time, _test
    )
xr.Dataset(ds_out)

In [None]:
plt.loglog(
    pvals_all["ks"][:, :, -1].flatten(),
    pvals_all["cvm"][:, :, -1].flatten(),
    ".",
    alpha=0.5,
)

In [None]:
_ds_all
mwu = sts.mannwhitneyu(
    _ds_all["T"].isel(exp=0).values, _ds_all["T"].isel(exp=1).values, axis=0
)

In [None]:
esp = sts.epps_singleton_2samp(
    _ds_all["T"].isel(exp=0).values, _ds_all["T"].isel(exp=1).values, axis=0
).pvalue