In [None]:
import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from matplotlib.ticker import ScalarFormatter
from statsmodels.stats import multitest as smm

import detclim.plot_bootstrap_results as pbr

plt.style.use("default")
REJ_THR = {
    0.01: {"ks": 6, "cvm": 9, "mw": 9, "es": 8},
    0.05: {"ks": 11, "cvm": 16, "mw": 16, "wsr": 15},
}

In [None]:
def plot_tests_failed(
    case_data: dict,
    idx: int,
    params: dict,
    fig_spec: dict,
    style: dict,
    stest: str,
):
    """
    Plot the number of rejected variables

    Parameters
    ----------
    case_data : list
        _description_
    idx : int
        _description_
    params : dict
        _description_

    """
    nparams = len(params)
    # Horizontal orientation
    if "horiz" in fig_spec["orient"].lower():
        fig, axis = plt.subplots(
            1,
            nparams,
            figsize=(nparams * fig_spec["width"], fig_spec["width"]),
            dpi=fig_spec["dpi"],
            sharey=True,
        )
    # Vertical orientation
    else:
        fig, axis = plt.subplots(
            nparams,
            1,
            figsize=(fig_spec["width"], nparams * fig_spec["width"]),
            dpi=fig_spec["dpi"],
            sharey=True,
        )
    axis = axis.flatten()

    for ixp, _param in enumerate(params):
        _ntests = pbr.ntests(case_data[_param], idx)[stest]
        for ixm, _method in enumerate(case_data[_param][0].methods):
        # for ixm, _method in enumerate(["uncor", "fdr_bh"]):
            _method_hum = f"{case_data[_param][0].methods[_method]} - {case_data[_param][0].stests[stest]}"
            axis[ixp].semilogy(
                _ntests["pct_change"],
                _ntests[_method],
                color=style["colors"][ixp],
                linestyle=style["lstyle"][ixm % 4],
                marker=style["markers"][ixm],
                label=f"{params[_param]}: {_method_hum}",
            )
        axis[ixp].set_ylim([1, 2500])
        axis[ixp].legend(fontsize=style["legend_fontsize"])

        # Put a horizontal line for alpha*niter% of bootstrap iterations
        _alphapct = case_data[_param][0].alpha * case_data[_param][0].n_iter
        axis[ixp].axhline(
            _alphapct,
            color="k",
            ls="--",
            lw=style["linewidth"],
            label=f"{case_data[_param][0].alpha * 100:.0f}% of tests ({_alphapct:.0f})",
        )

        if "vert" in fig_spec["orient"].lower():
            x_check = ixp == len(params) - 1
            y_check = True
        else:
            x_check = True
            y_check = ixp == 0

        if x_check:
            axis[ixp].set_xlabel(
                "Parameter Change [%]", fontsize=style["label_fontsize"]
            )

        if y_check:
            axis[ixp].set_ylabel(
                (
                    f"Tests with global\n"
                    f"significance [p < {case_data[_param][0].alpha:.2f}]"
                ),
                fontsize=style["label_fontsize"],
            )
        axis[ixp].legend(fontsize=style["legend_fontsize"])

        for _ax in [axis[ixp].xaxis, axis[ixp].yaxis]:
            _ax.set_major_formatter(ScalarFormatter(useOffset=True))

        pbr.style_axis(axis[ixp], style)

    fig.tight_layout()
    _alphastr = f"{case_data[_param][0].alpha:.02f}".replace(".", "p")
    fig.savefig(f"plt_nfailed_tests_{stest}_{idx}_a{_alphastr}.{fig_spec['ext']}")

In [None]:
ext = "png"
vert = False
run_len = "1year"
rolling = 12
test_size = 30
niter = 1000
alpha = 0.05

assert 0.0 < alpha < 1.0, f"ALPHA {alpha} not in [0, 1]"

pcts = {
    "effgw_oro": [1, 10, 20, 30, 40, 50],
    "clubb_c1": [1, 3, 5, 10],
    "zmconv_c0_ocn": [0.5, 1, 3, 5],
}
params_hum = {
    "effgw_oro": "GW Orog",
    "clubb_c1": "CLUBB C1",
    "zmconv_c0_ocn": "ZM Conv C0-Ocean",
}

# Initialize list of case with control vs. control test
cases = []
cases_hum = []
case_data = {}
_casefile = (
    "bootstrap_output.{runlen}_{rolling}avg_"
    "ts{test_size}.{case[0]}_{case[1]}_n{niter}.nc"
)

# Add the cases for all the tested parameters
for _param in pcts:
    case_data[_param] = []

    for _pct in pcts[_param]:
        _pct_str = f"{_pct:.1f}".replace(".", "p")
        if "zmconv" in _param:
            _ctlcase = "ctl-miller"
        else:
            _ctlcase = "ctl"
        _case = (_ctlcase, f"{_param}-{_pct_str}pct")
        cases.append(_case)

        if _pct < 1:
            _casehum = ("Control", f"{params_hum[_param]} {_pct:.1f}%")
        else:
            _casehum = ("Control", f"{params_hum[_param]} {_pct:.0f}%")
        cases_hum.append(_casehum)

        _file = Path(
            "../bootstrap_data",
            _casefile.format(
                runlen=run_len,
                rolling=rolling,
                test_size=test_size,
                case=cases[-1],
                niter=niter,
            ),
        ).resolve()
        case_data[_param].append(
            pbr.CaseData(_file, *_case, alpha, pct_change=_pct)
        )
fig_width = 12.5 / 2.54
fig_spec = {
    "dpi": 300,
    "width": fig_width,
    "height": fig_width / 3.75,
    "ext": ext,
}
if vert:
    fig_spec["orient"] = "vert"
else:
    fig_spec["orient"] = "horiz"

_case = ("ctl", "ctl")
cases.append(_case)
_casehum = ("Control", "Control")
cases_hum.append(_casehum)
ctl_file = Path(
    "../bootstrap_data",
    _casefile.format(
        runlen=run_len,
        rolling=rolling,
        test_size=test_size,
        case=cases[0],
        niter=niter,
    ),
).resolve()

case_data["ctl"] = pbr.CaseData(ctl_file, *_case, alpha)

In [None]:
style = {
    "linewidth": 1.0,
    "label_fontsize": 12,
    "legend_fontsize": 8,
    "tick_fontsize": 10,
    "colors": ["C1", "C4", "C6"],
    "lstyle": ["-", "--", "-.", ":"],
    "markers": ["o", "x", "+", "h", ".", "*"],
}
# for _ti in [0, 1, 2]:
_ti = 2
for stest in case_data["ctl"].stests:
    plot_tests_failed(
        case_data,
        _ti,
        params=params_hum,
        fig_spec=fig_spec,
        style=style,
        stest=stest,
    )

In [None]:
nfailed = {}
nrej_vars = {}
pct_change = {}
for param in case_data:
    if "ctl" not in param:
        nfailed[param] = {}
        nrej_vars[param] = {}
        pct_change[param] = []
        for _case in case_data[param]:
            pct_change[param].append(_case.pct_change)
            for stest in _case.ntests:
                nfailed[param][stest] = {}
                nrej_vars[param][stest] = {}
                # for method in ["uncor", "fdr_bh"]:
                for method in _case.ntests[stest]:
                    nfailed[param][stest][method] = [_case.ntests[stest][method][-1] for _case in case_data[param]]
                    nrej_vars[param][stest][method] = [_case.reject_qtiles[stest][95.0][method][-1] for _case in case_data[param]]

In [None]:
lstyle = {"uncor": "-", "fdr_bh": "--", "fdr_by": "-.", "bonferroni": ":"}
ornl_colours = {
    "green": "#007833",
    "bgreen": "#84b641",
    "orange": "#DE762D",
    "teal": "#1A9D96",
    "red": "#88332E",
    "blue": "#5091CD",
    "gold": "#FECB00",
}
colors = {
    "ks": ornl_colours["green"],
    "cvm": ornl_colours["blue"],
    "wsr": ornl_colours["red"],
    "mw": ornl_colours["orange"]
}
markers = {
    "ks": "o",
    "cvm": ".",
    "wsr": "H",
    "mw": "x"
}
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)

for idx, param in enumerate(nfailed):
    for stest in nfailed[param]:
        # for method in ["uncor", "fdr_bh"]:
        for method in nfailed[param][stest]:
            axes[idx].semilogy(
                pct_change[param],
                nfailed[param][stest][method],
                color=colors[stest],
                ls=lstyle[method],
                marker=markers[stest],
                label=f"{stest} - {method}"
            )
    axes[idx].set_title(f"{params_hum[param]}")
    axes[idx].set_ylim([10, 1500])
    axes[idx].set_xlabel(f"Pct change in {param}")
    axes[idx].grid(visible=True, ls="--", lw=0.5)
    axes[idx].axhline(niter * alpha, color="k", ls=":")
    x_check = True
    y_check = idx == 0

    if x_check:
        axes[idx].set_xlabel(
            "Parameter Change [%]", fontsize=style["label_fontsize"]
        )

    if y_check:
        axes[idx].set_ylabel(
            (
                f"Tests with global\n"
                f"significance [p < {case_data[param][0].alpha:.2f}]"
            ),
            fontsize=style["label_fontsize"],
        )

    for _ax in [axes[idx].xaxis, axes[idx].yaxis]:
        _ax.set_major_formatter(ScalarFormatter(useOffset=True))
    axes[-1].legend(fontsize=style["legend_fontsize"])
_ = plt.legend()


fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)

for idx, param in enumerate(nfailed):
    for stest in nfailed[param]:
        for method in ["uncor", "fdr_bh"]:
            axes[idx].plot(
                pct_change[param],
                nrej_vars[param][stest][method],
                color=colors[stest],
                ls=lstyle[method],
                label=f"{stest} - {method}",
                marker=".",
            )

        axes[idx].axhline(REJ_THR[0.05][stest], color=colors[stest], ls="--", lw=1, alpha=.7)

    axes[idx].set_xlabel(f"Pct change in {param}")
    axes[idx].grid(visible=True, ls="--", lw=0.5)
    # axes[idx].axhline(niter * alpha, color="k", ls=":")
_ = plt.legend()


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)
lstyle = {"fdr_bh": "-", "fdr_by": ":", "bonferroni": "-."}

for idx, param in enumerate(nfailed):
    for stest in nfailed[param]:
        if stest == "wsr":
            continue
        # for method in nfailed[param][stest]:
        for method in ["fdr_bh",]:
            if method != "uncor":
                axes[idx].plot(
                    pct_change[param],
                    (
                        np.array(nfailed[param][stest][method]) -
                        np.array(nfailed[param][stest]["uncor"])
                    ), # / np.array(nfailed[param][stest]["uncor"]),
                    color=colors[stest],
                    label=f"{stest} - {method}",
                    ls=lstyle[method],
                )
    axes[idx].set_xlabel(f"Pct change in {param}")
    axes[idx].grid(visible=True, ls="--", lw=0.5)
    axes[idx].axhline(0, color="k", ls=":")

plt.legend()