In [None]:
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.stats import multitest as smm
from scipy import stats as sts

In [None]:
n_y = 40
n_x = 40

n_vars = 40
n_ens = 30
n_diffs = 100
REJECT_THR = 0.05

var_offset = np.random.randn(n_vars)[:, None, None, None] * 1e2
sys_diff = np.linspace(1e-5, 1, n_diffs)

nfail = []
nfail_corr = []
all_mean_1 = []
all_mean_2 = []
rmsd = []
ks_stats_0 = None

for _diff in sys_diff:
    data1 = np.random.randn(n_vars, n_ens, n_y, n_x) + var_offset
    data2 = np.random.randn(n_vars, n_ens, n_y, n_x) + var_offset + _diff
    # data2 = data1 + np.random.rand(n_vars, n_ens, n_y, n_x) * _diff
    # fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    # axes[0].plot(data1.flatten(), data2.flatten(), '.')
    # axes[1].hist([data1.flatten(), data2.flatten()])
    t_stats = []
    ks_stats = []
    means = []
    for vix in range(data1.shape[0]):
        t_stats.append(
            sts.ttest_ind(
                data1[vix, :].mean(axis=(-2, -1)), data2[vix, :].mean(axis=(-2, -1))
            ),
        )

        means.append(
            (data1[vix, :].mean(axis=(-2, -1)), data2[vix, :].mean(axis=(-2, -1)))
        )

        ks_stats.append(
            sts.mstats.ks_2samp(
                data1[vix, :].mean(axis=(-2, -1)), data2[vix, :].mean(axis=(-2, -1))
            )
        )
    means = np.array(means)
    t_stats = np.array(t_stats)
    ks_stats = np.array(ks_stats)
    if ks_stats_0 is None:
        ks_stats_0 = ks_stats

    rej_corr, pval_corr, _, _ = smm.multipletests(
        ks_stats[:, 1], alpha=2.0 * REJECT_THR, method="holm", is_sorted=False
    )

    all_mean_1.append(means[:, 0, :].mean(axis=(0, -1)))
    all_mean_2.append(means[:, 1, :].mean(axis=(0, -1)))
    rmsd.append(np.sqrt(np.sum((data1 - data2) ** 2)))

    nfail.append((ks_stats[:, 1] < REJECT_THR).sum())
    nfail_corr.append((pval_corr < REJECT_THR).sum())

nfail = np.array(nfail)
nfail_corr = np.array(nfail_corr)
all_mean_1 = np.array(all_mean_1)
all_mean_2 = np.array(all_mean_2)
# print(
#   f"NFAIL = {nfail} ({nfail / n_vars:.3f})\n"f
#   "CORR = {nfail_corr} ({nfail_corr / n_vars:.3f})"
# )

In [None]:
plt.figure(figsize=(9, 5))
plt.subplot(1, 2, 1)
plt.plot(sys_diff, nfail / n_vars, label="N Fail")
plt.plot(sys_diff, nfail_corr / n_vars, label="N Fail [corr]")
plt.axhline(REJECT_THR, color="k", ls="--")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(sys_diff, all_mean_1, color="C2", label="Data 1")
plt.plot(sys_diff, all_mean_2, color="C3", label="Data 2")
plt.legend()

plt.figure(figsize=(9, 5))
plt.subplot(1, 2, 1)
# plt.plot(sys_diff, np.array(rmsd) / n_vars)
plt.plot(sys_diff, (all_mean_2 - all_mean_1) ** 2)
plt.subplot(1, 2, 2)
plt.plot(np.abs(all_mean_2 - all_mean_1), nfail / n_vars, "x")
plt.plot(np.abs(all_mean_2 - all_mean_1), nfail_corr / n_vars, ".")
# plt.plot(np.array(rmsd) , nfail / n_vars, 'x')
# plt.plot(np.array(rmsd) , nfail_corr / n_vars, '.')

In [None]:
pvals = ks_stats_0[:, 1]
pvals = np.array(sorted(pvals))
p_n = 0.05 * (np.arange(pvals.shape[0]) + 1) / pvals.shape[0]
for i in range(pvals.shape[0]):
    print(f"{pvals[i]:.4f}\t{p_n[i]:.4f} {pvals[i] <= p_n[i]}")

In [None]:
if False:
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    axes = axes.flatten()
    pass_col = "green"
    fail_col = "red"

    for _vix in range(means.shape[0]):
        axes[_vix].hist(means[_vix].T)
        # if ks_stats[_vix, 1] < REJECT_THR:
        if pval_corr[_vix] < REJECT_THR:
            _col = fail_col
        else:
            _col = pass_col
        axes[_vix].set_title(
            f"t={t_stats[_vix, 1]:.2e} ks={ks_stats[_vix, 1]:.2e}", color=_col
        )
    fig.tight_layout()

In [None]:
pvals = np.abs(np.random.randn(n_vars))
# pvals /= pvals.max()
plt.hist(
    pvals,
    edgecolor="k"
)

In [None]:
p_n = .01 * (np.arange(pvals.shape[0]) + 1) / (pvals.shape[0])
pvals_s = sorted(pvals)
for i in range(pvals.shape[0]):
    print(f"{pvals_s[i]:.4f} \t {p_n[i]:.4f} {pvals_s[i] <= p_n[i]}")