In [None]:
import xarray as xr
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import scipy.stats as sts
import json
import random
from functools import partial
import multiprocessing as mp

from dask.distributed import Client

In [None]:
def ks_all_times(data, ens_ids):
    """Perform K-S test on two arrays across all times in the array.

    Parameters
    ----------
    data_1, data_2 : array_like
        Arrays of data for testing, dimension 2 (typically [ensemble, time]),
        with time dimension as the rightmost dimension.

    Returns
    -------
    ks_test_output : `da.array`
        Dask array with shape [data_n.shape[1], 2] of 2 sample K-S test
        results (statstic, p-value)

    """
    data_1 = data.isel(exp=0, ens=ens_ids[0])
    data_2 = data.isel(exp=1, ens=ens_ids[1])

    ks_test = np.vectorize(sts.mstats.ks_2samp, signature="(n),(n)->(),()")
    _, ks_pval = ks_test(data_1.T, data_2.T)

    return xr.DataArray(
        data=ks_pval, dims=("time",), coords={"time": data.time}
    )
ks_test_vec = np.vectorize(sts.mstats.ks_2samp, signature="(n),(n)->(),()")
def ks_vec(data_1, data_2):
    return ks_test_vec(data_1, data_2)

def randomise_new(ens_min, ens_max, ens_size, with_repl=False, ncases=2):
    ens_idx = sorted(range(ens_min, ens_max + 1))
    assert len(ens_idx) > ens_size, "ENSEMBLE SIZE MUST BE SMALLER THAN ENSEMBLE RANGE"
    if not with_repl:
        selected = [
            random.sample(ens_idx, ens_size)
            for _ in range(ncases)
        ]
    else:
        selected = [
            [random.randint(ens_min, ens_max) for _ in range(ens_size)]
            for _ in range(ncases)
        ]
    return selected


def rolling_mean_data(data, period_len=12, time_var="time"):
    select = {time_var: period_len}
    return data.rolling(**select).mean().dropna(time_var)


def ks_bootstrap(idx, data):
    return data.apply(ks_all_times, ens_ids=idx)


def cvm_2samp(data_x, data_y):
    """Perform a 2 sample Cramer von Mises test, map output to a tuple."""
    _res = sts.cramervonmises_2samp(data_x, data_y)
    return _res.pvalue
    

cvm_test_vec = np.vectorize(cvm_2samp, signature="(n),(n)->()")


def cvm_all_times(data_c, ens_ids):
    """Perform a 2 sample Cramer von Mises test on all times."""

    data_1 = data_c.isel(exp=0, ens=ens_ids[0])
    data_2 = data_c.isel(exp=1, ens=ens_ids[1])

    cvm_pval = cvm_test_vec(data_1.T, data_2.T)

    return xr.DataArray(
        data=cvm_pval, dims=("time",), coords={"time": data_c.time}
    )

def cvm_bootstrap(idx, data):
    return data.apply(cvm_all_times, ens_ids=idx)

def anderson_pval(data_1, data_2):
    try:
        _res = sts.anderson_ksamp([data_1, data_2], method=sts.PermutationMethod(n_resamples=1000))
    except ValueError:
        return 1.
    return _res.pvalue

anderson_test_vec = np.vectorize(anderson_pval, signature="(n),(n)->()")

def anderson_all_times(data, ens_ids):
    data_1 = data.isel(exp=0, ens=ens_ids[0])
    data_2 = data.isel(exp=1, ens=ens_ids[1])
    _pval = anderson_test_vec(data_1.T, data_2.T)

    return xr.DataArray(
        data=_pval, dims=("time",), coords={"time": data.time}
    )

    
def anderson_bootstrap(idx, data):
    return data.apply(anderson_all_times, ens_ids=idx)

In [None]:
%%time
scratch = Path("/home/mikek/Code/2025-09-16.F2010.ne30pg2_r05_oECv3_aavgs")
in_dirs = sorted(scratch.glob("*"))
_ds_ctl = xr.open_mfdataset(
    sorted(in_dirs[1].glob("*.nc")), combine="nested", concat_dim="ens"
)

_ds_exp = xr.open_mfdataset(
    sorted(in_dirs[0].glob("*.nc")), combine="nested", concat_dim="ens"
)

_ds_all = xr.concat([_ds_ctl, _ds_exp], dim="exp")
dvars = json.loads(
    open("../new_vars.json", "r", encoding="utf-8").read()
)["default"]

_ds_all_mean = _ds_all[dvars].map(rolling_mean_data, period_len=12).load()
_emin = _ds_all_mean.ens.values.min()
_emax = _ds_all_mean.ens.values.max()

In [None]:
ninst = 100
ens_size = 20
ens_sel = [randomise_new(_emin, _emax, ens_size=ens_size, ncases=2) for _ in range(ninst)]

In [None]:
%%time
ks_bootsrap_part = partial(ks_bootstrap, data=_ds_all_mean[dvars])
with mp.Pool(16) as pool:
    pvals_out_ks = xr.concat(pool.map(ks_bootsrap_part, ens_sel), dim="iter")

In [None]:
%%time
anderson_bootstrap_part = partial(anderson_bootstrap, data=_ds_all_mean[dvars])
with mp.Pool(16) as pool:
    pvals_out_anderson = xr.concat(pool.map(anderson_bootstrap_part, ens_sel), dim="iter")


In [None]:
%%time
cvm_bootstrap_part = partial(cvm_bootstrap, data=_ds_all_mean[dvars])
with mp.Pool(16) as pool:
    pvals_out_cvm = xr.concat(pool.map(cvm_bootstrap_part, ens_sel), dim="iter")

In [None]:
pvals_all = {
    "ks": np.array([pvals_out_ks.isel(time=2)[_var].values for _var in pvals_out_ks.data_vars]),
    "cvm": np.array([pvals_out_cvm.isel(time=2)[_var].values for _var in pvals_out_cvm.data_vars]),
    "anderson": np.array([pvals_out_anderson.isel(time=2)[_var].values for _var in pvals_out_anderson.data_vars]),
}

fig, axis = plt.subplots(1, 3, figsize=(12, 5), sharey=True)
for idx, pvals_out in enumerate(pvals_all):
    pvals = pvals_all[pvals_out]
    pvals.sort(axis=0)
    _ = axis[idx].semilogy(pvals, color="grey", lw=0.5)
    _ = axis[idx].semilogy(pvals.mean(axis=1), color="k")
    _ = axis[idx].axhline(0.05, ls="--", color="green")
    axis[idx].set_title(pvals_out)

In [None]:
nreject = {
    mode: [(pvals_all[mode][:, i] < 0.05).sum() for i in range(pvals_all[mode].shape[1])]
    for mode in pvals_all
}

In [None]:
plt.figure(figsize=(12, 5))
for idx, mode in enumerate(nreject):
    plt.subplot(1, 3, idx + 1)
    plt.hist(nreject[mode], bins=15, edgecolor="k")
    plt.title(mode)