In [None]:
import xarray as xr
import matplotlib.pyplot as plt
from cartopy import crs as ccr
from pathlib import Path
import scipy.stats as sts
import numpy as np
import random
import json
plt.style.use("ggplot")
img_fmt = "pdf"

In [None]:
scratch = Path("/lcrc/group/e3sm/ac.mkelleher/scratch/chrys/")
# case_abbr = ["ctl_2mo", "1pct_2mo"]
# case_abbr = ["ctl", "new-2p5pct"]
# case_abbr = ["new-ctl", "gworo-10pct"]
# case_abbr = ["pert-1e-10", "pert-1e-14"]
case_abbr = ["ctl", "clubb_c1-10p0pct"]
# case_abbr = ["ctl", "ctl14"]
run_len = "1year"

with open("case_db.json", "r", encoding="utf-8") as _cdb:
    cases = json.loads(_cdb.read())

case_dirs = {_case: Path(scratch, cases[run_len][_case], "run") for _case in case_abbr}
ninst = 120
REJECT_THR = 0.05

In [None]:
files = {
    _case: 
    {
        inst: sorted(case_dirs[_case].glob(f"{cases[run_len][_case]}.eam_{inst:04d}*aavg.nc"))
        for inst in range(1, ninst + 1)
    }
    for _case in case_abbr
}

In [None]:
ens_data = {}
for _case in case_abbr:
    ens_data[_case] = []
    for inst in files[_case]:
        ens_data[_case].append(
            # xr.open_dataset(
                # files[_case][inst][0],
            xr.open_mfdataset(
                files[_case][inst],
                decode_times=False
            )
        )
    ens_data[_case] = xr.concat(ens_data[_case], dim="ens")

In [None]:
ens_shuffle = {_case: random.sample(list(ens_data[_case].ens.values), 30) for _case in case_abbr}

ens_loo = {}
for _case in case_abbr:
    case_ens = set(ens_data[_case].ens.values)
    shuf_ens = set(ens_shuffle[_case])
    
    if len(case_ens) >= len(shuf_ens):
        ens_loo[_case] = random.sample(list(case_ens.difference(shuf_ens)), 1)[0]
    else:
        ens_loo[_case] = random.sample(list(shuf_ens.difference(case_ens)), 1)[0]

print(ens_loo)

In [None]:
# (ens_data - ens_data.mean(dim="ens"))["U"].plot.line(x="time")
test_var = "U"
tslice = slice(0, None)
data_a = ens_data[case_abbr[0]][test_var].isel(ens=ens_shuffle[case_abbr[0]], time=tslice)
data_b = ens_data[case_abbr[1]][test_var].isel(ens=ens_shuffle[case_abbr[1]], time=tslice)
# times = ens_data[case_abbr[0]].time.values
times = data_a.time.values

In [None]:
plt.clf()
plot_diff_mean = False
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
# (ens_data[test_var].isel(ens=slice(1, None)) - ens_data.isel(ens=0)[test_var]).plot.line(x="time", ax=axes[0, 0])
# (ens_data - ens_data.mean(dim="ens"))[test_var].plot.line(x="time", ax=axes[0, 0], add_legend=False)

if plot_diff_mean:
    # Plot against mean for group
    (data_a - data_a.mean(dim="ens")).plot.line(x="time", ax=axes[0, 0], label=case_abbr[0], color="C0", add_legend=False)
    (data_b - data_b.mean(dim="ens")).plot.line(x="time", ax=axes[0, 0], label=case_abbr[1], color="C1", add_legend=False)
else:
    # Plot against leave-one-out for each case (if n_test < (n_ens-1) then it's the first one left out)
    (data_a - ens_data[case_abbr[0]][test_var].isel(ens=ens_loo[case_abbr[0]])).plot.line(
        x="time", ax=axes[0, 0], label=case_abbr[0], color="C0", add_legend=False, lw=0.5
    )
    (data_b - ens_data[case_abbr[1]][test_var].isel(ens=ens_loo[case_abbr[1]])).plot.line(
        x="time", ax=axes[0, 0], label=case_abbr[1], color="C1", add_legend=False, lw=0.5
    )

# ens_data[test_var].plot.line(x="time", ax=axes[0, 0])
axes[0, 0].set_title(f"{test_var} ensemble spread")

# ens_data.mean(dim="ens")[test_var].plot(ax=axes[0, 1], label="Overall")
aline, = data_a.mean(dim="ens").plot(ax=axes[0, 1], label=case_abbr[0])
bline, = data_b.mean(dim="ens").plot(ax=axes[0, 1], label=case_abbr[1])

ax_diff = axes[0, 1].twinx()
# diffline, = (data_a.mean(dim="ens") - data_b.mean(dim="ens")).pipe(np.abs).plot(ax=ax_diff, label="Difference", color="grey")
# diffline, = (data_a.mean(dim="ens") - data_b.mean(dim="ens")).plot(ax=ax_diff, label="Difference", color="grey")
diffline, = ((data_a - data_b)**2).pipe(np.sqrt).mean(dim="ens").plot(ax=ax_diff, label="RMSD", color="grey")

ax_diff.set_ylabel(f"{test_var} difference")
axes[0, 1].legend(handles=[aline, bline, diffline])
axes[0, 1].set_title(f"{test_var} mean")

# ens_data.std(dim="ens")[test_var].plot(ax=axes[1, 0])
data_a.std(dim="ens").plot(ax=axes[1, 0], label=case_abbr[0])
data_b.std(dim="ens").plot(ax=axes[1, 0], label=case_abbr[1])
axes[1, 0].set_title(f"{test_var} std dev")

ks_time = [
    sts.ks_2samp(
        data_a.isel(time=_it).values,
        data_b.isel(time=_it).values
    ) for _it in range(data_a.time.shape[0])
]

ks_stat = np.array([_ks.statistic for _ks in ks_time])
ks_pval = np.array([_ks.pvalue for _ks in ks_time])
ax_pval = axes[1, 1].twinx()

ks_line, = axes[1, 1].plot(times, ks_stat, label="Statistic", lw=1)
pv_line, = ax_pval.plot(times, ks_pval, color="C1", label="P-value", lw=1)
pv_points = ax_pval.plot(times[ks_pval < REJECT_THR], ks_pval[ks_pval < REJECT_THR], 'C1o', ms=2)
axes[1, 1].set_ylim([0, 1.0])
# ax_pval.set_ylim([0, 0.1])
ax_pval.axhline(REJECT_THR, color="C1", ls="--", alpha=0.5)

axes[1, 1].set_title(f"{test_var} K-S Test")
axes[1, 1].legend(handles=[ks_line, pv_line])
axes[1, 1].set_ylabel("Test statistic", color=ks_line.get_color())
ax_pval.set_ylabel("Test p-value", color=pv_line.get_color())
for _ax in axes.flatten():
    _ax.grid(visible=True)
plt.tight_layout()
plt.savefig(f"plt_{case_abbr[0]}x{case_abbr[1]}_ensemble_{test_var}_4panel.{img_fmt}")

In [None]:
fig, axis = plt.subplots(1, 1, figsize=(8, 4))
time_slice = slice(0, data_a.shape[0])
lna = (data_a[0] + (data_a - ens_data[case_abbr[0]][test_var].isel(ens=ens_loo[case_abbr[0]])).isel(time=time_slice)).plot.line(
    x="time", ax=axis, label=case_abbr[0], color="C0", lw=0.6, add_legend=False
)
lnb = (data_b[0] + (data_b - ens_data[case_abbr[1]][test_var].isel(ens=ens_loo[case_abbr[1]])).isel(time=time_slice)).plot.line(
    x="time", ax=axis, label=case_abbr[1], color="C1", lw=0.6, add_legend=False
)
# data_a.isel(time=time_slice).mean(dim="ens").plot.line(x="time", color="k", ax=axis, add_legend=False)
# data_b.isel(time=time_slice).mean(dim="ens").plot.line(x="time", color="k", ls="--", ax=axis, add_legend=False)

# .plot.line(x="time", ax=axis, color="k", ls="--")
plt.legend([lna[0], lnb[0]], case_abbr)
plt.savefig(f"plt_{case_abbr[0]}x{case_abbr[1]}_ensemble_{test_var}.{img_fmt}")

In [None]:
def ks_all_times(data_a, data_b):
    stat = []
    pval = []
    for _it in range(data_a.time.shape[0]):
        _stat, _pval = sts.ks_2samp(
            data_a.isel(time=_it).values,
            data_b.isel(time=_it).values
        )
        stat.append(_stat)
        pval.append(_pval)
    return np.array(stat), np.array(pval)

In [None]:
%%time
data_vars = sorted(json.load(open("run_scripts/new_vars.json"))["default"])
niter = 2
ks_stat = []
ks_pval = []

for i in range(niter):
    if (i % 10 == 0) or (i == niter - 1):
        print(i)
    ens_shuffle = {
        _case: random.sample(list(ens_data[_case].ens.values), 30)
        for _case in case_abbr
    }
    ens_loo = {
        _case: list(set(ens_data[_case].ens.values).difference(ens_shuffle[_case]))#[0]
        for _case in case_abbr
    }
    # print(ens_loo)
    ks_stat_i = []
    ks_pval_i = []

    for test_var in data_vars:
        data_a = ens_data[case_abbr[0]][test_var].isel(ens=ens_shuffle[case_abbr[0]])
        # data_b = ens_data[case_abbr[0]][test_var].isel(ens=ens_loo[case_abbr[0]])
        data_b = ens_data[case_abbr[1]][test_var].isel(ens=ens_shuffle[case_abbr[1]])
        if hasattr(data_a, "time"):
            _stat, _pval = ks_all_times(data_a, data_b)
            ks_stat_i.append(_stat)
            ks_pval_i.append(_pval)

    ks_stat.append(np.array(ks_stat_i))
    ks_pval.append(np.array(ks_pval_i))
    
ks_stat = np.array(ks_stat)
ks_pval = np.array(ks_pval)

In [None]:
fig, axis = plt.subplots(1, 1, figsize=(6, 3))
REJECT_THR = 0.05
axis.plot((ks_pval < REJECT_THR).sum(axis=1).T)
axis.axhline(REJECT_THR * ks_pval.shape[1], color="black", ls="--")
axis.set_title(f"Number of variables rejected at {(1 - REJECT_THR) * 100}% confidence")
axis.set_xlabel("Timestep")
axis.set_ylabel("N variables")

In [None]:
fig, axis = plt.subplots(1, 1, figsize=(6, 3))
REJECT_THR = 0.05
n_reject = (ks_pval < REJECT_THR).sum(axis=1)
n_reject_mean = np.median(n_reject, axis=0)

quantile = 10
n_reject_lq = np.percentile(n_reject, quantile, axis=0)
n_reject_uq = np.percentile(n_reject, 100 - quantile, axis=0)

n_reject_std = n_reject.std(axis=0)
# axis.plot(n_reject.T, color="grey", lw=0.5)
axis.plot(n_reject_mean, color="black", lw=1.5, label="Median")
axis.plot(n_reject_lq, color="darkblue", lw=1.0, ls="-", label=f"{quantile}%")
axis.plot(n_reject_uq, color="darkred", lw=1.0, ls="-", label=f"{100 - quantile}%")
axis.axhline(0.05 * ks_pval.shape[1], color="#343", ls="-.")
plt.legend()
# axis.axhline(REJECT_THR, color="black", ls="--")
axis.set_title(f"Number of variables rejected at {(1 - REJECT_THR) * 100}% confidence")
axis.set_xlabel("Timestep")
axis.set_ylabel("N variables")
_reject = f"{REJECT_THR:.2f}".replace(".", "p")
plt.tight_layout()
plt.savefig(f"plt_nreject_{case_abbr[0]}-{case_abbr[1]}_a{_reject}_n{niter}.{img_fmt}")