In [1]:
# import modules
import uproot, sys, time, math, pickle, os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import awkward as ak
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from matplotlib.ticker import FormatStrFormatter
import matplotlib.ticker as ticker
from scipy.special import betainc
from scipy.stats import norm
from pathlib import Path

# import config functions
sys.path.append('/home/jlai/dark_photon/code/config')
from plot_config import getWeight, zbi, sample_dict, getVarDict
from plot_var import variables, variables_mc, ntuple_names
from n_1_iteration_functions import get_best_cut, calculate_significance, apply_cut_to_fb, apply_all_cuts, compute_total_significance, n_minus_1_optimizer
from perf_sig_plot import plot_performance, plot_significance, plot_n_1, calculate_significance2

# Set up plot defaults
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = 14.0,10.0  # Roughly 11 cm wde by 8 cm high  
mpl.rcParams['font.size'] = 20.0 # Use 14 point font
sns.set(style="whitegrid")

font_size = {
    "xlabel": 17,
    "ylabel": 17,
    "xticks": 15,
    "yticks": 15,
    "legend": 14,
    "title": 20
}

plt.rcParams.update({
    "axes.labelsize": font_size["xlabel"],  # X and Y axis labels
    "xtick.labelsize": font_size["xticks"],  # X ticks
    "ytick.labelsize": font_size["yticks"],  # Y ticks
    "legend.fontsize": font_size["legend"],  # Legend
    "axes.titlesize": font_size["title"] # Title
})


tot = []
signal_name = 'ggHyyd'
data = pd.DataFrame()

def test(fb):
    # checking if there are any none values
    mask = ak.is_none(fb['met_tst_et'])
    n_none = ak.sum(mask)
    print("Number of none values: ", n_none)
    # if n_none > 0:
    #     fb = fb[~mask]
    # print("Events after removing none values: ", len(fb), ak.sum(ak.is_none(fb['met_tst_et'])))

def print_cut(ntuple_name, fb, label):
    print(f"{ntuple_name} Unweighted Events {label}: ", len(fb))
    print(f"{ntuple_name} Weighted Events {label}: ", sum(getWeight(fb, ntuple_name)))
        
for i in range(len(ntuple_names)):
    start_time = time.time()
    ntuple_name = ntuple_names[i]
    path = f"/data/fpiazza/ggHyyd/NtuplesWithBDTSkim/{ntuple_name}_nominal_bdt.root"
    f = uproot.open(path)['nominal']
    if ntuple_name.startswith("mc"):
        fb = f.arrays(variables+variables_mc, library='ak')
        print_cut(ntuple_name, fb, 'before cut')
        
        fb = fb[ak.num(fb['ph_eta']) > 0]     # for abs(ak.firsts(fb['ph_eta'])) to have value to the reweighting
        fb = fb[fb['n_ph'] == 1]
        fb = fb[fb['n_el_baseline'] == 0]

        # goodPV on signal only
        if ntuple_name == 'ggHyyd':
            fb = fb[ak.num(fb['pv_z']) > 0]
            good_pv_tmp = (np.abs(ak.firsts(fb['pv_truth_z']) - ak.firsts(fb['pv_z'])) <= 0.5)
            fb = fb[good_pv_tmp]
            
        
    if (ntuple_name == "data23_y") or (ntuple_name == "data24_y"):  # jet-faking 
        fb = f.arrays(variables, library='ak')
        print_cut(ntuple_name, fb, 'before cut')

        fb = fb[ak.num(fb['ph_eta']) > 0]
        mask1 = (ak.firsts(fb['ph_topoetcone40'])-2450.)/ak.firsts(fb['ph_pt']) > 0.1   # jet_faking_photon cut
        fb = fb[mask1]
        fb = fb[fb['n_ph_baseline'] == 1]
        fb = fb[fb['n_el_baseline'] == 0]


    if (ntuple_name == "data23_eprobe") or (ntuple_name == "data24_eprobe"): # electron-faking
        fb = f.arrays(variables, library='ak')
        print_cut(ntuple_name, fb, 'before cut')
        
        fb = fb[fb['n_el'] == 1]
        fb = fb[fb['n_ph_baseline'] == 0]

        # using electron info for photon info
        fb['ph_pt'] = fb['el_pt']
        fb['ph_eta'] = fb['el_eta']
        fb['ph_phi'] = fb['el_phi']
        fb['dphi_met_phterm'] = fb['dphi_met_eleterm']  

    fb = fb[ak.num(fb['ph_pt']) > 0] # prevent none values in Tbranch
    fb = fb[ak.firsts(fb['ph_pt']) >= 50000] # ph_pt cut (basic cut)
    fb = fb[fb['n_mu_baseline'] == 0]
    fb = fb[fb['n_tau_baseline'] == 0]
    fb = fb[fb['trigger_HLT_g50_tight_xe40_cell_xe70_pfopufit_80mTAC_L1eEM26M']==1]
    fb = fb[fb['met_tst_et'] >= 100000] # MET cut (basic cut)
    fb = fb[fb['n_jet_central'] <= 3] # n_jet_central cut (basic cut)
    
    fb['VertexBDTScore'] = fb['BDTScore'] # renaming BDTScore to ensure this is recognized as Vertex BDT Score
    fb = fb[fb['VertexBDTScore'] > 0.1]
    
    mt_tmp = np.sqrt(2 * fb['met_tst_et'] * ak.firsts(fb['ph_pt']) * 
                    (1 - np.cos(fb['met_tst_phi'] - ak.firsts(fb['ph_phi'])))) / 1000
    mask1 = mt_tmp > 100
    mask2 = mt_tmp < 140
    fb = fb[mask1 * mask2]

    # ------ Adjustment --------
    fb['weights'] = getWeight(fb, ntuple_name)
    
    dphi_met_jetterm_tmp = fb['dphi_met_jetterm']
    cond = ak.fill_none(dphi_met_jetterm_tmp == -10, False)
    fb['dphi_met_jetterm'] = ak.where(cond, -999, dphi_met_jetterm_tmp)

    fb['dphi_met_phterm'] = np.arccos(np.cos(fb['dphi_met_phterm']))

    print_cut(ntuple_name, fb, 'after basic')

    test(fb) # check for none value

    print(f"Reading Time for {ntuple_name}: {(time.time()-start_time)} seconds\n")

    tot.append(fb)

    del fb 

# combining 23d + 23e {Zgamma (1, 6), Wgamma (2, 7), gammajet_direct (3, 8)}
# combining 2023 + 2024 {data_y (4, 9), data_eprobe (5, 10)}
tot_tmp = tot
tot = [tot_tmp[0]]
for i in tqdm(range(5)):
    tot.append(ak.concatenate([tot_tmp[i+1], tot_tmp[i+6]]))
ntuple_names = ["ggHyyd", "Zgamma", "Wgamma", "gammajet_direct", "data_y", "data_eprobe"]
del tot_tmp


mc23d_ggHyyd_y Unweighted Events before cut:  17999
mc23d_ggHyyd_y Weighted Events before cut:  1786.539416255438
mc23d_ggHyyd_y Unweighted Events after basic:  2627
mc23d_ggHyyd_y Weighted Events after basic:  266.43555898887837
Number of none values:  0
Reading Time for mc23d_ggHyyd_y: 0.7762739658355713 seconds

mc23d_Zgamma_y Unweighted Events before cut:  2520609
mc23d_Zgamma_y Weighted Events before cut:  15697.116266766878
mc23d_Zgamma_y Unweighted Events after basic:  19478
mc23d_Zgamma_y Weighted Events after basic:  191.54357598716345
Number of none values:  0
Reading Time for mc23d_Zgamma_y: 17.110783100128174 seconds

mc23d_Wgamma_y Unweighted Events before cut:  685525
mc23d_Wgamma_y Weighted Events before cut:  16946.649253377054
mc23d_Wgamma_y Unweighted Events after basic:  13933
mc23d_Wgamma_y Weighted Events after basic:  386.78848750579306
Number of none values:  0
Reading Time for mc23d_Wgamma_y: 4.599167346954346 seconds

mc23d_gammajet_direct_y Unweighted Events b

100%|██████████| 5/5 [00:03<00:00,  1.32it/s]


In [2]:
def compute_total_significance(tot2, ntuple_names, signal_name, getVarDict):
    signal_sum = 0
    bkg_sum = 0
    for i in range(len(ntuple_names)):
        fb = tot2[i]
        process = ntuple_names[i]
        weights = fb['weights']
        if process == signal_name:
            signal_sum += ak.sum(weights)
        else:
            bkg_sum += ak.sum(weights)
    return signal_sum / np.sqrt(bkg_sum) if bkg_sum > 0 else 0

sig_tmp = compute_total_significance(tot, ntuple_names, signal_name, getVarDict)
print("significance: ", sig_tmp)

significance:  1.5690984893427113


In [3]:
cuts1 = [
    {'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1},
    {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 7},
    {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.75},
    {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.25},
    {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -20000},
    {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.5},
    {'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.75},
]

tot2 = apply_all_cuts(tot, ntuple_names, cuts1, getVarDict)
sig_tmp = compute_total_significance(tot2, ntuple_names, signal_name, getVarDict)
print("significance: ", sig_tmp)

significance:  2.665468845273705


In [24]:
cuts1 = [
    {'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1},
    {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 6},
    {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.75},
    {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.25},
    {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -10000},
    {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.5},
    {'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.75},
]

tot2 = apply_all_cuts(tot, ntuple_names, cuts1, getVarDict)
sig_tmp = compute_total_significance(tot2, ntuple_names, signal_name, getVarDict)


for cut in cuts1:
    var = cut['cut_var']
    val = cut['best_cut']
    if cut['cut_type'] == 'uppercut':
        print(f"{var} < {val}")
    elif cut['cut_type'] == 'lowercut':
        print(f"{var} > {val}")
        
# print('after optimized cutting, signficance: ', final_significance)


print("significance: ", sig_tmp)

VertexBDTScore > 0.1
metsig > 6
ph_eta < 1.75
dphi_met_phterm > 1.25
dmet > -10000
dphi_jj < 2.5
dphi_met_jetterm < 0.75
significance:  2.5812533339885784


In [44]:
def getCutDict(): # same cut as the internal note
    cut_dict = {}
    
    cut_dict['VertexBDTScore'] = {
        'lowercut': np.arange(0.1, 0.36, 0.02),  # VertexBDTScore > cut
    }
    cut_dict['dmet'] = {
        'lowercut': np.arange(-30000, 10000 + 5000, 5000), # dmet > cut
    }
    cut_dict['metsig'] = {
        'lowercut': np.arange(0, 10 + 1, 1), # metsig > cut
    }
    cut_dict['dphi_met_phterm'] = {
        'lowercut': np.arange(0, 2 + 0.05, 0.05), # dphi_met_phterm > cut
    }
    cut_dict['dphi_met_jetterm'] = {
        'uppercut': np.arange(0.5, 1, 0.05), # dphi_met_jetterm < cut
    }
    cut_dict['ph_eta'] = {
        'uppercut': np.arange(1, 2.5 + 0.05, 0.05), # ph_eta < cut
    }
    cut_dict['dphi_jj'] = {
        'uppercut': np.arange(1, 3.1 + 0.05, 0.05) # dphi_jj < cut
    }
    cut_dict['balance'] = {
        'lowercut': np.arange(0.3, 1.5 + 0.05, 0.05), # balance > cut
    }
    cut_dict['jetterm'] = {
        'lowercut': np.arange(0, 150000+10000, 10000) # jetterm > cut
    }
    cut_dict['dphi_phterm_jetterm'] = {
        'lowercut': np.arange(1, 2.5 + 0.1, 0.1), # dphi_phterm_jetterm > cut
        'uppercut': np.arange(2, 4 + 0.1, 0.1) # dphi_phterm_jetterm < cut
    }
    cut_dict['metsigres'] = {
        'uppercut': np.arange(12000, 60000, 10000)
    }
    cut_dict['met_noJVT'] = {
        'lowercut': np.arange(50000, 120000, 10000),
    }
    return cut_dict
cut_config = getCutDict()

'''
def getCutDict():
    cut_dict = {}
    # Selection 1: same variables as in the internal note
    cut_dict['dmet'] = {
        'lowercut': np.arange(-30000, 10000 + 100, 100), # dmet > cut
        'uppercut': np.arange(10000, 100000 + 100, 100), # -10000 < dmet < cut
    }
    cut_dict['metsig'] = {
        'lowercut': np.arange(0, 10 + 1, 1), # metsig > cut
        'uppercut': np.arange(10, 30 + 1, 1), # metsig < cut 
    }
    cut_dict['dphi_met_phterm'] = {
        'lowercut': np.arange(1, 2 + 0.01, 0.01), # dphi_met_phterm > cut
    }
    cut_dict['dphi_met_jetterm'] = {
        'uppercut': np.arange(0.5, 1, 0.01), # dphi_met_jetterm < cut
    }
    cut_dict['ph_eta'] = {
        'uppercut': np.arange(1, 2.5 + 0.01, 0.01), # ph_eta < cut
    }
    cut_dict['dphi_jj'] = {
        'uppercut': np.arange(1, 3.14 + 0.01, 0.01) # dphi_jj < cut
    }

    # Selection 2
    cut_dict['balance'] = {
        'lowercut': np.arange(0.3, 1.5 + 0.05, 0.05), # balance > cut
    }
    cut_dict['jetterm'] = {
        'lowercut': np.arange(0, 150000+10000, 10000) # jetterm > cut
    }
    cut_dict['dphi_phterm_jetterm'] = {
        'lowercut': np.arange(1, 2.5 + 0.1, 0.1), # dphi_phterm_jetterm > cut
        'uppercut': np.arange(2, 4 + 0.1, 0.1) # dphi_phterm_jetterm < cut
    }
    cut_dict['metsigres'] = {
        'uppercut': np.arange(12000, 60000, 10000)
    }
    cut_dict['met_noJVT'] = {
        'lowercut': np.arange(50000, 120000, 10000),
    }
    
    return cut_dict
cut_config = getCutDict()
'''


"\ndef getCutDict():\n    cut_dict = {}\n    # Selection 1: same variables as in the internal note\n    cut_dict['dmet'] = {\n        'lowercut': np.arange(-30000, 10000 + 100, 100), # dmet > cut\n        'uppercut': np.arange(10000, 100000 + 100, 100), # -10000 < dmet < cut\n    }\n    cut_dict['metsig'] = {\n        'lowercut': np.arange(0, 10 + 1, 1), # metsig > cut\n        'uppercut': np.arange(10, 30 + 1, 1), # metsig < cut \n    }\n    cut_dict['dphi_met_phterm'] = {\n        'lowercut': np.arange(1, 2 + 0.01, 0.01), # dphi_met_phterm > cut\n    }\n    cut_dict['dphi_met_jetterm'] = {\n        'uppercut': np.arange(0.5, 1, 0.01), # dphi_met_jetterm < cut\n    }\n    cut_dict['ph_eta'] = {\n        'uppercut': np.arange(1, 2.5 + 0.01, 0.01), # ph_eta < cut\n    }\n    cut_dict['dphi_jj'] = {\n        'uppercut': np.arange(1, 3.14 + 0.01, 0.01) # dphi_jj < cut\n    }\n\n    # Selection 2\n    cut_dict['balance'] = {\n        'lowercut': np.arange(0.3, 1.5 + 0.05, 0.05), # balance 

In [50]:
signal_name='ggHyyd'
initial_cut = []
# tot2 = tot  # return the initial cut

# < -- Initial Cut on all variables (maximize the significance * acceptance) -- > 
for cut_var, cut_types in cut_config.items():
    for cut_type, cut_values in cut_types.items():
        sig_simple_list, sigacc_simple_list, acceptance_values = calculate_significance(
            tot, ntuple_names, getVarDict, cut_var, cut_type, cut_values
        )

        best_cut, best_sig, idx = get_best_cut(cut_values, sigacc_simple_list) 
        
        if idx == 0 or idx == len(sigacc_simple_list) - 1: # I chose to use index to indicate not to make unnecessary cut (for initial cut)
            print(cut_var, idx, len(sigacc_simple_list))
            continue
            
        result = {
            "cut_var": cut_var,
            "cut_type": cut_type,
            "best_cut": best_cut,
            "best_sig_x_acc": best_sig,
            "significance": sig_simple_list[idx],
            "acceptance": acceptance_values[idx]
        }

        print(result)
        initial_cut.append(dict(list(result.items())[:3]))

VertexBDTScore 0 13
{'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -20000, 'best_sig_x_acc': 1.6198364927642324, 'significance': 1.6342900266183504, 'acceptance': 99.11560777960415}
{'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 6, 'best_sig_x_acc': 2.0972604819261904, 'significance': 2.2788018308797002, 'acceptance': 92.03347362225752}
{'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.05, 'best_sig_x_acc': 1.6135286667758535, 'significance': 1.6401586619859487, 'acceptance': 98.37637688186514}
{'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.8500000000000003, 'best_sig_x_acc': 1.5696778160158058, 'significance': 1.5696768086370052, 'acceptance': 100.00006417746603}
{'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 2.4000000000000012, 'best_sig_x_acc': 1.5690994963503613, 'significance': 1.5690984893427113, 'acceptance': 100.00006417746603}
dphi_jj 42 43
{'cut_var': 'balance', 'cut_type': 'lowercut', 'best_cut': 0.8999

In [51]:
initial_cut.append({'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1})

In [10]:
# my cut
initial_cuts = [
    {'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1},
    {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 6},
    {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.75},
    {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.25},
    {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -10000},
    {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.5},
    {'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.75},
]


# internal note cut (just in case)
# initial_cuts = [
#     {'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1},
#     {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 6},
#     {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.75},
#     {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.25},
#     {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -20000},
#     {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.5},
#     {'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.75},
# ]

In [27]:
tot2_initial_cut = apply_all_cuts(tot, ntuple_names, initial_cut, getVarDict)
final_significance = compute_total_significance(tot2_initial_cut, ntuple_names, signal_name, getVarDict)
print('after initial cutting, signficance: ', final_significance)

after initial cutting, signficance:  2.390334094119089


In [52]:
%%time

# < -- n-1 iterations until no further improvement (max significance) -- >
optimized_cuts, final_significance = n_minus_1_optimizer(
    initial_cut, cut_config, tot, ntuple_names, signal_name, getVarDict, final_significance, allow_drop=False
)
print('after optimized cutting, signficance: ', final_significance)



--- Iteration 1 ---
Updating dmet (lowercut): -20000 → -25000  (N-1 2.387 → with-cut 2.393)
Updating metsig (lowercut): 6 → 8  (N-1 2.290 → with-cut 2.482)
Updating dphi_met_phterm (lowercut): 1.05 → 1.25  (N-1 2.466 → with-cut 2.503)
Updating ph_eta (uppercut): 2.4000000000000012 → 1.7500000000000007  (N-1 2.503 → with-cut 2.603)
Updating balance (lowercut): 0.8999999999999999 → 0.9999999999999998  (N-1 2.540 → with-cut 2.662)
Updating metsigres (uppercut): 22000 → 42000  (N-1 2.747 → with-cut 2.747)
Updating met_noJVT (lowercut): 100000 → 90000  (N-1 2.745 → with-cut 2.751)
Updating VertexBDTScore (lowercut): 0.1 → 0.22000000000000003  (N-1 2.751 → with-cut 2.829)

--- Iteration 2 ---
Updating dphi_met_phterm (lowercut): 1.25 → 1.35  (N-1 2.771 → with-cut 2.830)
Updating jetterm (lowercut): 80000 → 90000  (N-1 2.830 → with-cut 2.830)

--- Iteration 3 ---
optimized cuts, end of iteration
after optimized cutting, signficance:  2.830240585331781
CPU times: user 36 s, sys: 116 ms, total

In [53]:
print( ' < -- Final Optimized Cuts -- > ')
# print(optimized_cuts)

for cut in optimized_cuts:
    var = cut['cut_var']
    val = cut['best_cut']
    if cut['cut_type'] == 'uppercut':
        print(f"{var} < {val}")
    elif cut['cut_type'] == 'lowercut':
        print(f"{var} > {val}")
        
print('after optimized cutting, signficance: ', final_significance)

 < -- Final Optimized Cuts -- > 
dmet > -25000.0
metsig > 8.0
dphi_met_phterm > 1.35
dphi_met_jetterm < 0.8500000000000003
ph_eta < 1.7500000000000007
balance > 0.9999999999999998
jetterm > 90000.0
metsigres < 42000.0
met_noJVT > 90000.0
VertexBDTScore > 0.22000000000000003
after optimized cutting, signficance:  2.830240585331781


In [40]:
signal_name = 'ggHyyd'

# --- helpers ---
def weight_sum(fb, ntuple_name):
    return ak.sum(fb['weights'])

def s_over_sqrt_b(S, B):
    return S/np.sqrt(B) if B > 0 else 0.0

def zbi(S, B, sigma_b_frac=0.30):
    # Binomial significance with background uncertainty
    if B <= 0:
        return 0.0
    tau   = 1.0 / (B * sigma_b_frac * sigma_b_frac)
    n_on  = S + B
    n_off = B * tau
    P_Bi  = betainc(n_on, n_off + 1, 1.0 / (1.0 + tau))
    if P_Bi <= 0:
        return 0.0
    return float(norm.ppf(1.0 - P_Bi))

# Minimal, branchless Δφ in [0, π]
def dphi(a, b):
    return np.abs((a - b + np.pi) % (2*np.pi) - np.pi)

# --- define the "further selection" cuts as functions that return a boolean mask ---
def cut_balance_gt_0p91(fb):
    sumet_tmp = fb['jet_central_vecSumPt']
    expr = (fb['met_tst_et'] + ak.firsts(fb['ph_pt'])) / ak.where(sumet_tmp != 0, sumet_tmp, 1)
    balance = ak.where(sumet_tmp != 0, expr, -999) 
    return (balance == -999) | (balance > 1)

def cut_met_jetterm_et_gt_92GeV(fb):
    return fb['met_jetterm_et'] > 80_000

def cut_dphi_phterm_jetterm_window(fb, low=1.6, high=3.1):
    # angle defined only if met_jetterm_et>0; missing → pass
    cond  = fb['met_jetterm_et'] > 0
    angle = ak.where(cond, dphi(fb['met_phterm_phi'], fb['met_jetterm_phi']), 0.0)
    return (~cond) | ((angle > low) & (angle < high))

def cut_metsigres_lt_36GeV(fb):
    # metsigres = MET / METsig  (in MeV; 36 GeV = 36000 MeV)
    # Guard against nonpositive sig; if sig<=0, we conservatively fail the cut.
    sig_ok = fb['met_tst_sig'] > 0
    ratio  = ak.where(sig_ok, fb['met_tst_et'] / fb['met_tst_sig'], np.inf)
    return ratio < 42_000

def cut_met_noJVT_gt_90GeV(fb):
    return fb['met_tst_noJVT_et'] > 100_000

# Bundle all the single-cut masks you want to test
CUTS = [
    ("balance>1",              cut_balance_gt_0p91),
    ("met_jetterm_et>80GeV",      cut_met_jetterm_et_gt_92GeV),
    # ("1.6<dphi(phterm,jetterm)<3.1", cut_dphi_phterm_jetterm_window),
    ("metsigres<42GeV",           cut_metsigres_lt_36GeV),
    ("met_noJVT>100GeV",           cut_met_noJVT_gt_90GeV),
]

def significance_by_single_cuts(tot, ntuple_names, signal_name="ggHyyd", sigma_b_frac=0.30, include_zbi=True):
    """
    For each cut in CUTS: apply only that cut on top of the baseline 'tot',
    then compute S, B, S/sqrt(B) (and ZBi).
    Returns a pandas DataFrame.
    """
    # Precompute per-sample weights once (so we don't rebuild weights inside each cut)
    weights = []
    for i, fb in enumerate(tot):
        name = ntuple_names[i]
        w = fb['weights']
        weights.append(w)

    rows = []
    # Also compute baseline (no extra cut) significance for reference
    S0 = 0.0
    B0 = 0.0
    for i, fb in enumerate(tot):
        name = ntuple_names[i]
        wsum = float(ak.sum(weights[i]))
        if name == signal_name:
            S0 += wsum
        else:      # exclude data from B
            B0 += wsum
    base = {
        "cut": "(no extra cut)",
        "S": S0,
        "B": B0,
        "S/sqrt(B)": s_over_sqrt_b(S0, B0)
    }
    if include_zbi:
        base["ZBi(30%)"] = zbi(S0, B0, sigma_b_frac)
    rows.append(base)

    # Now evaluate each single cut
    for label, cut_fn in CUTS:
        S = 0.0
        B = 0.0
        for i, fb in enumerate(tot):
            name = ntuple_names[i]
            m = cut_fn(fb)                      # boolean mask (Awkward-friendly)
            w = weights[i][m]                   # apply mask to precomputed event weights
            wsum = float(ak.sum(w))
            if name == signal_name:
                S += wsum
            else:
                B += wsum

        row = {
            "cut": label,
            "S": S,
            "B": B,
            "S/sqrt(B)": s_over_sqrt_b(S, B)
        }
        if include_zbi:
            row["ZBi(30%)"] = zbi(S, B, sigma_b_frac)
        rows.append(row)

    df = pd.DataFrame(rows)
    # Nicely formatted view (optional)
    with pd.option_context('display.float_format', '{:,.3f}'.format):
        print(df.to_string(index=False))
    return df


initial_cuts = [
    {'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1},
    {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 7},
    {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.75},
    {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.2},
    {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -15000},
    {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.35},
    {'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.85},
]
tot2_initial_cut = apply_all_cuts(tot, ntuple_names, initial_cuts, getVarDict)

# --- RUN IT ---
sig_table = significance_by_single_cuts(tot2_initial_cut, ntuple_names, signal_name='ggHyyd', sigma_b_frac=0.30, include_zbi=True)


                 cut       S         B  S/sqrt(B)  ZBi(30%)
      (no extra cut) 180.522 4,496.285      2.692    -0.065
           balance>1 177.398 3,956.804      2.820    -0.050
met_jetterm_et>80GeV 180.522 4,491.873      2.693    -0.065
     metsigres<42GeV 180.522 4,488.710      2.694    -0.065
    met_noJVT>100GeV 180.011 4,456.464      2.697    -0.065


In [55]:
optimized_cuts

[{'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -25000.0},
 {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 8.0},
 {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.35},
 {'cut_var': 'dphi_met_jetterm',
  'cut_type': 'uppercut',
  'best_cut': 0.8500000000000003},
 {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.7500000000000007},
 {'cut_var': 'balance',
  'cut_type': 'lowercut',
  'best_cut': 0.9999999999999998},
 {'cut_var': 'jetterm', 'cut_type': 'lowercut', 'best_cut': 90000.0},
 {'cut_var': 'metsigres', 'cut_type': 'uppercut', 'best_cut': 42000.0},
 {'cut_var': 'met_noJVT', 'cut_type': 'lowercut', 'best_cut': 90000.0},
 {'cut_var': 'VertexBDTScore',
  'cut_type': 'lowercut',
  'best_cut': 0.22000000000000003}]

In [56]:
optimized_cuts2 = [{'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -15000.0},
 {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 7.0},
 {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.20},
 {'cut_var': 'dphi_met_jetterm',
  'cut_type': 'uppercut',
  'best_cut': 0.8500000000000003},
 {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.7500000000000007},
 {'cut_var': 'balance',
  'cut_type': 'lowercut',
  'best_cut': 0.9999999999999998},
 {'cut_var': 'jetterm', 'cut_type': 'lowercut', 'best_cut': 80000.0},
 {'cut_var': 'metsigres', 'cut_type': 'uppercut', 'best_cut': 42000.0},
 {'cut_var': 'met_noJVT', 'cut_type': 'lowercut', 'best_cut': 100000.0},
 {'cut_var': 'VertexBDTScore',
  'cut_type': 'lowercut',
  'best_cut': 0.1000000000000003}]

In [57]:
tot2_optimized_cuts = apply_all_cuts(tot, ntuple_names, optimized_cuts2, getVarDict)

print('< -- Sum of weight each process -- >')

for i in range(len(tot2_optimized_cuts)):
    print(ntuple_names[i], ak.sum(tot2_optimized_cuts[i]['weights']))

< -- Sum of weight each process -- >
ggHyyd 186.12543
Zgamma 527.19464
Wgamma 1009.02875
gammajet_direct 42.67856
data_y 1054.260999999995
data_eprobe 1854.7567422406833


In [66]:
initial_cuts = [
    {'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1},
    {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 6},
    {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.75},
    {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.25},
    {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -10000},
    {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.5},
    {'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.75},
]
print('< -- Sum of weight each process -- >')
tot_tmp = apply_all_cuts(tot, ntuple_names, initial_cuts, getVarDict)
for i in range(len(tot)):
    print(ntuple_names[i], ak.sum(tot_tmp[i]['weights']))
print(f'significance : {compute_total_significance(tot_tmp, ntuple_names, signal_name, getVarDict)}')

< -- Sum of weight each process -- >
ggHyyd 188.84154
Zgamma 497.98044
Wgamma 980.3159
gammajet_direct 109.93676
data_y 1708.3380000000134
data_eprobe 2055.646897520867
significance : 2.5812533339885784


In [71]:
# cuts = [
#     {'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1},
#     {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 7},
#     {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.75},
#     {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.20},
#     {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -15000},
#     {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.35},
#     {'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.85},
# ]
cuts = [{'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -15000.0},
 {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 7.0},
 {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.20},
 {'cut_var': 'dphi_met_jetterm',
  'cut_type': 'uppercut',
  'best_cut': 0.8500000000000003},
 {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.7500000000000007},
 {'cut_var': 'balance',
  'cut_type': 'lowercut',
  'best_cut': 0.9999999999999998},
 {'cut_var': 'jetterm', 'cut_type': 'lowercut', 'best_cut': 80000.0},
 {'cut_var': 'metsigres', 'cut_type': 'uppercut', 'best_cut': 42000.0},
 {'cut_var': 'met_noJVT', 'cut_type': 'lowercut', 'best_cut': 100000.0},
 {'cut_var': 'VertexBDTScore',
  'cut_type': 'lowercut',
  'best_cut': 0.1000000000000003}]
print('< -- Sum of weight each process -- >')
tot_tmp = apply_all_cuts(tot, ntuple_names, cuts, getVarDict)
for i in range(len(tot)):
    print(ntuple_names[i], ak.sum(tot_tmp[i]['weights']))
print(f'significance : {compute_total_significance(tot_tmp, ntuple_names, signal_name, getVarDict)}')

< -- Sum of weight each process -- >
ggHyyd 186.12543
Zgamma 527.19464
Wgamma 1009.02875
gammajet_direct 42.67856
data_y 1054.260999999995
data_eprobe 1854.7567422406833
significance : 2.7783257837777087


In [72]:
# path for plot storage
mt_val_dir = 'mt100_140'

# cut_name = 'basic'
# plot_performance(tot, ntuple_names, sample_dict, getVarDict, zbi, mt_val_dir, cut_name) # basic

cut_name = 'selection2'
plot_performance(tot_tmp, ntuple_names, sample_dict, getVarDict, zbi, mt_val_dir, cut_name) # selection

successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/n_ph.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/n_ph_baseline.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/n_el.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/n_el_baseline.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/n_mu_baseline.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/n_tau_baseline.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/mt.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/metsig.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/metsigres.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/met.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/met_noJVT.png
successfully saved to /home/jlai/dark_pho

  return impl(*broadcasted_args, **(kwargs or {}))


successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/central_jets_fraction.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/balance.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/dphi_jj.png
successfully saved to /home/jlai/dark_photon/main/mt100_140/selection2cut/VertexBDTScore.png


In [10]:
def sel(tot, n_1_name=None):
    """
    Apply baseline cuts to all fb in tot except the variable named by n_1_name.
    """
    import awkward as ak
    out = []
    for i in range(len(tot)):
        fb2 = tot[i]

        if n_1_name != "VertexBDTScore":
            VertexBDTScore_tmp = fb2['VertexBDTScore']
            fb2 = fb2[VertexBDTScore_tmp > 0.10]

        if n_1_name != "metsig":
            metsig_tmp = fb2['met_tst_sig']
            fb2 = fb2[(metsig_tmp > 7)]

        if n_1_name != "ph_eta":
            ph_eta_tmp = np.abs(ak.firsts(fb2['ph_eta']))
            fb2 = fb2[ph_eta_tmp < 1.75]

        if n_1_name != "dphi_met_phterm":
            dphi_met_phterm_tmp = fb2['dphi_met_phterm']
            fb2 = fb2[dphi_met_phterm_tmp > 1.20]
            
        if n_1_name != "dmet":
            dmet_tmp = fb2['dmet']
            fb2 = fb2[(dmet_tmp > -15000)]

        if n_1_name != "dphi_met_jetterm":
            dphi_met_jetterm_tmp = fb2['dphi_met_jetterm']
            fb2 = fb2[dphi_met_jetterm_tmp <= 0.85]

        if n_1_name != "dphi_jj":
            dphi_jj_tmp = fb2['dphi_central_jj']
            dphi_jj_tmp = ak.where(dphi_jj_tmp == -10, np.nan, dphi_jj_tmp)
            dphi_jj_tmp = np.arccos(np.cos(dphi_jj_tmp))
            dphi_jj_tmp = ak.where(np.isnan(dphi_jj_tmp), -999, dphi_jj_tmp)
            fb2 = fb2[dphi_jj_tmp < 2.35]

        if n_1_name != "balance":
            sumet_tmp = fb['jet_central_vecSumPt']
            expr = (fb['met_tst_et'] + ak.firsts(fb['ph_pt'])) / ak.where(sumet_tmp != 0, sumet_tmp, 1)
            balance_tmp = ak.where(sumet_tmp != 0, expr, -999) 
            mask_nan = balance_tmp == -999
            mask = mask_nan | (balance_tmp > 0.1)
            fb2 = fb2[mask]
            
        if n_1_name != "jetterm":
            fb2 = fb2[fb2['met_jetterm_et'] > 80_000]
            
        if n_1_name != "metsigres":
            sig_ok = fb2['met_tst_sig'] > 0
            ratio  = ak.where(sig_ok, fb2['met_tst_et'] / fb2['met_tst_sig'], np.inf)
            fb2 = fb2[ratio < 42_000]
            
        if n_1_name != "met_noJVT":
            fb2 = fb2[fb2['met_tst_noJVT_et'] > 90_000]

        out.append(fb2)
    return out

def getCutDict(n_1_name=None):
    cut_dict = {}
    if n_1_name is None or n_1_name == "VertexBDTScore":
        cut_dict['VertexBDTScore'] = {'lowercut': np.arange(0.10, 0.24, 0.02)}
    if n_1_name is None or n_1_name == "dmet":
        cut_dict['dmet'] = {'lowercut': np.arange(-30000, 10000 + 5000, 5000)}
    if n_1_name is None or n_1_name == "metsig":
        cut_dict['metsig'] = {'lowercut': np.arange(0, 10 + 1, 1)}
    if n_1_name is None or n_1_name == "dphi_met_phterm":
        cut_dict['dphi_met_phterm'] = {'lowercut': np.arange(1, 2 + 0.05, 0.05)}
    if n_1_name is None or n_1_name == "dphi_met_jetterm":
        cut_dict['dphi_met_jetterm'] = {'uppercut': np.arange(0.5, 1.00, 0.05)}
    if n_1_name is None or n_1_name == "ph_eta":
        cut_dict['ph_eta'] = {'uppercut': np.arange(1.0, 2.50 + 0.05, 0.05)}
    if n_1_name is None or n_1_name == "dphi_jj":
        cut_dict['dphi_jj'] = {'uppercut': np.arange(1.0, 3.10 + 0.05, 0.05)}
    if n_1_name is None or n_1_name == "balance":
        cut_dict['balance'] = {'lowercut': np.arange(0.3, 1.5 + 0.05, 0.05)}
    if n_1_name is None or n_1_name == "jetterm":
        cut_dict['jetterm'] = {'lowercut': np.arange(0, 150000+10000, 10000)}
    if n_1_name is None or n_1_name == "metsigres":
        cut_dict['metsigres'] = {'uppercut': np.arange(12000, 60000, 10000)}
    if n_1_name is None or n_1_name == "met_noJVT":
        cut_dict['met_noJVT'] = {'lowercut': np.arange(50000, 120000, 10000)}
    return cut_dict


def calculate_significance2(tot, ntuple_names, getVarDict, zbi, cut_var, cut_type, cut_values, signal_name="ggHyyd"):
    
    sig_simple_list = []
    sig_s_plus_b_list = []
    sig_s_plus_1p3b_list = []
    sig_binomial_list = []

    sigacc_simple_list = []
    sigacc_s_plus_b_list = []
    sigacc_s_plus_1p3b_list = []
    sigacc_binomial_list = []

    acceptance_values = []  # Store acceptance percentages

    for cut in cut_values:
        sig_after_cut = 0
        bkg_after_cut = []
        sig_events = 0
        
        for i in range(len(ntuple_names)):
            fb = tot[i]
            process = ntuple_names[i]
            var_config = getVarDict(fb, process, var_name=cut_var)
            x = var_config[cut_var]['var']
            mask_nan = x == -999
            
            if process == signal_name:
                sig_events = fb['weights']
                mask_cut = x > cut if cut_type == 'lowercut' else x < cut
                mask = mask_nan | mask_cut
                sig_after_cut = ak.sum(sig_events[mask])
            else:
                bkg_events = fb['weights']
                mask_cut = x > cut if cut_type == 'lowercut' else x < cut
                mask = mask_nan | mask_cut
                bkg_after_cut.append(ak.sum(bkg_events[mask]))

       # Now compute different types of significance
        total_bkg = sum(bkg_after_cut)
        total_signal = sig_after_cut

        # Avoid zero division carefully
        if total_bkg > 0:
            sig_simple = total_signal / np.sqrt(total_bkg)
            sig_s_plus_b = total_signal / np.sqrt(total_signal + total_bkg) if (total_signal + total_bkg) > 0 else 0
            sig_s_plus_1p3b = total_signal / np.sqrt(total_signal + 1.3 * total_bkg) if (total_signal + 1.3*total_bkg) > 0 else 0
            sig_binomial = zbi(total_signal, total_bkg, sigma_b_frac=0.3)
        else:
            sig_simple = sig_s_plus_b = sig_s_plus_1p3b = sig_binomial = 0

        # Acceptance
        acceptance = total_signal / sum(sig_events) if sum(sig_events) > 0 else 0
        acceptance_values.append(acceptance * 100)  # percentage

        # Save significance
        sig_simple_list.append(sig_simple)
        sig_s_plus_b_list.append(sig_s_plus_b)
        sig_s_plus_1p3b_list.append(sig_s_plus_1p3b)
        sig_binomial_list.append(sig_binomial)

        # Save significance × acceptance
        sigacc_simple_list.append(sig_simple * acceptance)
        sigacc_s_plus_b_list.append(sig_s_plus_b * acceptance)
        sigacc_s_plus_1p3b_list.append(sig_s_plus_1p3b * acceptance)
        sigacc_binomial_list.append(sig_binomial * acceptance)

    return (sig_simple_list, sig_s_plus_b_list, sig_s_plus_1p3b_list, sig_binomial_list,
            sigacc_simple_list, sigacc_s_plus_b_list, sigacc_s_plus_1p3b_list, sigacc_binomial_list,
            acceptance_values)


# --- config ---
mt_val_dir = 'mt100_140'
n_1_config = ["VertexBDTScore", "metsig", "ph_eta", "dmet", "dphi_jj", "dphi_met_phterm", "dphi_met_jetterm", "balance", "jetterm", "metsigres", "met_noJVT"]
signal_name = 'ggHyyd'
cut_name = 'n-1' 


def plot_n_1(tot, ntuple_names, sample_dict, getVarDict, zbi, getCutDict, sel, mt_val_dir, n_1_config, cut_name='n-1', signal_name="ggHyyd"):

    def ensure_dir(path_str):
        Path(path_str).mkdir(parents=True, exist_ok=True)
    
    def to_np(a):
        """Flatten awkward/numpy to 1D numpy array; empty-safe."""
        if a is None:
            return np.array([])
        try:
            import awkward as ak
            if hasattr(ak, "to_numpy"):
                return ak.to_numpy(ak.flatten(a, axis=None))
        except Exception:
            pass
        return np.asarray(a).ravel()
    
    def safe_concat(list_of_arrays):
        """Concatenate list of arrays safely (possibly empty)."""
        if len(list_of_arrays) == 0:
            return np.array([])
        arrs = [to_np(x) for x in list_of_arrays if x is not None]
        return np.concatenate(arrs) if len(arrs) else np.array([])
    
    def safe_hist(x, bins, w=None):
        if x.size == 0:
            return np.zeros(len(bins)-1, dtype=float), bins
        return np.histogram(x, bins=bins, weights=w)
    
    # ---------- N-1 scan + plots ----------
    
    out_base = f"/home/jlai/dark_photon/main/{mt_val_dir}/{cut_name}cut"
    ensure_dir(out_base)
    
    for cut_var_tmp in n_1_config:
        cut_config = getCutDict(n_1_name=cut_var_tmp)
        tot2 = sel(tot, n_1_name=cut_var_tmp)
    
        # --- Significance vs cut scans for this n-1 variable ---
        for cut_var, cut_types in cut_config.items():
            for cut_type, cut_values in cut_types.items():
                (sig_simple_list, sig_s_plus_b_list, sig_s_plus_1p3b_list, sig_binomial_list,
                 sigacc_simple_list, sigacc_s_plus_b_list, sigacc_s_plus_1p3b_list, sigacc_binomial_list,
                 acceptance_values) = calculate_significance2(tot, ntuple_names, getVarDict, zbi, cut_var, cut_type, cut_values)
    
                fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(8, 10), sharex=True)
    
                # Top: S/sqrt(B) and vertical line at max
                i_max = int(np.argmax(sig_simple_list)) if len(sig_simple_list) else 0
                max_tmp = float(cut_values[i_max]) if len(cut_values) else np.nan
                if len(cut_values):
                    ax_top.axvline(x=max_tmp, color='r', linestyle='--', label=f'Max S/√B at {max_tmp:.2f}')
                ax_top.plot(cut_values, sig_simple_list, marker='o', label='S/√B')
                # Uncomment if you want the other metrics on same plot:
                # ax_top.plot(cut_values, sig_s_plus_b_list, marker='s', label='S/√(S+B)')
                # ax_top.plot(cut_values, sig_s_plus_1p3b_list, marker='^', label='S/√(S+1.3B)')
                # ax_top.plot(cut_values, sig_binomial_list, marker='x', label='BinomialExpZ')
                ax_top.set_ylabel('Significance')
                ax_top.set_title(f'N-1: Significance vs. {cut_var} ({cut_type})')
                ax_top.grid(True)
                ax_top.legend()
    
                # Bottom: (S/√B) × Acceptance
                if len(cut_values):
                    ax_bot.axvline(x=max_tmp, color='r', linestyle='--')
                ax_bot.plot(cut_values, sigacc_simple_list, marker='o', label='(S/√B) × Acceptance')
    
                for i, acc in enumerate(acceptance_values):
                    ax_bot.text(cut_values[i], sigacc_simple_list[i], f'{acc:.1f}%',
                                fontsize=9, ha='right', va='bottom', color='purple')
    
                # Label with pretty var title
                var_cfg_for_label = getVarDict(tot2[0], signal_name, cut_var)
                ax_bot.set_xlabel(var_cfg_for_label[cut_var]['title'])
                ax_bot.set_ylabel('Significance × Acceptance')
                ax_bot.grid(True)
                ax_bot.legend()
    
                plt.tight_layout()
                out_path = f"{out_base}/significance_{cut_var}_{cut_type}.png"
                ensure_dir(Path(out_path).parent.as_posix())
                plt.savefig(out_path)
                print(f"Saved: {out_path}")
                plt.close()
    
        # --- N-1 distributions + per-bin significance & ROC for THIS variable ---
        var_cfg_sig = getVarDict(tot2[0], signal_name, var_name=cut_var_tmp)  # only request this var
        for var in var_cfg_sig:
            bg_vals, bg_wts, bg_cols, bg_labs = [], [], [], []
            sig_vals, sig_wts = [], []
            sig_col = None
            sig_lab = None
    
            # Build stacks
            for j in range(len(ntuple_names)):
                process = ntuple_names[j]
                fb2 = tot2[j]
                var_cfg = getVarDict(fb2, process, var_name=var)
                x = var_cfg[var]['var']
                bins = var_cfg[var]['bins']
                weights = fb2['weights']
    
                info = sample_dict[process]
                if process == signal_name:
                    sig_vals.append(x)
                    sig_wts.append(weights)
                    sig_col = info['color']; sig_lab = info['legend']
                else:
                    bg_vals.append(x); bg_wts.append(weights)
                    bg_cols.append(info['color']); bg_labs.append(info['legend'])
    
            # Convert/concat
            sig_all = safe_concat(sig_vals)
            sig_w_all = safe_concat(sig_wts)
            bg_all = safe_concat(bg_vals)
            bg_w_all = safe_concat(bg_wts)
    
            # Figure / axes
            fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(12, 13), gridspec_kw={'height_ratios':[9,4]})
    
            # Stacked BG + signal outline
            if len(bg_vals):
                ax_top.hist([to_np(v) for v in bg_vals], bins=bins,
                            weights=[to_np(w) for w in bg_wts], color=bg_cols,
                            label=bg_labs, stacked=True)
            if sig_all.size:
                ax_top.hist(sig_all, bins=bins, weights=sig_w_all, color=sig_col,
                            label=sig_lab, histtype='step', linewidth=2)
    
                # Signal error bars
                s_counts, s_edges = safe_hist(sig_all, bins=bins, w=sig_w_all)
                s2, _ = safe_hist(sig_all, bins=bins, w=sig_w_all**2 if sig_w_all.size else None)
                bin_centers = 0.5*(s_edges[:-1] + s_edges[1:])
                s_err = np.sqrt(s2)
                ax_top.errorbar(bin_centers, s_counts, yerr=s_err, fmt='.', linewidth=2,
                                color=sig_col, capsize=0)
    
            ax_top.set_yscale('log')
            ax_top.set_ylim(max(1e-4, 1e-6), 1e11)
            ax_top.set_xlim(bins[0], bins[-1])
            ax_top.minorticks_on()
            ax_top.grid(True, which="both", linestyle="--", linewidth=0.5)
            ax_top.set_ylabel("Events")
            ax_top.legend(ncol=2)
    
            # Per-bin significance curves
            S_counts, _ = safe_hist(sig_all, bins=bins, w=sig_w_all)
            B_counts, _ = safe_hist(bg_all, bins=bins, w=bg_w_all)
    
            sqrt_B = np.sqrt(np.clip(B_counts, 0, None))
            sqrt_SplusB = np.sqrt(np.clip(S_counts + B_counts, 0, None))
            sqrt_Splus1p3B = np.sqrt(np.clip(S_counts + 1.3*B_counts, 0, None))
    
            sig_simple = np.where(B_counts > 0, S_counts / sqrt_B, 0.0)
            sig_s_plus_b = np.where((S_counts + B_counts) > 0, S_counts / sqrt_SplusB, 0.0)
            sig_s_plus_1p3b = np.where((S_counts + 1.3*B_counts) > 0, S_counts / sqrt_Splus1p3B, 0.0)
    
            # Binomial ExpZ per bin
            zbi_per_bin = np.array([zbi(S_counts[i], B_counts[i], sigma_b_frac=0.3) for i in range(len(S_counts))])
    
            bin_centers = 0.5*(bins[:-1] + bins[1:])
    
            # Totals
            S_tot = float(np.sum(S_counts))
            B_tot = float(np.sum(B_counts))
            if B_tot > 0:
                tot_SsqrtB = S_tot / np.sqrt(B_tot)
                tot_SsqrtSB = S_tot / np.sqrt(S_tot + B_tot) if (S_tot + B_tot) > 0 else 0
                tot_SsqrtS1p3B = S_tot / np.sqrt(S_tot + 1.3*B_tot) if (S_tot + 1.3*B_tot) > 0 else 0
                tot_zbi = zbi(S_tot, B_tot, sigma_b_frac=0.3)
            else:
                tot_SsqrtB = tot_SsqrtSB = tot_SsqrtS1p3B = tot_zbi = 0.0
    
            ax_bot.step(bin_centers, sig_simple, where='mid', linewidth=2,
                        label=f"S/√B = {tot_SsqrtB:.4f}", color='chocolate')
            ax_bot.step(bin_centers, sig_s_plus_b, where='mid', linewidth=2,
                        label=f"S/√(S+B) = {tot_SsqrtSB:.4f}", color='tomato')
            ax_bot.step(bin_centers, sig_s_plus_1p3b, where='mid', linewidth=2,
                        label=f"S/√(S+1.3B) = {tot_SsqrtS1p3B:.4f}", color='orange')
            ax_bot.step(bin_centers, zbi_per_bin, where='mid', linewidth=2,
                        label=f"Binomial ExpZ = {tot_zbi:.4f}", color='plum')
    
            ax_bot.set_xlabel(var_cfg_sig[var]['title'])
            ax_bot.set_ylabel("Significance")
            ax_bot.grid(True, which="both", linestyle="--", linewidth=0.5)
            ax_bot.set_title("")
            leg = ax_bot.legend()
            for t in leg.get_texts():
                t.set_color('purple')
    
            plt.xlim(bins[0], bins[-1])
            plt.tight_layout()
    
            dist_dir = f"{out_base}"
            ensure_dir(dist_dir)
            out_png = f"{dist_dir}/{var}.png"
            plt.savefig(out_png)
            print(f"Saved: {out_png}")
            plt.close()
    
            # ROC using this single variable as score
            y_true = np.concatenate([np.ones_like(S_counts).repeat(1)])  # placeholder to silence lints
            # Build event-wise arrays, not binned:
            y_true = np.concatenate([np.ones(sig_all.shape[0], dtype=int), np.zeros(bg_all.shape[0], dtype=int)])
            y_scores = np.concatenate([sig_all, bg_all])
            y_w = np.concatenate([sig_w_all if sig_w_all.size else np.ones(sig_all.shape[0]),
                                  bg_w_all if bg_w_all.size else np.ones(bg_all.shape[0])])
    
            if y_scores.size and np.unique(y_true).size == 2:
                fpr, tpr, thr = roc_curve(y_true, y_scores, sample_weight=y_w)
                order = np.argsort(fpr)
                roc_auc = auc(fpr[order], tpr[order])
    
                plt.figure(figsize=(8, 8))
                plt.plot(fpr, tpr, lw=2, label=f'ROC (AUC = {roc_auc:.5f})')
                plt.plot([0, 1], [0, 1], linestyle='--', label='Random')
                plt.xlabel("False Positive Rate")
                plt.ylabel("True Positive Rate")
                plt.title(f"N-1 ROC — {var}")
                plt.legend(loc="lower right")
                plt.grid(True, which="both", linestyle="--", linewidth=0.5)
                plt.tight_layout()
                out_png = f"{dist_dir}/roc_curve_{var}.png"
                plt.savefig(out_png)
                print(f"Saved: {out_png}")
                plt.close()


plot_n_1(tot, ntuple_names, sample_dict, getVarDict, zbi, getCutDict, sel, mt_val_dir, n_1_config)


Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/significance_VertexBDTScore_lowercut.png


  sig_simple = np.where(B_counts > 0, S_counts / sqrt_B, 0.0)
  sig_s_plus_b = np.where((S_counts + B_counts) > 0, S_counts / sqrt_SplusB, 0.0)
  sig_s_plus_1p3b = np.where((S_counts + 1.3*B_counts) > 0, S_counts / sqrt_Splus1p3B, 0.0)


Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/VertexBDTScore.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/roc_curve_VertexBDTScore.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/significance_metsig_lowercut.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/metsig.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/roc_curve_metsig.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/significance_ph_eta_uppercut.png


  sig_simple = np.where(B_counts > 0, S_counts / sqrt_B, 0.0)
  sig_s_plus_b = np.where((S_counts + B_counts) > 0, S_counts / sqrt_SplusB, 0.0)
  sig_s_plus_1p3b = np.where((S_counts + 1.3*B_counts) > 0, S_counts / sqrt_Splus1p3B, 0.0)


Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/ph_eta.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/roc_curve_ph_eta.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/significance_dmet_lowercut.png


  sig_simple = np.where(B_counts > 0, S_counts / sqrt_B, 0.0)
  sig_s_plus_b = np.where((S_counts + B_counts) > 0, S_counts / sqrt_SplusB, 0.0)
  sig_s_plus_1p3b = np.where((S_counts + 1.3*B_counts) > 0, S_counts / sqrt_Splus1p3B, 0.0)


Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/dmet.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/roc_curve_dmet.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/significance_dphi_jj_uppercut.png


  sig_simple = np.where(B_counts > 0, S_counts / sqrt_B, 0.0)
  sig_s_plus_b = np.where((S_counts + B_counts) > 0, S_counts / sqrt_SplusB, 0.0)
  sig_s_plus_1p3b = np.where((S_counts + 1.3*B_counts) > 0, S_counts / sqrt_Splus1p3B, 0.0)


Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/dphi_jj.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/roc_curve_dphi_jj.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/significance_dphi_met_phterm_lowercut.png


  sig_simple = np.where(B_counts > 0, S_counts / sqrt_B, 0.0)
  sig_s_plus_b = np.where((S_counts + B_counts) > 0, S_counts / sqrt_SplusB, 0.0)
  sig_s_plus_1p3b = np.where((S_counts + 1.3*B_counts) > 0, S_counts / sqrt_Splus1p3B, 0.0)


Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/dphi_met_phterm.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/roc_curve_dphi_met_phterm.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/significance_dphi_met_jetterm_uppercut.png


  sig_simple = np.where(B_counts > 0, S_counts / sqrt_B, 0.0)
  sig_s_plus_b = np.where((S_counts + B_counts) > 0, S_counts / sqrt_SplusB, 0.0)
  sig_s_plus_1p3b = np.where((S_counts + 1.3*B_counts) > 0, S_counts / sqrt_Splus1p3B, 0.0)


Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/dphi_met_jetterm.png
Saved: /home/jlai/dark_photon/main/mt100_140/n-1cut/roc_curve_dphi_met_jetterm.png


In [7]:
from pathlib import Path