In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statsmodels.stats import multitest as smm

plt.style.use("ggplot")

In [None]:
REJECT_THR = 0.05

ctl_key = tuple(["PertLim 1.0e-10"] * 2)
run_len = "1year"
cases = [
    ("ctl", "ctl"),
    ("ctl", "effgw_oro-0p5pct"),
    ("ctl", "effgw_oro-1pct"),
    ("ctl", "effgw_oro-10pct"),
    # ("ctl", "effgw_oro-20pct"),
    ("ctl", "effgw_oro-30pct"),
    ("ctl", "effgw_oro-40pct"),
    ("ctl", "effgw_oro-50pct"),

]

files = [
    Path('bootstrap_data/bootstrap_output.{}.{}_{}_n500.nc'.format(run_len, *case))
    for case in cases
]
print(files)

In [None]:
ks_pval_cr = {}

for _ix, _file in enumerate(files):
    case_a, case_b = cases[_ix]
    n_iter = int(_file.stem.split("_")[-1][1:])
    
    ks_res = xr.open_dataset(_file)
    ks_pval = ks_res["pval"].values

    fig, axes = plt.subplots(1, 2, figsize=(15, 4))
    quantile = REJECT_THR * 100
    time_step = np.arange(ks_res.time.shape[0])

    n_reject = np.array((ks_pval < REJECT_THR).sum(axis=1))
    n_reject_mean = np.median(n_reject, axis=0)
    n_reject_lq = np.percentile(n_reject, quantile, axis=0)
    n_reject_uq = np.percentile(n_reject, 100 - quantile, axis=0)

    _pval_cr = []
    for jdx in range(ks_pval.shape[0]):
        _pval_cr.append(
            smm.fdrcorrection(
                ks_pval[jdx].flatten(),
                alpha=REJECT_THR,
                method="indep",
                is_sorted=False
            )[1].reshape(ks_pval[jdx].shape)
        )
    ks_pval_cr[(case_a, case_b)] = np.array(_pval_cr)
    n_reject_cr = np.array((ks_pval_cr[(case_a, case_b)] < REJECT_THR).sum(axis=1))
    n_reject_mean_cr = np.median(n_reject_cr, axis=0)
    n_reject_lq_cr = np.percentile(n_reject_cr, quantile, axis=0)
    n_reject_uq_cr = np.percentile(n_reject_cr, 100 - quantile, axis=0)

    rejections = {
        f"{100 * (1 - REJECT_THR)}%": n_reject_uq,
        f"{100 * (1 - REJECT_THR)}% [Corrected]": n_reject_cr.max(axis=0)
    }
    width = 0.5
    mult = 0
    for name, nreject in rejections.items():
        offset = width * mult
        rect = axes[0].bar(time_step + offset, nreject, width=0.45, label=name)
        axes[0].bar_label(rect, padding=3, color=rect[-1].get_facecolor())
        mult += 1
        
    axes[0].bar_label(rect, padding=3)
    axes[0].set_xticks(time_step, ["1-12", "2-13", "3-14"])
    
    axes[0].axhline(
        REJECT_THR * ks_pval.shape[1],
        color="#343",
        ls="-.",
        label=f"{REJECT_THR * 100}% of variables"
    )
    axes[0].legend()

    axes[0].set_title(
        f"Number of variables rejected at {(1 - REJECT_THR) * 100}% confidence"
    )
    axes[0].set_xlabel("Simulated month")
    axes[0].set_ylabel("N variables")

    test = {
        "Un-corrected": (n_reject > 13).sum(axis=0),
        "Corrected": (n_reject_cr > ks_pval.shape[1] * REJECT_THR).sum(axis=0)
    }
    
    mult = 0
    for name, itest in test.items():
        offset = width * mult
        rect = axes[1].bar(time_step + offset, itest, width=0.45, label=name)
        axes[1].bar_label(rect, padding=3, color=rect[-1].get_facecolor(), zorder=10)
        mult += 1

    axes[1].legend()
    axes[1].set_xticks(time_step, ["1-12", "2-13", "3-14"])
    axes[1].set_xlabel("Simulated month")

    axes[1].set_title(f"Number of tests (of {ks_pval.shape[0]}) \"failing\"")

    _reject = f"{REJECT_THR:.2f}".replace(".", "p")
    fig.suptitle(f"{case_a} x {case_b}")
    plt.tight_layout()
    plt.savefig(f"plt_{case_a}-{case_b}_n{n_iter}.png")

In [None]:
def fmt_case(case):
    if case == "ctl":
        _out = "Control"
    elif case == "ctl-2mo":
        _out = "Control [2 mo]"
    elif case == "new-ctl":
        _out = "Control [new]"
    elif "old" in case:
        num = float(case.replace("old-", "").replace("pct", "").replace("p", "."))
        _out = f"old {num:.1f}%"
    elif "new" in case:
        num = float(case.replace("new-", "").replace("pct", "").replace("p", "."))
        _out = f"new {num:.1f}%"
    elif "c1" in case:
        num = float(case.split("-")[1].replace("pct", "").replace("p", ".").replace("-2mo", ""))
        _out = f"clubb_c1 {num:.1f}%"
    elif "gworo" in case or "effgw_oro" in case:
        if "yr" in case:
            num = float(case.split("-")[-1].replace("pct", "").replace("p", ".").replace("-2mo", ""))
            _out = f"GW orog {num:.1f}%"
        else:
            num = float(case.split("-")[1].replace("pct", "").replace("p", ".").replace("-2mo", ""))
            _out = f"GW orog {num:.1f}%"
    elif "pl" in case:
        num = float(case.split("_")[1].replace("pct", "").replace("p", ".").replace("-2mo", ""))
        _out = f"PertLim {num:.1e}"
    else:
        num = float(case.replace("pct", "").replace("p", ".").replace("-2mo", ""))
        _out = f"{num:.1f}%"
    return _out

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(12, 6))
quantile = REJECT_THR * 100
reject_test = {}
n_reject = {}

reject_test_cr = {}
n_reject_cr = {}


bar_width = 0.3  # the width of the bars

for idx, _file in enumerate(files):
    # print(_file.stem.split("_"))
    # case_a, case_b = _file.stem.split("_")[2:4]
    case_a, case_b = cases[idx]
    n_iter = int(_file.stem.split("_")[-1][1:])
    case_a = fmt_case(case_a)
    case_b = fmt_case(case_b)
    ks_res = xr.open_dataset(_file)
    ks_pval = ks_res["pval"].values
    
    time_step = np.arange(ks_res.time.shape[0])

    n_reject[(case_a, case_b)] = np.array((ks_pval < REJECT_THR).sum(axis=1))
    n_reject_mean = np.median(n_reject[(case_a, case_b)], axis=0)
    n_reject_lq = np.percentile(n_reject[(case_a, case_b)], quantile, axis=0)
    n_reject_uq = np.percentile(n_reject[(case_a, case_b)], 100 - quantile, axis=0)
    
    _pval_cr = []
    print(f"APPLY FDR CORRECTION FOR {ks_pval.shape[0]} iters")
    for jdx in range(ks_pval.shape[0]):
        _pval_cr.append(
            smm.fdrcorrection(
                ks_pval[jdx].flatten(),
                alpha=REJECT_THR,
                method="indep",
                is_sorted=False
            )[1].reshape(ks_pval[jdx].shape)
        )
    ks_pval_cr = np.array(_pval_cr)
    
    n_reject_cr[(case_a, case_b)] = np.array((ks_pval_cr < REJECT_THR).sum(axis=1))
    n_reject_mean_cr = np.median(n_reject_cr[(case_a, case_b)], axis=0)
    n_reject_lq_cr = np.percentile(n_reject_cr[(case_a, case_b)], quantile, axis=0)
    n_reject_uq_cr = np.percentile(n_reject_cr[(case_a, case_b)], 100 - quantile, axis=0)
    
    reject_test[(case_a, case_b)] = n_reject_uq
    reject_test_cr[(case_a, case_b)] = n_reject_uq_cr

    # ln_uq, = axes.bar(time_step, n_reject_uq, color=f"C{idx}", lw=1.5, label=f"{case_a} x {case_b}"),
    # ln_uq_cr, = axes.bar(time_step, n_reject_uq_cr, color=f"C{idx}", lw=1.5, ls="--", label=f"{case_a} x {case_b} [corr]"),

    offset = bar_width * idx
    
    if False:
        ln_uq, = axes.bar(
            time_step - offset / 2,
            n_reject_uq,
            width=width  / 2,
            color=f"C{idx}",
            label=f"{case_a} x {case_b}",
            edgecolor="k",
            alpha=0.5,
        ),
    ln_uq_cr, = axes.bar(
        time_step + offset,
        n_reject_uq_cr,
        width=bar_width,
        color=f"C{idx}",
        label=f"{case_a} x {case_b} [corr]",
        edgecolor="k",
        alpha=0.9,
    ),
    axes.bar_label(ln_uq_cr, padding=3, color=ln_uq_cr[-1].get_facecolor())

    # ln_lq, = axes[0].plot(n_reject_lq, color="darkblue", lw=1.0, ls="-", label=f"{quantile}%"),
    # ln_uq, = axes[0].plot(n_reject_uq, color="darkred", lw=1.0, ls="-", label=f"{100 - quantile}%")
    # axes.fill_between(time_step, n_reject_lq, n_reject_uq, color=ln_uq[0].get_color(), alpha=0.4)#, label=f"{100 - quantile}% CI")

axes.axhline(REJECT_THR * ks_pval.shape[1], color="#343", ls="-.", label=f"{REJECT_THR * 100} % of variables")
axes.set_title(f"Number of variables rejected at {(1 - REJECT_THR) * 100}% confidence")
axes.set_xlabel("Timestep")
axes.set_ylabel("N variables")

axes.legend()

_reject = f"{REJECT_THR:.2f}".replace(".", "p")
plt.tight_layout()
plt.savefig(f"plt_all_cases_new.png")

In [None]:
fig, axis = plt.subplots(1, 1)
_ = axis.hist(
    # [n_reject_cr[(case_a, case_a)][:, -1], n_reject_cr[(case_a, case_b)][:, -1]],
    [n_reject_cr[_case][:, -1] for _case in n_reject_cr],
    # label=[f"{case_a} x {case_a} [ctl]", f"{case_a} x {case_b} [prt]"],
    label=[f"{_ca} x {_cb}" for (_ca, _cb) in n_reject_cr],
    edgecolor="k",
    bins=np.arange(0, 30, 2),
)

axis.axvline(REJECT_THR * ks_pval.shape[1], color="#343", ls="-.", label=f"{REJECT_THR * 100}% of variables")
axis.set_xlabel(f"Number of rejected (p < {REJECT_THR}) variables")
axis.set_ylabel(f"Frequency (of {n_reject[(case_a, case_a)][:, -1].shape[0]} iterations)")
_ = plt.legend()

In [None]:
tests = [
    ('Control', 'GW orog 0.5%'),
    ('Control', 'GW orog 1.0%'),
    ('Control', 'GW orog 10.0%'),
    ('Control', 'GW orog 30.0%'),
    ('Control', 'GW orog 40.0%'),
    ('Control', 'GW orog 50.0%')
]
pcts = [0.5, 1, 10, 30, 40, 50]
plt.plot(pcts, [np.percentile(n_reject_cr[_case], 100 - quantile, axis=0)[-1] for _case in tests], '.-', label="UQ")
plt.plot(pcts, [np.percentile(n_reject_cr[_case], quantile, axis=0)[-1] for _case in tests], '.-', label="LQ")
plt.plot(pcts, [np.percentile(n_reject_cr[_case], 50, axis=0)[-1] for _case in tests], '.-', color="k", label="Median")
plt.axhline(6, color="k", ls="--", lw=1)
plt.xlabel("GW Orographic Parameter Change [%]")
plt.ylabel("Number of variables rejected [p < 0.05]")
plt.legend()

In [None]:
test = {
    _key: (n_reject[_key] > 13).sum(axis=0) 
    for _key in reject_test
}
test_cr = {
    _key: (n_reject_cr[_key] > REJECT_THR * ks_pval.shape[1]).sum(axis=0) 
    for _key in reject_test
}
fig, axes = plt.subplots(1, 1, figsize=(12, 5))

for idx, _case in enumerate(test):
    axes.plot(test[_case], color=f"C{idx}", label=f"{_case[0]} x {_case[1]}", lw=2.3)
    axes.plot(test_cr[_case], color=f"C{idx}", ls="-", label=f"{_case[0]} x {_case[1]} [corr]", lw=1.0)
axes.set_ylabel(f"Number of iterations")
axes.set_xlabel("Time step")
axes.set_title("")
axes.legend()

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(12, 5))

# for idx, _case in enumerate(test):
    # axes.plot(test[_case], color=f"C{idx}", label=f"{_case[0]} x {_case[1]}", lw=2.3)

axes.plot(pcts, [test_cr[_case][-1] for _case in tests], 'o-')
axes.axhline(500 * 0.05, color="k", ls="--")
axes.set_ylabel(f"Number of iterations failing")
axes.set_xlabel("Percent Change")
axes.set_title("")

In [None]:
fig, axes = plt.subplots(len(files), 2, figsize=(10, 2.5 * len(files)), sharex=True)
bins = np.arange(0, 121, 2)
ex_val = 500 / ((bins.shape[0] - 1) / 30)

for file_ix, _file in enumerate(files):
    # case_a, case_b = _file.stem.split("_")[2:4]
    # cases = [case_a, case_b]
    case_a, case_b = cases[file_ix]
    n_iter = int(_file.stem.split("_")[-1][1:])
    ks_res = xr.open_dataset(_file)
    
    for _case in [0, 1]:
        axes[file_ix, _case].axhline(ex_val, color="k", ls="--")
        _ = ks_res.rnd_idx[_case].plot.hist(edgecolor="k", bins=bins, ax=axes[file_ix, _case])
        axes[file_ix, _case].set_title(cases[file_ix][_case])
        axes[file_ix, _case].set_xlabel("")

In [None]:
fig, axes = plt.subplots(len(files), 1, figsize=(10, len(files) * 2))#, sharex=True)
bins = np.arange(0, 121, 4)
ex_val = 500 / ((bins.shape[0] - 1) / 30)

for file_ix, _file in enumerate(files):
    case_a, case_b = cases[file_ix]
    # case_a, case_b = _file.stem.split("_")[2:4]
    case_a = fmt_case(case_a)
    case_b = fmt_case(case_b)
    # cases = [case_a, case_b]
    n_iter = int(_file.stem.split("_")[-1][1:])
    ks_res = xr.open_dataset(_file)
    ks_pval = ks_res.pval
    reject_by_var = pd.DataFrame((ks_pval < REJECT_THR).sum(axis=0).T, columns=ks_res.vars)
    mask = (reject_by_var.sum() > reject_by_var.sum().quantile(.9))
    
    # reject_by_var.loc[:, reject_by_var.mean() >= 0.6].plot(ax=axes[file_ix], legend=False)
    reject_by_var.index.name = "Time Step"
    sns.heatmap(reject_by_var.T[mask], ax=axes[file_ix], vmin=0, vmax=500)
    axes[file_ix].set_title(f"{case_a} x {case_b}")
fig.tight_layout()