In [None]:
import os
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
filepath = "/DATA_MASTER_PATH/23_hike_pinunu-background/2308_zoptical-zanalyze_final_vars/pred_norm_conv1plus_withgencls_FINAL/"

def filesel(s):
    #sel = ("_0." in s)
    sel = ("sig_mal" in s) | ("bkg_2p0_mal" in s)
    #sel = True
    return sel

files_2pi = [s for s in os.listdir(filepath) if ((os.path.isfile(os.path.join(filepath, s))) & filesel(s) & ("2p0_mal" in s))]
files_sig = [s for s in os.listdir(filepath) if ((os.path.isfile(os.path.join(filepath, s))) & filesel(s) & ("sig_mal" in s))]
files_lam = [s for s in os.listdir(filepath) if ((os.path.isfile(os.path.join(filepath, s))) & filesel(s) & ("lambda" in s))]

files_bkg = files_2pi

In [None]:
def histsum(files):

    n_decays = 0

    for ifile, file, in enumerate(files):
            
        n_decays_0 = n_decays
        for i in range(1, 10):
            if "_%dM_"%(i) in file:
                n_decays += i*1e6
            if "_%dM_"%(10*i) in file:
                n_decays += i*10e6
            if "_%dM_"%(100*i) in file:
                n_decays += i*100e6
            if "_%dG_"%(i) in file:
                n_decays += i*1e9
            if "_%dG_"%(10*i) in file:
                n_decays += i*10e9
        print("opened %d/%d, with %d new decays" % (ifile+1, len(files), n_decays-n_decays_0))

        if ifile==0:
            preds0 = pd.read_csv(filepath+file)
            preds = preds0[["pred0"]]

        else:
            if ifile<100:
                preds0 = pd.read_csv(filepath+file)
                preds = pd.concat((preds, preds0[["pred0"]]))

    return preds, n_decays
    
preds_bkg, n_decays_bkg = histsum(files_bkg)
preds_sig, n_decays_sig = histsum(files_sig)

normfact_sig = 15483/n_decays_sig
normfact_bkg = (446e9/n_decays_bkg) if files_bkg==files_2pi else (8.22e13/n_decays_bkg)

In [None]:
# manually add weights in case of old prediction files
if not ("W" in preds_bkg.columns):
    print("manually adding weights to preds_bkg")
    preds_bkg["W"] = 1
if not ("W" in preds_sig.columns):
    print("manually adding weights to preds_sig")
    preds_sig["W"] = 1
    
# also add dummy values of the final phase-space variables in case of old prediction files
if False:
    if not ("Vertex_xRec_Z" in preds_bkg.columns):
        print("manually adding Vertex_xRec_Z to preds_bkg")
        preds_bkg["Vertex_xRec_Z"] = -9999
    if not ("Vertex_xRec_Z" in preds_sig.columns):
        print("manually adding Vertex_xRec_Z to preds_sig")
        preds_sig["Vertex_xRec_Z"] = -9999

    if not ("Vertex_pRecPi_T" in preds_bkg.columns):
        print("manually adding Vertex_pRecPi_T to preds_bkg")
        preds_bkg["Vertex_pRecPi_T"] = -9999
    if not ("Vertex_pRecPi_T" in preds_sig.columns):
        print("manually adding Vertex_pRecPi_T to preds_sig")
        preds_sig["Vertex_pRecPi_T"] = -9999

In [None]:
# also open classifiers
clss = []
fileclsnames = [
    "/eos/user/m/msoldani/succo/postdocs/23-25_lnf/hike_sensitivity/pinunu_vs_background/classification/23_09_bdt_first_results_SHORTBL/signal_vs_2pi_normalised/bdt_ab.pickle",
    "/eos/user/m/msoldani/succo/postdocs/23-25_lnf/hike_sensitivity/pinunu_vs_background/classification/23_09_bdt_first_results_SHORTBL/signal_vs_lambda_normalised/bdt_ab.pickle"
]
for fileclsname in fileclsnames:
    with open(fileclsname, 'rb') as filecls:
        clstemp = pickle.load(filecls)
    clss.append(clstemp)
print("classifiers loaded")

In [None]:
cls = clss[0] if type(clss[0])!=tuple else clss[0][0]
cut0 = cls["output_cut"][cls["output_cut"]["used_for_evaluation"]]

score_cut_0 = []
score_cut_1 = []
score_cut_2 = []
cuts = np.linspace(cut0, 0.510, 50)
for cut in cuts:
    score_cut_0.append(
        len(preds_bkg[preds_bkg["pred0"]<cut]["pred0"])
        #sum(preds_bkg[preds_bkg["pred0"]<cut]["W"])
    )
    score_cut_1.append(
        len(preds_sig[preds_sig["pred0"]>cut]["pred0"])
        #sum(preds_sig[preds_sig["pred0"]>cut]["W"])
    )
    score_cut_2.append(
        len(preds_bkg[preds_bkg["pred0"]>cut]["pred0"])
        #sum(preds_bkg[preds_bkg["pred0"]>cut]["W"])
    )

In [None]:
fig = plt.figure(figsize=(7, 5))

plt.subplot(211)
preds_bkg["pred0"].hist(bins=100, histtype="step", label="true background", weights=preds_bkg["W"]*normfact_bkg)
preds_sig["pred0"].hist(bins=100, histtype="step", label="true signal", weights=preds_sig["W"]*normfact_sig)

plt.subplot(212)
plt.yscale("log")
preds_bkg["pred0"].hist(bins=100, histtype="step", weights=preds_bkg["W"]*normfact_bkg)
preds_sig["pred0"].hist(bins=100, histtype="step", weights=preds_sig["W"]*normfact_sig)

plt.subplot(211)
plt.axvline(cut0, color="black", label="original cut")
#if ("_norm" in filepath):
#    plt.axvline(0.5018, color="red", ls="-", label="new cut 0.5018")
#    plt.axvline(0.5030, color="red", ls="--", label="new cut 0.5030")
#    plt.axvline(0.5027, color="red", ls=":", label="new cut 0.5027")
#if ("_norm_conv1plus" in filepath):
#    plt.axvline(0.5055, color="red", ls="-", label="new cut 0.5018")
if ("_norm_conv1plus_withgencls" in filepath):
    if ("SHORTBL" in filepath):
        plt.axvline(0.5065, color="red", ls="-", label="new cut 0.5065")
    else:
        plt.axvline(0.5030, color="red", ls="-", label="new cut 0.5030")

plt.subplot(212)
plt.axvline(cut0, color="black")
#if ("_norm" in filepath):
#    plt.axvline(0.5018, color="red", ls="-")
#    plt.axvline(0.5030, color="red", ls="--")
#    plt.axvline(0.5027, color="red", ls=":")
#if ("_norm_conv1plus" in filepath):
#    plt.axvline(0.5055, color="red", ls="-")
if ("_norm_conv1plus_withgencls" in filepath):
    if ("SHORTBL" in filepath):
        plt.axvline(0.5065, color="red", ls="-")
    else:
        plt.axvline(0.5030, color="red", ls="-")

fig.legend(loc="upper left")
fig.tight_layout()

In [None]:
fig, axs = plt.subplots(figsize=(10, 5), ncols=2)

ax = axs[0]
ax.axvline(cut0, color="black", label="old cut")
#if ("_norm" in filepath):
#    ax.axvline(0.5018, color="red", ls="-", label="new cut 0.5018")
#    ax.axvline(0.5030, color="red", ls="--", label="new cut 0.5030")
#    ax.axvline(0.5027, color="red", ls=":", label="new cut 0.5027")
#if ("_norm_conv1plus" in filepath):
#    plt.axvline(0.5055, color="red", ls="-", label="new cut 0.5018")
if ("_norm_conv1plus_withgencls" in filepath):
    if ("SHORTBL" in filepath):
        ax.axvline(0.5065, color="red", ls="-", label="new cut 0.5065")
        ax.set_xlim((0.4995, 0.507))
    else:
        ax.axvline(0.5030, color="red", ls="-", label="new cut 0.5030")
        ax.set_xlim((0.4995, 0.504))

#axr = ax.twinx()
ax.plot(cuts, np.array(score_cut_2)/score_cut_2[0], color="C0", label="bkg > x")
ax.plot(cuts, np.array(score_cut_1)/score_cut_1[0], color="C1", label="sig > x")
#axr.plot(cuts, np.array(score_cut_0)/score_cut_0[0], color="C2", ls="-.", label="bkg < x (right axis)")
ax.grid()

ax = axs[1]
ax.axvline(cut0, color="black")
#if ("_norm" in filepath):
#    ax.axvline(0.5018, color="red", ls="-")
#    ax.axvline(0.5030, color="red", ls="--")
#    ax.axvline(0.5027, color="red", ls=":")
#if ("_norm_conv1plus" in filepath):
#    plt.axvline(0.5055, color="red", ls="-")
if ("_norm_conv1plus_withgencls" in filepath):
    if ("SHORTBL" in filepath):
        ax.axvline(0.5065, color="red", ls="-")
        ax.set_xlim((0.4995, 0.507))
    else:
        ax.axvline(0.5030, color="red", ls="-")
        ax.set_xlim((0.4995, 0.504))
ax.set_yscale("log")

#axr = ax.twinx()

ax.plot(cuts, np.array(score_cut_2)/score_cut_2[0], color="C0")
ax.plot(cuts, np.array(score_cut_1)/score_cut_1[0], color="C1")
#axr.plot(cuts, np.array(score_cut_0)/score_cut_0[0], color="C2", ls="-.")
ax.grid()
fig.legend(loc="upper left")
fig.tight_layout()