In [None]:
%cd ../../../

In [None]:
saved = "ml_hep_sim/notebooks/article_notebooks/saved/"

In [None]:
from ml_hep_sim.notebooks.article_notebooks.test_runs import *
from ml_hep_sim.pipeline.pipes import *
from ml_hep_sim.pipeline.blocks import *

from ml_hep_sim.plotting.style import style_setup, set_size
from ml_hep_sim.stats.stat_plots import two_sample_plot

from ml_hep_sim.data_utils.higgs.process_higgs_dataset import LATEX_COLNAMES, LOG_BIN_RANGES

import matplotlib 
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy

set_size()
style_setup(seaborn_pallete=True)

In [None]:
num_train=np.logspace(4, 6, 10).astype(int)

In [None]:
glow_train = run_glow_pipeline(
        False,
        False,
        False,
        sig=False,
        num_flows=[10],
        num_train=np.logspace(4, 6, 10).astype(int),
    )

maf_train = run_maf_pipeline(
        False,
        False,
        False,
        sig=False,
        use_mog=True,
        use_maf=True,
        num_mogs=[10],
        num_train=np.logspace(4, 6, 10).astype(int),
    )

spline_train = run_spline_pipeline(
        False,
        False,
        False,
        sig=False,
        num_splines=[32],
        name_str="",
        num_train=np.logspace(4, 6, 10).astype(int),
    )

In [None]:
N = 1 * 10 ** 5
device = "cuda"

res = []
k = 5

for i, p in enumerate(maf_train):
    res_k = []
    for j in range(k):
        print(i, j)
        x_ConfigBuilderBlock, _, _, x_ModelTrainerBlock = p.pipes

        x1 = ModelLoaderBlock(device=device)(x_ConfigBuilderBlock, x_ModelTrainerBlock)

        x2 = DataGeneratorBlock(N, model_type="flow", chunks=10, device=device)(x1)
        x3 = GeneratedDataVerifierBlock(save_data=False, device=device)(x1, x2)

        x4 = DatasetBuilderBlock()(x_ConfigBuilderBlock)
        x5 = ReferenceDataLoaderBlock(device=device)(x4)

        x6 = StatTestRunnerBlock(use_results=False, add_dim=False)(x5, x3)

        pipe = Pipeline()
        pipe.compose(x1, x2, x3, x4, x5, x6)
        pipe.fit()
        
        res_k.append(pipe.pipes[-1].results)

    res.append(res_k)

In [None]:
mean_res_chi2 = []
mean_res_ks = []
mean_res_chi2_crit = []
mean_res_ks_crit = []

for i in range(len(num_train)):
    inter_chi2 = np.zeros((18, k))
    inter_ks = np.zeros((18, k))
    inter_chi2_crit = np.zeros((18, k))
    inter_ks_crit = np.zeros((18, k))
    for j in range(k):
        inter_chi2[:, j] = res[i][j][0]["chi2"].to_numpy()
        inter_ks[:, j] = res[i][j][1]["ks"].to_numpy()
        inter_chi2_crit[:, j] = res[i][j][0]["crit"].to_numpy()
        inter_ks_crit[:, j] = res[i][j][1]["crit"].to_numpy()

    mean_res_chi2.append(inter_chi2.mean(axis=1))
    mean_res_ks.append(inter_ks.mean(axis=1))
    mean_res_chi2_crit.append(inter_chi2_crit.mean(axis=1))
    mean_res_ks_crit.append(inter_ks_crit.mean(axis=1))

In [None]:
s_chi2 = np.array(mean_res_chi2).T
s_ks = np.array(mean_res_ks).T
s_chi2_crit = np.array(mean_res_chi2_crit).T
s_ks_crit = np.array(mean_res_ks_crit).T

In [None]:
rr = list(range(0, len(res), 1))
# rr.append(9)

rat_idx = s_chi2[:, -1] / s_chi2_crit[:, -1]
idx = np.argsort(rat_idx)

y1 = np.max((s_chi2 / s_chi2_crit), axis=1)[idx]
y3 = np.min((s_chi2 / s_chi2_crit), axis=1)[idx]

x = np.array(range(1, 19, 1))

for c, i in enumerate(rr):
    rat = s_chi2[:, i] / s_chi2_crit[:, i]
    
    # plt.scatter(range(1, 19, 1), rat[idx], facecolor="none", label=num_train[i], alpha=1 if c != len(rr) - 1 else 0.7, s=120, edgecolors=f"C{c}" if c != len(rr) - 1 else "k", linewidths=2)
    
    if i == len(rr) - 1:
        y2 = rat[idx]
        plt.scatter(x, rat[idx], facecolor="none", label=num_train[i], alpha=0.7, s=120, edgecolors="k", linewidths=2)
        plt.plot(x, rat[idx], color=f"k", ls='-', lw=2, alpha=0.7)
    else:
        max_rat = np.where(y1 == rat[idx])[0]
        min_rat = np.where(y3 == rat[idx])[0]

        plt.scatter(x[max_rat], rat[idx][max_rat], facecolor="none", edgecolors=f"C{c}", label=num_train[i], s=120, linewidths=2, zorder=100)
        plt.scatter(x[min_rat], rat[idx][min_rat], facecolor="none", edgecolors=f"C{c}", s=120, linewidths=2, zorder=100)

        
y1 = np.max((s_chi2 / s_chi2_crit), axis=1)[idx]
y3 = np.min((s_chi2 / s_chi2_crit), axis=1)[idx]

plt.plot(range(1, 19, 1), y1, c='g', ls='--', alpha=0.8, lw=2)
plt.fill_between(range(1, 19, 1), y1, y2, alpha=0.1, color='g')

plt.plot(range(1, 19, 1), y3, c='r', ls='--', alpha=0.8, lw=2)
plt.fill_between(range(1, 19, 1), y3, y2, alpha=0.1, color='r')
        
# y1 = np.max((s_chi2 / s_chi2_crit), axis=1)[idx]
# plt.plot(range(1, 19, 1), y1, c='k', ls='--', alpha=0.5)
# plt.fill_between(range(1, 19, 1), y1, y2, alpha=0.1, color='k')
        
plt.axhline(1, c='k', ls='--', alpha=0.8)
plt.yscale("log")
plt.ylim([0.5, 30])
plt.xlim([0, 19])
plt.xticks(range(1, len(LATEX_COLNAMES) + 1, 1), np.array(LATEX_COLNAMES)[idx], rotation=90)
plt.legend(title="$N$ train", fontsize=14, ncol=2)
plt.title("MAFMADEMOG")
plt.ylabel("$\chi^2/\chi^2_c$")
plt.tight_layout()
plt.savefig(saved + "mafmademog_N_chi2.pdf")

In [None]:
# rr = list(range(0, len(res), 1))
# # rr.append(9)

# for c, i in enumerate(rr):
#     rat = s_chi2[:, i] / s_chi2_crit[:, i]
    
#     if i == 0:
#         idx = np.argsort(rat)
#         y1 = rat[idx]
#         plt.plot(range(1, 19, 1), rat[idx], color=f"C{c}", lw=2)
#         plt.scatter(range(1, 19, 1), rat[idx], facecolor="none", label=num_train[i], alpha=1, s=120, edgecolors=f"C{c}", linewidths=2)
#     else:
#         plt.scatter(range(1, 19, 1), rat[idx], facecolor="none", label=num_train[i], alpha=1, s=120, edgecolors=f"C{c}", linewidths=2)
    
#     if i == len(rr) - 1:
#         y2 = rat[idx]
#         plt.plot(range(1, 19, 1), rat[idx], color=f"C{c}", ls='-', lw=2)

# plt.fill_between(range(1, 19, 1), y1, y2, alpha=0.1, color='k')
        
# plt.axhline(1, c='k', ls='--', alpha=0.8)
# plt.yscale("log")
# plt.ylim([0.5, 30])
# plt.xlim([0, 19])
# plt.xticks(range(1, len(LATEX_COLNAMES) + 1, 1), np.array(LATEX_COLNAMES)[idx], rotation=90)
# plt.legend(title="$N$ train", fontsize=14, ncol=2)
# plt.title("RQS")
# plt.ylabel("$\chi^2/\chi^2_c$")
# plt.tight_layout()
# plt.savefig(saved + "rqs_N_chi2.pdf")

In [None]:
rr = list(range(0, len(res), 1))
# rr.append(9)

rat_idx = s_ks[:, -1] / s_ks_crit[:, -1]
idx = np.argsort(rat_idx)

y1 = np.max((s_ks / s_ks_crit), axis=1)[idx]
y3 = np.min((s_ks / s_ks_crit), axis=1)[idx]

x = np.array(range(1, 19, 1))

for c, i in enumerate(rr):
    rat = s_ks[:, i] / s_ks_crit[:, i]
    
    # plt.scatter(range(1, 19, 1), rat[idx], facecolor="none", label=num_train[i], alpha=1 if c != len(rr) - 1 else 0.7, s=120, edgecolors=f"C{c}" if c != len(rr) - 1 else "k", linewidths=2)
    
    if i == len(rr) - 1:
        y2 = rat[idx]
        plt.scatter(x, rat[idx], facecolor="none", label=num_train[i], alpha=0.7, s=120, edgecolors="k", linewidths=2)
        plt.plot(x, rat[idx], color=f"k", ls='-', lw=2, alpha=0.7)
    else:
        max_rat = np.where(y1 == rat[idx])[0]
        min_rat = np.where(y3 == rat[idx])[0]

        plt.scatter(x[max_rat], rat[idx][max_rat], facecolor="none", edgecolors=f"C{c}", label=num_train[i], s=120, linewidths=2, zorder=100)
        plt.scatter(x[min_rat], rat[idx][min_rat], facecolor="none", edgecolors=f"C{c}", s=120, linewidths=2, zorder=100)
        
plt.plot(x, y1, c='g', ls='--', alpha=0.8, lw=2)
plt.fill_between(x, y1, y2, alpha=0.1, color='g')

plt.plot(x, y3, c='r', ls='--', alpha=0.8, lw=2)
plt.fill_between(x, y3, y2, alpha=0.1, color='r')
        
plt.axhline(1, c='k', ls='--', alpha=0.8)
plt.yscale("log")
plt.ylim([0.5, 100])
plt.xlim([0, 19])
plt.xticks(range(1, len(LATEX_COLNAMES) + 1, 1), np.array(LATEX_COLNAMES)[idx], rotation=90)
plt.legend(title="$N$ train", fontsize=14, ncol=2)
plt.ylabel("KS$/$KS$_c$")
plt.title("MAFMADEMOG")
plt.tight_layout()
plt.savefig(saved + "mafmademog_N_ks.pdf")

In [None]:
# rr = list(range(0, len(res), 1))
# # rr.append(9)

# for c, i in enumerate(rr):
#     rat = s_ks[:, i] / s_ks_crit[:, i]
    
#     if i == 0:
#         idx = np.argsort(rat)
#         y1 = rat[idx]
#         plt.plot(range(1, 19, 1), rat[idx], color=f"C{c}", lw=2)
#         plt.scatter(range(1, 19, 1), rat[idx], facecolor="none", label=num_train[i], alpha=1, s=120, edgecolors=f"C{c}", linewidths=2)
#     else:
#         plt.scatter(range(1, 19, 1), rat[idx], facecolor="none", label=num_train[i], alpha=1, s=120, edgecolors=f"C{c}", linewidths=2)
    
#     if i == len(rr) - 1:
#         y2 = rat[idx]
#         plt.plot(range(1, 19, 1), rat[idx], color=f"C{c}", ls='-', lw=2)

# plt.fill_between(range(1, 19, 1), y1, y2, alpha=0.1, color='k')
        
# plt.axhline(1, c='k', ls='--', alpha=0.8)
# plt.yscale("log")
# plt.ylim([0.5, 100])
# plt.xlim([0, 19])
# plt.xticks(range(1, len(LATEX_COLNAMES) + 1, 1), np.array(LATEX_COLNAMES)[idx], rotation=90)
# plt.legend(title="$N$ train", fontsize=14, ncol=2)
# plt.ylabel("KS$/$KS$_c$")
# plt.title("RQS")
# plt.tight_layout()
# plt.savefig(saved + "rqs_N_ks.pdf")