In [None]:
import xarray as xr
import matplotlib.pyplot as plt
from cartopy import crs as ccr
from pathlib import Path
import scipy.stats as sts
import numpy as np
import random
plt.style.use("ggplot")

In [None]:
scratch = Path("/lcrc/group/e3sm/ac.mkelleher/scratch/chrys/")
# case = "20221130.F2010.ne4_oQU240.dtcl_control_n0030"
# case = "20221201.F2010.ne4_oQU240.dtcl_zmconv_c0_0p0022_n0030"
# case = "20221205.F2010.ne4_oQU240.dtcl_zmconv_c0_0p00201_n0030"
# case = "20230321.F2010.ne4_oQU240.dtcl_pertlim_1e-10_n0120"
case = "20230322.F2010.ne4_oQU240.dtcl_pertlim_1e-14_n0120"
case_dir = Path(scratch, case, "run")#, "combined")
ninst = 120

In [None]:
files = {
    inst: sorted(case_dir.glob(f"{case}.eam_{inst:04d}*aavg.nc"))
    for inst in range(1, ninst + 1)
}


In [None]:
ens_data = []
for inst in files:
    ens_data.append(
        xr.open_mfdataset(
            files[inst],
            decode_times=False
        )
    )
ens_data = xr.concat(ens_data, dim="ens")

In [None]:
ens_data

In [None]:
ens_shuffle = random.sample(list(ens_data.ens.values), 60)
set_a = ens_shuffle[:30]
set_b = ens_shuffle[30:]

In [None]:
# (ens_data - ens_data.mean(dim="ens"))["U"].plot.line(x="time")
test_var = "T"

data_a = ens_data[test_var].isel(ens=set_a)
data_b = ens_data[test_var].isel(ens=set_b)

fig, axes = plt.subplots(2, 2, figsize=(10, 8))
# (ens_data[test_var].isel(ens=slice(1, None)) - ens_data.isel(ens=0)[test_var]).plot.line(x="time", ax=axes[0, 0])
# (ens_data - ens_data.mean(dim="ens"))[test_var].plot.line(x="time", ax=axes[0, 0], add_legend=False)
(data_a - ens_data.mean(dim="ens")[test_var]).plot.line(x="time", ax=axes[0, 0], label="Set A", color="C0", add_legend=False)
(data_b - ens_data.mean(dim="ens")[test_var]).plot.line(x="time", ax=axes[0, 0], label="Set B", color="C1", add_legend=False)
# ens_data[test_var].plot.line(x="time", ax=axes[0, 0])
axes[0, 0].set_title(f"{test_var} ensemble spread")

ens_data.mean(dim="ens")[test_var].plot(ax=axes[0, 1], label="Overall")
data_a.mean(dim="ens").plot(ax=axes[0, 1], label="Set A")
data_b.mean(dim="ens").plot(ax=axes[0, 1], label="Set B")
axes[0, 1].legend()

axes[0, 1].set_title(f"{test_var} mean")

ens_data.std(dim="ens")[test_var].plot(ax=axes[1, 0])
data_a.std(dim="ens").plot(ax=axes[1, 0], label="Set A")
data_b.std(dim="ens").plot(ax=axes[1, 0], label="Set B")
axes[1, 0].set_title(f"{test_var} std dev")

ks_time = [
    sts.ks_2samp(
        data_a.isel(time=_it).values,
        data_b.isel(time=_it).values
    ) for _it in range(ens_data.time.shape[0])
]

ks_stat = np.array([_ks.statistic for _ks in ks_time])
ks_pval = np.array([_ks.pvalue for _ks in ks_time])
times = ens_data.time.values
axes[1, 1].plot(times, ks_stat, label="Statistic")
axes[1, 1].plot(times, ks_pval, color="C1", label="P-value")
axes[1, 1].plot(times[ks_pval < 0.05], ks_pval[ks_pval < 0.05], 'C1o')
axes[1, 1].set_title(f"{test_var} K-S Test")
plt.legend()
for _ax in axes.flatten():
    _ax.grid(visible=True)
plt.tight_layout()