In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statsmodels.stats import multitest as smm
from matplotlib.ticker import ScalarFormatter
from detclim.results_plot import fmt_case

plt.style.use("default")

ALPHA = 0.05
run_len = "1year"
rolling = 12
niter = 1000
ext = "pdf"
CTL_THR = 11
COMPUTE_THR = False

In [None]:
def correct_pvals(ks_pval, alpha=0.05, method: str = "fdr_bh"):
    _pval_cr = []
    for jdx in range(ks_pval.shape[0]):
        for kdx in range(ks_pval.shape[-1]):
            _pval_cr.append(
                smm.multipletests(
                    pvals=ks_pval[jdx, :, kdx],
                    alpha=alpha,
                    method=method,
                    is_sorted=False,
                )[1]
            )

    return np.array(_pval_cr).reshape(ks_pval.shape)

In [None]:
cases = [
    ("ctl", "ctl"),
    ("effgw_oro-1p0pct",) * 2,
    ("effgw_oro-10p0pct",) * 2,
    ("effgw_oro-20p0pct",) * 2,
    ("effgw_oro-30p0pct",) * 2,
    ("effgw_oro-40p0pct",) * 2,
    ("effgw_oro-50p0pct",) * 2,
    ("clubb_c1-1p0pct",) * 2,
    ("clubb_c1-3p0pct",) * 2,
    ("clubb_c1-5p0pct",) * 2,
    ("clubb_c1-10p0pct",) * 2,
    ("opt-O1",) * 2,
    ("fastmath",) * 2,
]
pcts = [0.5, 1, 5, 10, 20, 30, 40, 50]
# pcts = [0.5, 1]
if COMPUTE_THR:
    data_path = Path("../", "bootstrap_data")
else:
    data_path = Path("../", "bootstrap_data_ctl")

files = [
    Path(
        data_path,
        "bootstrap_output.{}_{}avg.{}_{}_n{}.nc".format(run_len, rolling, *case, niter),
    )
    for case in cases
]
print("[")
for _file in files:
    print(f"\t{_file}\t\t{_file.exists()}")
print("]")

ks_pval_cr = {}

n_reject = {}
reject_test = {}

n_reject_cr = {}
reject_test_cr = {}
rejections = {}

for _ix, _file in enumerate(files):
    case_a, case_b = cases[_ix]
    n_iter = int(_file.stem.split("_")[-1][1:])
    case_a = fmt_case(case_a)
    case_b = fmt_case(case_b)
    ks_res = xr.open_dataset(_file)
    ks_pval = ks_res["pval"].values

    # fig, axes = plt.subplots(1, 2, figsize=(15, 4))
    quantile = ALPHA * 100
    time_step = np.arange(ks_res.time.shape[0])

    # n_reject = np.array((ks_pval < ALPHA).sum(axis=1))
    n_reject[(case_a, case_b)] = np.array((ks_pval < ALPHA).sum(axis=1))

    n_reject_mean = np.median(n_reject[(case_a, case_b)], axis=0)
    n_reject_lq = np.percentile(n_reject[(case_a, case_b)], quantile, axis=0)
    n_reject_uq = np.percentile(n_reject[(case_a, case_b)], 100 - quantile, axis=0)

    ks_pval_cr[(case_a, case_b)] = correct_pvals(ks_pval, alpha=ALPHA)
    n_reject_cr[(case_a, case_b)] = np.array(
        (ks_pval_cr[(case_a, case_b)] < ALPHA).sum(axis=1)
    )
    n_reject_mean_cr = np.median(n_reject_cr[(case_a, case_b)], axis=0)
    n_reject_lq_cr = np.percentile(n_reject_cr[(case_a, case_b)], quantile, axis=0)
    n_reject_uq_cr = np.percentile(
        n_reject_cr[(case_a, case_b)], 100 - quantile, axis=0
    )

    rejections[(case_a, case_b)] = {
        f"{100 * (1 - ALPHA)}%": n_reject_uq,
        f"{100 * (1 - ALPHA)}% [Corrected]": n_reject_uq_cr,  # .max(axis=0),
    }

In [None]:
reject_data = {
    "Case": [],
    "Rejected fields": [],
    "iteration": [],
    "Parameter": [],
    "Median": [],
    "Pct": [],
}

time_idx = 1
for _case in n_reject:
    for idx, nrej in enumerate(n_reject[_case]):
        _pct = _case[0].split(" ")[-1]
        try:
            _pct = float(_pct[:-1])
        except ValueError:
            _pct = 0
        reject_data["Case"].append(_case[0])
        if _case[0] == "opt-O1" or _case[0] == "fastmath":
            reject_data["Parameter"].append("Optimization")
        else:
            reject_data["Parameter"].append(_case[0].split(" ")[0])
        reject_data["Rejected fields"].append(nrej[time_idx])
        reject_data["iteration"].append(idx)
        reject_data["Median"].append(np.median(n_reject[_case][:, time_idx]))
        reject_data["Pct"].append(_pct)

reject_data_frame = pd.DataFrame(reject_data)

In [None]:
fig, axis = plt.subplots(1, 1, figsize=(12.5 / 2.54, 6.25 / 2.54), dpi=600)

with sns.plotting_context(context="paper", font_scale=0.5, rc=None):
    # _ = sns.boxenplot(
    _ = sns.boxplot(
        reject_data_frame,
        orient="h",
        x="Rejected fields",
        y="Case",
        hue="Parameter",
        palette="Set2",
        ax=axis,
        zorder=4,
        # width_method="linear",
        # k_depth="trustworthy",
        # trust_alpha=ALPHA,
        linewidth=0.5,
        whis=(5, 95),
        fliersize=1,
    )

    _labels = axis.get_yticklabels()
    _tmp = []
    _newlabels = []

    for _label in _labels:
        _txt = _label.get_text()
        if _txt.split(" ")[0] not in _tmp:
            _tmp.append(_txt.split(" ")[0])
        else:
            _label.set_text(_txt.split(" ")[-1])
        _newlabels.append(_label)

    axis.set_yticklabels(_newlabels)
    axis.tick_params(labelsize=8)

    if COMPUTE_THR:
        CTL_THR = float(
            np.percentile(
                reject_data_frame.groupby("Case")["Rejected fields"].quantile(
                    1 - ALPHA
                ),
                50,
            )
        )

    axis.axvline(CTL_THR, ls="--", lw=0.6, color="k")

    axis.set_ylabel("")
    axis.set_xlabel(f"Rejected fields at \u03b1={ALPHA:.2f}", fontsize=8)
    axis.grid(visible=True, ls="--", lw=0.2, zorder=0)

    # Add label for 95th %tile
    _txtx = CTL_THR + 1
    _txty = len(cases) * 0.985
    print(f"ADD THE TEXT AT {_txtx}, {_txty}")
    axis.text(
        float(_txtx),
        _txty,
        f"{CTL_THR:.0f}",
        horizontalalignment="right",
        verticalalignment="top",
        zorder=10,
        fontsize=6,
        bbox=dict(boxstyle="round", edgecolor="white", facecolor="grey", alpha=0.6),
    )

    # fig.suptitle(f"Median {(1 - ALPHA) * 100}th percentlile={CTL_THR:.1f}", fontsize=9)
    plt.tight_layout()
    plt.savefig(f"plt_control_nrej_a{ALPHA}.{ext}")

In [None]:
reject_data_frame["N"] = 1
reject_data_frame[reject_data_frame["Rejected fields"] > CTL_THR].groupby("Case")[
    "N"
].sum()

In [None]:
false_positives = {
    _key[0]
    .replace("%", r"\%")
    .replace("_", " "): np.sum(n_reject[_key][:, :] > CTL_THR, axis=0)
    for _key in n_reject
}
false_positives_cr = {
    _key[0]
    .replace("%", r"\%")
    .replace("_", " "): np.sum(n_reject_cr[_key][:, :] > 0, axis=0)
    for _key in n_reject_cr
}

false_positives = pd.DataFrame(false_positives)
false_positives_cr = pd.DataFrame(false_positives_cr)

false_positives = false_positives.rename({0: "0-11", 1: "1-12", 2: "2-13"}).T
false_positives_cr = false_positives_cr.rename({0: "0-11", 1: "1-12", 2: "2-13"}).T

false_positives_rate_cr = false_positives_cr / 1000.0
false_positives_rate = false_positives / 1000.0
_label = f"{100 * (1 - ALPHA):.0f}"
_label = r"\textbf{" + _label + r"th \%tile}"

fpr_mean = false_positives_rate.mean()
fpr_cr_mean = false_positives_rate_cr.mean()
false_positives_rate.loc[_label] = false_positives_rate.quantile(1 - ALPHA, interpolation="nearest")
false_positives_rate_cr.loc[_label] = false_positives_rate_cr.quantile(1 - ALPHA, interpolation="nearest")
mean_label = r"\textbf{Mean}"
false_positives_rate.loc[mean_label] = fpr_mean
false_positives_rate_cr.loc[mean_label] = fpr_cr_mean


In [None]:
false_positives["Mode"] = "Uncorrected"
false_positives_cr["Mode"] = "Corrected"

all_fps = pd.concat((false_positives, false_positives_cr))
all_fps["Sim"] = all_fps.index
all_fps = all_fps.melt(
    id_vars=["Sim", "Mode"], var_name="Months", value_name="False Positives"
)

In [None]:
ptile = 100 * (1 - ALPHA)
_estr = lambda x: np.percentile(x, ptile, method="closest_observation")

ornl_colours = {
    "green": "#007833",
    "bgreen": "#84b641",
    "orange": "#DE762D",
    "teal": "#1A9D96",
    "red": "#88332E",
    "blue": "#5091CD",
    "gold": "#FECB00",
}

fig, axis = plt.subplots(1, 1, figsize=(12.5 / 2.54, 10 / 2.54), dpi=120)
_plt = sns.barplot(
    x="Months",
    y="False Positives",
    hue="Mode",
    data=all_fps.query("Months != '0-11'"),
    estimator=_estr,
    errorbar=("ci", ptile),
    ax=axis,
    palette=[ornl_colours.get(_color) for _color in ["teal", "orange"]],
    saturation=1.0,
)
axis.axhline(ks_pval.shape[0] * ALPHA, ls="--", lw=2, color="k", zorder=0)
sns.despine(ax=axis, offset=5, trim=True)

_ = _plt.legend().set_title("")
_ = _plt.set_ylabel(f"False Positives {ptile:.0f}th %tile")
plt.tight_layout()
fig.savefig(f"plt_false_pos_{f'{ALPHA:.2f}'.replace('.', 'p')}.{ext}")

In [None]:
false_positives_rate.index.name = "Simulation"
false_positives_rate_cr.index.name = "Simulation"

all_fprs = pd.concat(
    [
        false_positives_rate[["2-13"]].rename(columns={"2-13": "Uncorrected"}).T,
        false_positives_rate_cr[["2-13"]].rename(columns={"2-13": "Corrected"}).T,
    ]
).T
print(all_fprs.to_latex(float_format="{:.03f}".format))

In [None]:
false_positives_rate["2-13"].sort_values()

In [None]:
false_positives_rate_cr["2-13"].sort_values()

In [None]:
0.95 * 13