In [None]:
import xarray as xr
import matplotlib.pyplot as plt
from cartopy import crs as ccr
from pathlib import Path
import scipy.stats as sts
import numpy as np
plt.style.use("ggplot")

In [None]:
scratch = Path("/lcrc/group/e3sm/ac.mkelleher/scratch/chrys/")
case = "20221128.F2010.ne4_oQU240.dtcl_control"
case_dir = Path(scratch, case, "run", "aavg", "combined")
ninst = 8

In [None]:
files = {
    inst: sorted(case_dir.glob(f"{case}.eam_{inst:04d}*.nc"))
    for inst in range(1, ninst + 1)
}

In [None]:
ens_data = []
for inst in files:
    ens_data.append(
        xr.open_dataset(
            files[inst][0],
            decode_times=False
        )
    )
ens_data = xr.concat(ens_data, dim="ens")

In [None]:
# (ens_data - ens_data.mean(dim="ens"))["U"].plot.line(x="time")
test_var = "Q"
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
# (ens_data[test_var].isel(ens=slice(1, None)) - ens_data.isel(ens=0)[test_var]).plot.line(x="time", ax=axes[0, 0])
(ens_data - ens_data.mean(dim="ens"))[test_var].plot.line(x="time", ax=axes[0, 0])
# ens_data[test_var].plot.line(x="time", ax=axes[0, 0])
axes[0, 0].set_title(f"{test_var} ensemble spread")
ens_data.mean(dim="ens")[test_var].plot(ax=axes[0, 1])
axes[0, 1].set_title(f"{test_var} mean")
ens_data.std(dim="ens")[test_var].plot(ax=axes[1, 0])
axes[1, 0].set_title(f"{test_var} std dev")

ks_time = [
    sts.ks_2samp(
        ens_data[test_var].isel(time=_it, ens=slice(0, 4)),
        ens_data[test_var].isel(time=_it, ens=slice(4, 9))
    ) for _it in range(ens_data.time.shape[0])
]

ks_stat = np.array([_ks.statistic for _ks in ks_time])
ks_pval = np.array([_ks.pvalue for _ks in ks_time])
times = ens_data.time.values
axes[1, 1].plot(times, ks_stat, label="Statistic")
axes[1, 1].plot(times, ks_pval, color="C1", label="P-value")
axes[1, 1].plot(times[ks_pval < 0.05], ks_pval[ks_pval < 0.05], 'C1o')
axes[1, 1].set_title(f"{test_var} K-S Test")
plt.legend()
for _ax in axes.flatten():
    _ax.grid(visible=True)
plt.tight_layout()