In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statsmodels.stats import multitest as smm

plt.style.use("ggplot")

In [None]:
def fmt_case(case):
    if case == "ctl":
        _out = "Control"
    elif case == "ctl-2mo":
        _out = "Control [2 mo]"
    elif case == "new-ctl":
        _out = "Control [new]"
    elif "old" in case:
        num = float(case.replace("old-", "").replace("pct", "").replace("p", "."))
        _out = f"old {num:.1f}%"
    elif "new" in case:
        num = float(case.replace("new-", "").replace("pct", "").replace("p", "."))
        _out = f"new {num:.1f}%"
    elif "c1" in case:
        num = float(case.split("-")[1].replace("pct", "").replace("p", ".").replace("-2mo", ""))
        _out = f"clubb_c1 {num:.1f}%"
    elif "gworo" in case:
        num = float(case.split("-")[1].replace("pct", "").replace("p", ".").replace("-2mo", ""))
        _out = f"GW orog {num:.1f}%"
    elif "pl" in case:
        num = float(case.split("_")[1].replace("pct", "").replace("p", ".").replace("-2mo", ""))
        _out = f"PertLim {num:.1e}"
    else:
        try:
            num = float(case.replace("pct", "").replace("p", ".").replace("-2mo", ""))
        except ValueError:
            num = None
        if num is not None:
            _out = f"{num:.1f}%"
        else:
            _out = case
    return _out

In [None]:
all_files = sorted(Path("./bootstrap_data").glob("bootstrap_output_*new*n500.nc"))
REJECT_THR = 0.05
ctl_key = tuple(["Control"] * 2)

run_len = "1year"
# run_len = "1month"

if run_len == "1month":
    cases = [
        ("new-ctl", "new-ctl"),
        # ("ctl", "effgw_oro-0p5pct"),
        # ("ctl", "effgw_oro-1pct"),
        # ("ctl", "effgw_oro-10pct"),
        ("new-ctl", "new-2p5pct"),
        ("new-ctl", "new-5pct"),
        ("new-ctl", "new-10pct"),
    ]
else:
    cases = [
        ("ctl", "ctl"),
        # ("ctl", "ctl-next"),
        # ("ctl", "effgw_oro-0p5pct"),
        ("ctl", "effgw_oro-1p0pct"),
        ("ctl", "effgw_oro-10p0pct"),
        ("ctl", "effgw_oro-20p0pct"),
        # ("ctl", "effgw_oro-30p0pct"),
        # ("ctl", "effgw_oro-40p0pct"),
        ("ctl", "effgw_oro-50p0pct"),
        # ("ctl", "clubb_c1-1p0pct"),
        # ("ctl", "clubb_c1-3p0pct"),
        # ("ctl", "clubb_c1-5p0pct"),
        # ("ctl", "clubb_c1-10p0pct"),
    ]



files = [
    Path('bootstrap_data/bootstrap_output.{}.{}_{}_n500.nc'.format(run_len, *case))
    for case in cases
]
for _file in files:
    print(_file, _file.exists())

In [None]:
ks_pval_cr = {}
n_reject = {}
n_reject_cr = {}
test = {}
test_cr = {}

for _ix, _file in enumerate(files):
    case_a, case_b = cases[_ix]
    n_iter = int(_file.stem.split("_")[-1][1:])

    ks_res = xr.open_dataset(_file)
    ks_pval = ks_res["pval"].values

    fig, axes = plt.subplots(1, 2, figsize=(15, 4))
    quantile = REJECT_THR * 100
    time_step = np.arange(ks_res.time.shape[0]) # * 2

    n_reject[cases[_ix]] = np.array((ks_pval < REJECT_THR).sum(axis=1))
    n_reject_mean = np.median(n_reject[cases[_ix]], axis=0)
    n_reject_lq = np.percentile(n_reject[cases[_ix]], quantile, axis=0)
    n_reject_uq = np.percentile(n_reject[cases[_ix]], 100 - quantile, axis=0)

    _pval_cr = []
    for jdx in range(ks_pval.shape[0]):
        _pval_cr.append(
            smm.fdrcorrection(
                ks_pval[jdx].flatten(),
                alpha=REJECT_THR,
                method="indep",
                is_sorted=False
            )[1].reshape(ks_pval[jdx].shape)
        )

    ks_pval_cr[(case_a, case_b)] = np.array(_pval_cr)

    n_reject_cr[cases[_ix]] = np.array((ks_pval_cr[(case_a, case_b)] < REJECT_THR).sum(axis=1))
    n_reject_mean_cr = np.median(n_reject_cr[cases[_ix]], axis=0)
    n_reject_lq_cr = np.percentile(n_reject_cr[cases[_ix]], quantile, axis=0)
    n_reject_uq_cr = np.percentile(n_reject_cr[cases[_ix]], 100 - quantile, axis=0)

    axes[0].plot(time_step, n_reject_mean, color="black", lw=1.5, label="Median"),
    axes[0].plot(time_step, n_reject_mean_cr, color="black", lw=1.5, ls="--", label="Median [Corrected]"),

    # ln_lq, = axes[0].plot(n_reject_lq, color="darkblue", lw=1.0, ls="-", label=f"{quantile}%"),
    # ln_uq, = axes[0].plot(n_reject_uq, color="darkred", lw=1.0, ls="-", label=f"{100 - quantile}%")
    axes[0].fill_between(time_step, n_reject_lq, n_reject_uq, color="C5", alpha=0.8, label=f"{100 - quantile}% CI")
    axes[0].fill_between(time_step, n_reject_lq_cr, n_reject_uq_cr, color="C6", alpha=0.2, label=f"{100 - quantile}% CI [corr]")


    axes[0].axhline(REJECT_THR * ks_pval.shape[1], color="#343", ls="-.", label=f"{REJECT_THR * 100}% of variables")
    axes[0].legend()

    axes[0].set_title(f"Number of variables rejected at {(1 - REJECT_THR) * 100}% confidence")
    axes[0].set_xlabel("Timestep")
    axes[0].set_ylabel("N variables")

    # test = (n_reject > ks_pval.shape[1] * REJECT_THR).sum(axis=0)
    test[cases[_ix]] = (n_reject[cases[_ix]] > 13).sum(axis=0)
    test_cr[cases[_ix]] = (n_reject_cr[cases[_ix]] > 1).sum(axis=0)

    axes[1].plot(test[cases[_ix]], label="Un-corrected")
    axes[1].plot(test_cr[cases[_ix]], label="Corrected")
    axes[1].axhline(
        REJECT_THR * ks_pval.shape[0],
        label=f"{REJECT_THR * 100}% of {ks_pval.shape[0]} tests",
        color="k",
        ls="--",
        lw=0.8
    )
    axes[1].set_title(f"Number of tests (of {ks_pval.shape[0]}) \"failing\"")
    axes[1].set_ylim([0, int(ks_pval.shape[0] * 1.1)])
    axes[1].legend()

    _reject = f"{REJECT_THR:.2f}".replace(".", "p")
    fig.suptitle(f"{case_a} x {case_b}")
    plt.tight_layout()
    plt.savefig(f"plt_{case_a}-{case_b}_n{n_iter}.png")

In [None]:
fig2, axes2 = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

pct = np.array([0, 1, 10, 20, 50])
yidx = np.array([0, 1, 2, 3, 4])
test_all = np.array([test[_case] for _case in cases])
test_all_cr = np.array([test_cr[_case] for _case in cases])
times = np.arange(1, test_all.shape[1] + 1)

edge_style = {"edgecolor": "grey", "lw": .5, "ls": "--"}

_cf1 = axes2[0].pcolormesh(
    times,
    yidx,
    test_all, #/ ks_pval.shape[0],
    **edge_style,
)
axes2[0].set_yticks(yidx, pct)
axes2[0].set_title("Uncorrected")
axes2[0].set_ylabel("% Parameter change")

_cf2 = axes2[1].pcolormesh(
    times,
    yidx,
    test_all_cr,# / ks_pval.shape[0],
    **edge_style,
)
axes2[1].set_yticks(yidx, pct)
axes2[1].set_title("Corrected")
axes2[1].set_ylabel("% Parameter change")
test_diff = (test_all_cr - test_all)

_cf3 = axes2[2].pcolormesh(
    times,
    yidx,
    (test_all_cr - test_all),# / ks_pval.shape[0],
    vmin=-100,
    vmax=100,
    cmap="RdBu_r",
    **edge_style,
)
axes2[2].set_yticks(yidx, pct)
axes2[2].set_title("Corrected - Uncorrected")
axes2[2].set_ylabel("% Parameter change")
axes2[2].set_xlabel("Simulation time [month]")

plt.colorbar(_cf1, pad=0.02)
plt.colorbar(_cf2, pad=0.02)
plt.colorbar(_cf3, extend="both", pad=0.02)
fig2.tight_layout()

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(14, 7))
axes = axes.flatten()

for _ix, _file in enumerate(files):
    case_a, case_b = cases[_ix]

    axes[_ix].plot(test[cases[_ix]], label="Un-corrected")
    axes[_ix].plot(test_cr[cases[_ix]], label="Corrected")
    axes[_ix].axhline(
        REJECT_THR * ks_pval.shape[0],
        label=f"{REJECT_THR * 100}% of {ks_pval.shape[0]} tests",
        color="k",
        ls="--",
        lw=0.8
    )
    # axes[_ix].set_title(titles[_ix])
    axes[_ix].set_title(f"{case_a} x {case_b}")
    # _reject = f"{REJECT_THR:.2f}".replace(".", "p")

# fig.suptitle(f"{case_a} x {case_b}")
axes[_ix].legend()
fig.suptitle(f"Number of tests (of {ks_pval.shape[0]}) \"failing\"")
fig.tight_layout()
fig.savefig(f"plt_zmconv_c0_4panel_n{n_iter}.pdf")

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(12, 6))
quantile = REJECT_THR * 100
reject_test = {}
n_reject = {}

reject_test_cr = {}
n_reject_cr = {}

for idx, _file in enumerate(files):
    case_a, case_b = cases[idx]
    n_iter = int(_file.stem.split("_")[-1][1:])
    case_a = fmt_case(case_a)
    case_b = fmt_case(case_b)


    ks_res = xr.open_dataset(_file)
    ks_pval = ks_res["pval"].values

    time_step = np.arange(ks_res.time.shape[0])

    n_reject[(case_a, case_b)] = np.array((ks_pval < REJECT_THR).sum(axis=1))
    n_reject_mean = np.median(n_reject[(case_a, case_b)], axis=0)
    n_reject_lq = np.percentile(n_reject[(case_a, case_b)], quantile, axis=0)
    n_reject_uq = np.percentile(n_reject[(case_a, case_b)], 100 - quantile, axis=0)

    _pval_cr = []
    print(f"APPLY FDR CORRECTION FOR {ks_pval.shape[0]} iters")
    for jdx in range(ks_pval.shape[0]):
        _pval_cr.append(
            smm.fdrcorrection(
                ks_pval[jdx].flatten(),
                alpha=REJECT_THR,
                method="indep",
                is_sorted=False
            )[1].reshape(ks_pval[jdx].shape)
        )
    ks_pval_cr = np.array(_pval_cr)

    n_reject_cr[(case_a, case_b)] = np.array((ks_pval_cr < REJECT_THR).sum(axis=1))
    n_reject_mean_cr = np.median(n_reject_cr[(case_a, case_b)], axis=0)
    n_reject_lq_cr = np.percentile(n_reject_cr[(case_a, case_b)], quantile, axis=0)
    n_reject_uq_cr = np.percentile(n_reject_cr[(case_a, case_b)], 100 - quantile, axis=0)

    reject_test[(case_a, case_b)] = n_reject_uq
    reject_test_cr[(case_a, case_b)] = n_reject_uq_cr

    ln_uq, = axes.plot(time_step, n_reject_uq, color=f"C{idx}", lw=1.5, label=f"{case_a} x {case_b}"),
    ln_uq_cr, = axes.plot(time_step, n_reject_uq_cr, color=f"C{idx}", lw=1.5, ls="--", label=f"{case_a} x {case_b} [corr]"),

    # ln_lq, = axes[0].plot(n_reject_lq, color="darkblue", lw=1.0, ls="-", label=f"{quantile}%"),
    # ln_uq, = axes[0].plot(n_reject_uq, color="darkred", lw=1.0, ls="-", label=f"{100 - quantile}%")
    axes.fill_between(time_step, n_reject_lq, n_reject_uq, color=ln_uq[0].get_color(), alpha=0.4)

axes.axhline(REJECT_THR * ks_pval.shape[1], color="#343", ls="-.", label=f"{REJECT_THR * 100} % of variables")
axes.set_title(f"Number of variables rejected at {(1 - REJECT_THR) * 100}% confidence")
axes.set_xlabel("Timestep")
axes.set_ylabel("N variables")

axes.legend()

_reject = f"{REJECT_THR:.2f}".replace(".", "p")
plt.tight_layout()
plt.savefig(f"plt_all_cases_new.png")

In [None]:
fig, axis = plt.subplots(1, 1)
_ = axis.hist(
    [n_reject_cr[(case_a, case_a)][:, -1], n_reject[(case_a, case_b)][:, -1]],
    label=[f"{case_a} x {case_a} [ctl]", f"{case_a} x {case_b} [prt]"],
    edgecolor="k",
    bins=np.arange(0, 22, 2),
)

axis.axvline(REJECT_THR * ks_pval.shape[1], color="#343", ls="-.", label=f"{REJECT_THR * 100}% of variables")
axis.set_xlabel(f"Number of rejected (p < {REJECT_THR}) variables")
axis.set_ylabel(f"Frequency (of {n_reject[(case_a, case_a)][:, -1].shape[0]} iterations)")
_ = plt.legend()

In [None]:
test = {
    _key: (n_reject[_key] > 13).sum(axis=0)
    for _key in reject_test
}
test_cr = {
    _key: (n_reject_cr[_key] > REJECT_THR * ks_pval.shape[1]).sum(axis=0)
    for _key in reject_test
}
fig, axes = plt.subplots(1, 1, figsize=(12, 5))

for idx, _case in enumerate(test):
    axes.plot(test[_case], color=f"C{idx}", label=f"{_case[0]} x {_case[1]}", lw=2.3)
    axes.plot(test_cr[_case], color=f"C{idx}", ls="-", label=f"{_case[0]} x {_case[1]} [corr]", lw=1.0)
axes.set_ylabel(f"Number of iterations")
axes.set_xlabel("Time step")
axes.set_title("")
axes.legend()

In [None]:
fig, axes = plt.subplots(len(files), 2, figsize=(10, 2.5 * len(files)), sharex=True)
bins = np.arange(0, 121, 2)
ex_val = 500 / ((bins.shape[0] - 1) / 30)

for file_ix, _file in enumerate(files):
    case_a, case_b = cases[file_ix]
    n_iter = int(_file.stem.split("_")[-1][1:])
    ks_res = xr.open_dataset(_file)

    for _case in [0, 1]:
        axes[file_ix, _case].axhline(ex_val, color="k", ls="--")
        _ = ks_res.rnd_idx[_case].plot.hist(edgecolor="k", bins=bins, ax=axes[file_ix, _case])
        axes[file_ix, _case].set_title(cases[file_ix][_case])
        axes[file_ix, _case].set_xlabel("")

In [None]:
fig, axes = plt.subplots(len(files), 1, figsize=(10, 5))#, sharex=True)
bins = np.arange(0, 121, 4)
ex_val = 500 / ((bins.shape[0] - 1) / 30)

for file_ix, _file in enumerate(files):
    case_a, case_b = cases[file_ix]
    case_a = fmt_case(case_a)
    case_b = fmt_case(case_b)
    n_iter = int(_file.stem.split("_")[-1][1:])
    ks_res = xr.open_dataset(_file)
    ks_pval = ks_res.pval
    reject_by_var = pd.DataFrame((ks_pval < REJECT_THR).sum(axis=0).T, columns=ks_res.vars)
    mask = (reject_by_var.sum() > reject_by_var.sum().quantile(.9))

    reject_by_var.index.name = "Time Step"
    sns.heatmap(reject_by_var.T[mask], ax=axes[file_ix], vmin=0, vmax=500)
    axes[file_ix].set_title(f"{case_a} x {case_b}")
    print(reject_by_var.T[mask].sort_values(2, ascending=False))
fig.tight_layout()