In [None]:
from dask.distributed import Client
import dask
import numpy as np
import dask.array as da
from scipy import stats as sts
import matplotlib.pyplot as plt
import random

import xarray as xr
from pathlib import Path
import json

plt.style.use("ggplot")

In [None]:
# scheduler_file = "scheduler.json"
# dask.config.config["distributed"]["dashboard"]["link"] = "{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status"


client = Client(n_workers=8, threads_per_worker=1)
client

In [None]:
scratch = Path("/lcrc/group/e3sm/ac.mkelleher/scratch/chrys/")
case_abbr = ["ctl", "5pct"]
cases = {
    "ctl": "20221130.F2010.ne4_oQU240.dtcl_control_n0030",
    "5pct": "20221205.F2010.ne4_oQU240.dtcl_zmconv_c0_0p00201_n0030",
    "10pct": "20221201.F2010.ne4_oQU240.dtcl_zmconv_c0_0p0022_n0030",
    "50pct": "20221206.F2010.ne4_oQU240.dtcl_zmconv_c0_0p0030_n0030",
}
case_dirs = {_case: Path(scratch, cases[_case], "run") for _case in case_abbr}
ninst = 120
REJECT_THR = 0.05

In [None]:
files = {
    _case: {
        inst: sorted(case_dirs[_case].glob(f"{cases[_case]}.eam_{inst:04d}*aavg.nc"))
        for inst in range(1, ninst + 1)
    }
    for _case in case_abbr
}
ens_data = {}
for _case in case_abbr:
    ens_data[_case] = []
    for inst in files[_case]:
        ens_data[_case].append(
            xr.open_dataset(files[_case][inst][0], decode_times=False)
        )
    ens_data[_case] = xr.concat(ens_data[_case], dim="ens")

In [None]:
files = {
    _case: sorted(case_dirs[_case].glob(f"{cases[_case]}.eam*aavg.nc"))
    for _case in case_abbr
}

ens_data = {}
for _case in case_abbr:
    ens_data[_case] = []
    for _file in files[_case]:
        ens_data[_case].append(
            xr.open_dataset(
                _file,
                # decode_times=False
            )
        )
    ens_data[_case] = xr.concat(ens_data[_case], dim="ens")

# return ens_data

In [None]:
ens_data[_case]["time"]

In [None]:
# @dask.delayed
def ks_rand_sel_xarray(data_1, data_2):
    return da.array(
        [
            sts.ks_2samp(data_1[:, _tix], data_2[:, _tix], method="asymp")
            for _tix in range(data_1.shape[1])
        ]
    )

In [None]:
ens_data[_case]["T"]

In [None]:
%%time
random.seed(101_114)
futures = []
data_vars = sorted(json.load(open("run_scripts/new_vars.json"))["default"])
time_var = "time"
ens_var = "ens"
test_size = 30
n_iter = 5

idx_0 = [
    random.sample(
        list([_ for _ in range(ens_data[case_abbr[0]][ens_var].shape[0])]), test_size
    )
    for _ in range(n_iter)
]
idx_1 = [
    random.sample(
        list([_ for _ in range(ens_data[case_abbr[1]][ens_var].shape[0])]), test_size
    )
    for _ in range(n_iter)
]

for rse in range(n_iter):
    var_futures = []
    data_0 = ens_data[case_abbr[0]].isel(**{ens_var: idx_0[rse]})
    data_1 = ens_data[case_abbr[1]].isel(**{ens_var: idx_1[rse]})

    for test_var in data_vars:
        var_futures.append(
            client.submit(ks_rand_sel_xarray, data_0[test_var], data_1[test_var])
        )
    futures.append(var_futures)

results = da.array(dask.compute(*client.gather(futures)))
ks_stat = results[..., 0]
ks_pval = results[..., 1]

In [None]:
fig, axis = plt.subplots(1, 1, figsize=(6, 3))
REJECT_THR = 0.05

n_reject = np.array((ks_pval < REJECT_THR).sum(axis=1))
n_reject_mean = np.median(n_reject, axis=0)
quantile = 10
n_reject_lq = np.percentile(n_reject, quantile, axis=0)
n_reject_uq = np.percentile(n_reject, 100 - quantile, axis=0)

n_reject_std = n_reject.std(axis=0)
# axis.plot(n_reject.T, color="grey", lw=0.5)
axis.plot(n_reject_mean, color="black", lw=1.5, label="Median")
axis.plot(n_reject_lq, color="darkblue", lw=1.0, ls="-", label=f"{quantile}%")
axis.plot(n_reject_uq, color="darkred", lw=1.0, ls="-", label=f"{100 - quantile}%")
axis.axhline(0.05 * ks_pval.shape[1], color="#343", ls="-.")
plt.legend()
# axis.axhline(REJECT_THR, color="black", ls="--")
axis.set_title(f"Number of variables rejected at {(1 - REJECT_THR) * 100}% confidence")
axis.set_xlabel("Timestep")
axis.set_ylabel("N variables")
_reject = f"{REJECT_THR:.2f}".replace(".", "p")
plt.tight_layout()

In [None]:
random_index = np.array(
    [
        [
            random.sample(
                list(range(ens_data[_abbr][ens_var].shape[0])),
                test_size,
            )
            for _ in range(n_iter)
        ]
        for _abbr in case_abbr
    ]
)
random_index.mean()

In [None]:
out_coords = {
    "iter": np.arange(5),
    "vars": data_vars,
    "time": ens_data[case_abbr[0]]["time"],
}
out_dims = ("iter", "vars", "time")

ks_stat_xr = xr.DataArray(
    np.array(ks_stat),
    coords=out_coords,
    dims=out_dims,
    attrs={
        "units": "",
        "desc": "2-sample K-S test P-value",
        "long_name": "kolmogorov_smirnov_test_p_value",
        "short_name": "ks_pval",
    },
)
ks_pval_xr = xr.DataArray(
    np.array(ks_pval),
    coords=out_coords,
    dims=out_dims,
    attrs={
        "units": "",
        "desc": "2-sample K-S test statistic",
        "long_name": "kolmogorov_smirnov_test_statistic",
        "short_name": "ks_stat",
    },
)

rnd_idx = xr.DataArray(
    random_index,
    coords={"case": [0, 1], "iter": out_coords["iter"], "index": np.arange(30)},
    attrs={
        "units": "",
        "desc": "Index of ensemble members for each case and iteration",
    },
)
ks_ds = xr.Dataset({"stat": ks_stat_xr, "pval": ks_pval_xr, "rnd_idx": rnd_idx})

In [None]:
ks_ds.to_netcdf("testout.nc")

In [None]:
!ncdump -ch testout.nc