In [5]:
# import modules
import uproot, sys, time, math, pickle, os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import awkward as ak
from tqdm import tqdm
import seaborn as sns
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from matplotlib.ticker import FormatStrFormatter
import matplotlib.ticker as ticker
from scipy.special import betainc
from scipy.stats import norm

# import config functions
from jet_faking_plot_config import getWeight, zbi, sample_dict, getVarDict
from plot_var import variables, variables_data, ntuple_names, ntuple_names_BDT

# Set up plot defaults
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = 14.0,10.0  # Roughly 11 cm wde by 8 cm high  
mpl.rcParams['font.size'] = 20.0 # Use 14 point font
sns.set(style="whitegrid")

font_size = {
    "xlabel": 17,
    "ylabel": 17,
    "xticks": 15,
    "yticks": 15,
    "legend": 14
}

plt.rcParams.update({
    "axes.labelsize": font_size["xlabel"],  # X and Y axis labels
    "xtick.labelsize": font_size["xticks"],  # X ticks
    "ytick.labelsize": font_size["yticks"],  # Y ticks
    "legend.fontsize": font_size["legend"]  # Legend
})

In [7]:
%%time
tot = []
data = pd.DataFrame()
unweighted_bcut, weighted_bcut, unweighted_acut, weighted_acut = [], [], [], []
ntuple_names = ['ggHyyd','Zjets','Zgamma','Wgamma','Wjets','gammajet_direct', 'data23']

def test(fb):
    # checking if there are any none values
    mask = ak.is_none(fb['met_tst_et'])
    n_none = ak.sum(mask)
    print("Number of none values: ", n_none)
    # if n_none > 0:
    #     fb = fb[~mask]
    # print("Events after removing none values: ", len(fb), ak.sum(ak.is_none(fb['met_tst_et'])))

def print_cut(ntuple_name, fb, label):
    print(f"Unweighted Events {label}: ", len(fb))
    if ntuple_name == 'data23':
        print(f"Weighted Events {label}: ", sum(getWeight(fb, ntuple_name, jet_faking=True)))
    else: 
        print(f"Weighted Events {label}: ", sum(getWeight(fb, ntuple_name)))

for i in range(len(ntuple_names)):
    ucut, wcut = [], []
    start_time = time.time()
    ntuple_name = ntuple_names[i]
    if ntuple_name == 'data23': # data
        path = f"/data/fpiazza/ggHyyd/Ntuples/MC23d/withVertexBDT/data23_y_BDT_score.root" 
        print('processing file: ', path)
        f = uproot.open(path)['nominal']
        fb = f.arrays(variables_data, library="ak")
        fb['VertexBDTScore'] = fb['BDTScore'] # renaming BDTScore to ensure this is recognized as Vertex BDT Score
        
        fb = fb[ak.num(fb['ph_eta']) > 0]     # for abs(ak.firsts(fb['ph_eta'])) to have value to the reweighting
                
        mask1 = (ak.firsts(fb['ph_topoetcone40'])-2450.)/ak.firsts(fb['ph_pt']) > 0.1   # jet_faking_photon cut
        fb = fb[mask1]
        fb = fb[fb['n_ph_baseline'] == 1]

    else: # MC
        path = f"/data/tmathew/ntups/mc23d/{ntuple_name}_y.root" 
        path_BDT = f"/data/fpiazza/ggHyyd/Ntuples/MC23d/withVertexBDT/mc23d_{ntuple_name}_y_BDT_score.root" 
        print('processing file: ', path)
        f = uproot.open(path)['nominal']
        fb = f.arrays(variables, library="ak")

        # add BDT score to fb
        f_BDT = uproot.open(path_BDT)['nominal']
        fb_BDT = f_BDT.arrays(["event", "BDTScore"], library="ak")
        tmp = fb["event"] == fb_BDT["event"]
        if np.all(tmp) == True:
            fb["VertexBDTScore"] = fb_BDT["BDTScore"]
        else: 
            print("Something is wrong, need arranging")

        fb = fb[ak.num(fb['ph_eta']) > 0]     # for abs(ak.firsts(fb['ph_eta'])) to have value to the reweighting
        fb = fb[fb['n_ph'] == 1]
        
        # Zjets and Wjets (rule out everything except for e->gamma)
        if ntuple_name == 'Zjets' or ntuple_name == 'Wjets':
            mask = ak.firsts(fb['ph_truth_type']) == 2
            fb = fb[mask]
        
        # goodPV on signal only
        if ntuple_name == 'ggHyyd':
            fb = fb[ak.num(fb['pv_z']) > 0]
            good_pv_tmp = (np.abs(ak.firsts(fb['pv_truth_z']) - ak.firsts(fb['pv_z'])) <= 0.5)
            fb = fb[good_pv_tmp]

    print_cut(ntuple_name, fb, 'before cut')
    wcut.append(sum(getWeight(fb, ntuple_name)))

    # fb = fb[fb['n_mu_baseline'] == 0]
    fb = fb[fb['n_mu'] == 1]
    # wcut.append(sum(getWeight(fb, ntuple_name)))
    # fb = fb[fb['n_el_baseline'] == 0]
    # wcut.append(sum(getWeight(fb, ntuple_name)))
    # fb = fb[fb['n_tau_baseline'] == 0]
    # wcut.append(sum(getWeight(fb, ntuple_name)))
    # fb = fb[fb['trigger_HLT_g50_tight_xe40_cell_xe70_pfopufit_80mTAC_L1eEM26M']==1]
    # wcut.append(sum(getWeight(fb, ntuple_name)))
    # fb = fb[ak.num(fb['ph_pt']) > 0] # prevent none values in Tbranch
    # fb = fb[ak.firsts(fb['ph_pt']) >= 50000] # ph_pt cut (basic cut)
    # wcut.append(sum(getWeight(fb, ntuple_name)))
    # fb = fb[fb['met_tst_et'] >= 100000] # MET cut (basic cut)
    # wcut.append(sum(getWeight(fb, ntuple_name)))
    # fb = fb[fb['n_jet_central'] <= 4] # n_jet_central cut (basic cut)
    # wcut.append(sum(getWeight(fb, ntuple_name)))

    # mt_tmp = np.sqrt(2 * fb['met_tst_et'] * ak.firsts(fb['ph_pt']) * 
    #                         (1 - np.cos(fb['met_tst_phi'] - ak.firsts(fb['ph_phi'])))) / 1000
    # mask1 = mt_tmp >= 100 # trigger cut
    # fb = fb[mask1]
    # wcut.append(sum(getWeight(fb, ntuple_name)))

    print_cut(ntuple_name, fb, 'after basic cut')


    ucut.append(len(fb))

    unweighted_acut.append(ucut)
    weighted_acut.append(wcut)
    test(fb) # check for none value

    print(f"Reading Time for {ntuple_name}: {(time.time()-start_time)} seconds\n")



    tot.append(fb)

    fb = 0
    fb_BDT = 0
    tmp = 0


processing file:  /data/tmathew/ntups/mc23d/ggHyyd_uy.root


FileNotFoundError: [Errno 2] No such file or directory: '/data/tmathew/ntups/mc23d/ggHyyd_uy.root'

In [4]:
# Save data after basic cut to a csv file for BDT input
Vars = [
    'balance', 
    'VertexBDTScore',
    'dmet',
    'dphi_jj',
    'dphi_met_central_jet',
    'dphi_met_phterm',
    'dphi_met_ph',
    'dphi_met_jetterm',
    'dphi_phterm_jetterm',
    'dphi_ph_centraljet1',
    'ph_pt',
    'ph_eta',
    'ph_phi',
    'jet_central_eta',
    'jet_central_pt1',
    'jet_central_pt2',
    'jetterm',
    'jetterm_sumet',
    'metsig',
    'metsigres',
    'met',
    'met_noJVT',
    'metplusph',
    'failJVT_jet_pt1',
    'softerm',
    'n_jet_central',
    'mt'
]

data_list = []

for j in range(len(ntuple_names)):
    process = ntuple_names[j]
    fb = tot[j] 
    
    data_dict = {}
    
    for var in Vars:
        var_config = getVarDict(fb, process, var_name=var)
        data_dict[var] = var_config[var]['var']
    
    weights = getWeight(fb, process)
    data_dict['weights'] = weights
    
    n_events = len(weights)
    data_dict['process'] = [process] * n_events
    label = 1 if process == 'ggHyyd' else 0
    data_dict['label'] = [label] * n_events
    
    df_temp = pd.DataFrame(data_dict)
    data_list.append(df_temp)

df_all = pd.concat(data_list, ignore_index=True)
df_all.head()

df_all.to_csv("/data/jlai/ntups/csv/jet_faking_BDT_input_basic3.csv", index=False)


In [3]:
signal_name = 'ggHyyd'  # Define signal dataset
cut_name = 'basic'

def getCutDict():
    cut_dict = {}

    cut_dict['VertexBDTScore'] = {
        'lowercut': np.arange(0, 0.4+0.1, 0.1) # VertexBDTScore > cut
    }
    cut_dict['balance'] = {
        'lowercut': np.arange(0, 1.5 + 0.05, 0.05), # balance > cut
        'uppercut': np.arange(1.5, 6, 0.2) # balance < cut
    }
    cut_dict['dmet'] = {
        'lowercut': np.arange(-30000, 10000 + 5000, 5000), # dmet > cut
        'uppercut': np.arange(10000, 100000 + 5000, 5000), # -10000 < dmet < cut
    }
    cut_dict['dphi_jj'] = {
        'uppercut': np.arange(1, 3.1 + 0.1, 0.1) # dphi_jj < cut
    }
    cut_dict['dphi_met_jetterm'] = {
        'lowercut': np.arange(0, 1 + 0.05, 0.05), # dphi_met_jetterm > cut 
        'uppercut': np.arange(0.5, 2 + 0.05, 0.05), # dphi_met_jetterm < cut 
    }
    cut_dict['dphi_met_phterm'] = {
        'lowercut': np.arange(1, 2 + 0.05, 0.05), # dphi_met_phterm > cut
        'uppercut': np.arange(2, 3.1 + 0.1, 0.1), # dphi_met_phterm < cut
    }
    cut_dict['dphi_ph_centraljet1'] = {
        'lowercut': np.arange(0, 2.5 + 0.1, 0.1), # dphi_ph_centraljet1 > cut
        'uppercut': np.arange(1.5, 3.1 + 0.1, 0.1) # dphi_ph_centraljet1 < cut
    }
    cut_dict['dphi_phterm_jetterm'] = {
        'lowercut': np.arange(1, 2.5 + 0.05, 0.05), # dphi_phterm_jetterm > cut
        'uppercut': np.arange(2, 4 + 0.1, 0.1) # dphi_phterm_jetterm < cut
    }
    cut_dict['met'] = {
        'lowercut': np.arange(100000, 140000 + 5000, 5000),  # met > cut
        'uppercut': np.arange(140000, 300000 + 5000, 5000),  # met < cut
    }
    cut_dict['metsig'] = {
        'lowercut': np.arange(0, 10 + 1, 1), # metsig > cut
        'uppercut': np.arange(10, 30 + 1, 1), # metsig < cut 
    }
    # cut_dict['mt'] = {
    #     'lowercut': np.arange(80, 130+5, 5), # mt > cut
    #     'uppercut': np.arange(120, 230+5, 5) # mt < cut
    # }
    cut_dict['n_jet_central'] = {
        'uppercut': np.arange(0, 8+1, 1) # njet < cut
    }
    cut_dict['ph_eta'] = {
        'uppercut': np.arange(1, 2.5 + 0.05, 0.05), # ph_eta < cut
    }
    cut_dict['ph_pt'] = {
        'lowercut': np.arange(50000, 100000 + 5000, 5000),  # ph_pt > cut
        'uppercut': np.arange(100000, 300000 + 10000, 10000),  # ph_pt > cut
    }

    return cut_dict
cut_config = getCutDict()

In [3]:
signal_name = 'ggHyyd'  # Define signal dataset
cut_name = 'basic'

def getCutDict():
    cut_dict = {}

    cut_dict['dphi_jj'] = {
        'uppercut': np.arange(1, 3.14 + 0.01, 0.01) # dphi_jj < cut
    }
    cut_dict['dphi_phterm_jetterm'] = {
        'lowercut': np.arange(1, 2.5 + 0.05, 0.05), # dphi_phterm_jetterm > cut
        'uppercut': np.arange(2, 4 + 0.1, 0.1) # dphi_phterm_jetterm < cut
    }
    cut_dict['jet_central_eta'] = {
        'lowercut': np.arange(-2.5, 0+0.01, 0.01), # jet_central_eta > cut
        'uppercut': np.arange(0, 2.5+0.01, 0.01) # jet_central_eta < cut
    }
    cut_dict['jet_central_pt2'] = {
        'lowercut': np.arange(20000, 100000+1000, 1000) # jet_central_pt2 > cut
    }
    cut_dict['metsigres'] = {
        'lowercut': np.arange(8600, 15000, 100),
        'uppercut': np.arange(12000, 60000, 100)
        
    }
    cut_dict['met_noJVT'] = {
        'lowercut': np.arange(50000, 120000, 100),
        'uppercut': np.arange(100000, 250000, 100)
    }
    cut_dict['softerm'] = {
        'uppercut': np.arange(10000, 40000, 100)
    }
    cut_dict['n_jet_central'] = {
        'uppercut': np.arange(0, 8+1, 1) # njet < cut
    }

    return cut_dict
cut_config = getCutDict()

In [4]:
def get_best_cut(cut_values, significance_list):
    max_idx = np.argmax(significance_list)
    best_cut = cut_values[max_idx]
    best_sig = significance_list[max_idx]
    return best_cut, best_sig, max_idx

def calculate_significance(cut_var, cut_type, cut_values, tot2, ntuple_names, signal_name, getVarDict, getWeight):
    sig_simple_list = []
    sigacc_simple_list = []
    acceptance_values = []
    tot_tmp = []

    for cut in cut_values:
        sig_after_cut = 0
        bkg_after_cut = []
        sig_events = 0

        for i in range(len(ntuple_names)):
            fb = tot2[i]
            process = ntuple_names[i]
            var_config = getVarDict(fb, process, var_name=cut_var)
            x = var_config[cut_var]['var']
            mask = x != -999
            x = x[mask]

            if process == signal_name:
                sig_events = getWeight(fb, process)
                sig_events = sig_events[mask]
                mask = x >= cut if cut_type == 'lowercut' else x <= cut
                sig_after_cut = ak.sum(sig_events[mask])
            else:
                bkg_events = getWeight(fb, process)
                bkg_events = bkg_events[mask]
                mask = x >= cut if cut_type == 'lowercut' else x <= cut
                bkg_after_cut.append(ak.sum(bkg_events[mask]))

            
            tot_tmp.append(fb)

        total_bkg = sum(bkg_after_cut)
        total_signal = sig_after_cut

        sig_simple = total_signal / np.sqrt(total_bkg) if total_bkg > 0 else 0
        acceptance = total_signal / sum(sig_events) if sum(sig_events) > 0 else 0

        sig_simple_list.append(sig_simple)
        sigacc_simple_list.append(sig_simple * acceptance)
        acceptance_values.append(acceptance * 100)

    return sig_simple_list, sigacc_simple_list, acceptance_values

# os.makedirs("plot_data", exist_ok=True)
initial_cut = []
tot2 = tot

# < -- Initial Cut on all variables (maximize the significance * acceptance) -- > 
for cut_var, cut_types in cut_config.items():
    for cut_type, cut_values in cut_types.items():
        sig_simple_list, sigacc_simple_list, acceptance_values = calculate_significance(
            cut_var, cut_type, cut_values, tot2, ntuple_names, signal_name, getVarDict, getWeight
        )

        best_cut, best_sig, idx = get_best_cut(cut_values, sigacc_simple_list) 
        
        if idx == 0 or idx == len(sigacc_simple_list) - 1: # I chose to use index to indicate not to make unnecessary cut (for initial cut)
            print(cut_var, idx, len(sigacc_simple_list))
            continue
            
        result = {
            "cut_var": cut_var,
            "cut_type": cut_type,
            "best_cut": best_cut,
            "best_sig_x_acc": best_sig,
            "significance": sig_simple_list[idx],
            "acceptance": acceptance_values[idx]
        }

        print(result)
        initial_cut.append(dict(list(result.items())[:3]))

{'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.7600000000000016, 'best_sig_x_acc': 0.3486244955136403, 'significance': 0.3722073510197181, 'acceptance': 93.6640543392093}
{'cut_var': 'dphi_phterm_jetterm', 'cut_type': 'lowercut', 'best_cut': 1.4500000000000004, 'best_sig_x_acc': 0.5118064096607197, 'significance': 0.6000596689628758, 'acceptance': 85.29258607653298}
{'cut_var': 'dphi_phterm_jetterm', 'cut_type': 'uppercut', 'best_cut': 3.000000000000001, 'best_sig_x_acc': 0.3464933370130756, 'significance': 0.3475026838622216, 'acceptance': 99.70954271836754}
jet_central_eta 0 251
jet_central_eta 250 251
jet_central_pt2 0 81
{'cut_var': 'metsigres', 'cut_type': 'lowercut', 'best_cut': 10300, 'best_sig_x_acc': 0.3561096438700701, 'significance': 0.36114757116108814, 'acceptance': 98.60502251896057}
{'cut_var': 'metsigres', 'cut_type': 'uppercut', 'best_cut': 34200, 'best_sig_x_acc': 0.352821986471691, 'significance': 0.36368004520505426, 'acceptance': 97.01439249237853}
{'

In [5]:
def apply_cut_to_fb(fb, process, var, cut_val, cut_type, getVarDict):
    var_config = getVarDict(fb, process, var_name=var)
    x = var_config[var]['var']
    mask = x != -999

    if cut_type == 'lowercut':
        mask = mask & (x >= cut_val)
    elif cut_type == 'uppercut':
        mask = mask & (x <= cut_val)

    return fb[mask]

def apply_all_cuts(tot2, ntuple_names, cut_list, getVarDict):
    new_tot2 = []
    for i, fb in enumerate(tot2):
        process = ntuple_names[i]
        for cut in cut_list:
            fb = apply_cut_to_fb(fb, process, cut["cut_var"], cut["best_cut"], cut["cut_type"], getVarDict)
        new_tot2.append(fb)
    return new_tot2
    
def compute_total_significance(tot2, ntuple_names, signal_name, getVarDict, getWeight):
    signal_sum = 0
    bkg_sum = 0
    for i in range(len(ntuple_names)):
        fb = tot2[i]
        process = ntuple_names[i]
        weights = getWeight(fb, process)
        if process == signal_name:
            signal_sum += ak.sum(weights)
        else:
            bkg_sum += ak.sum(weights)
    return signal_sum / np.sqrt(bkg_sum) if bkg_sum > 0 else 0

tot2_initial_cut = apply_all_cuts(tot2, ntuple_names, initial_cut, getVarDict)
final_significance = compute_total_significance(tot2_initial_cut, ntuple_names, signal_name, getVarDict, getWeight)
print('after initial cutting, signficance: ', final_significance)

after initial cutting, signficance:  0.7898281381978879


In [6]:
def n_minus_1_optimizer(initial_cut, cut_config, tot2, ntuple_names, signal_name, getVarDict, getWeight, final_significance, max_iter=10, tolerance=1e-4):
    best_cuts = initial_cut.copy()
    iteration = 0
    converged = False

    while not converged and iteration < max_iter:
        converged = True
        print(f"\n--- Iteration {iteration + 1} ---")
        for i, cut in enumerate(best_cuts):
            # Apply all other cuts
            n_minus_1_cuts = best_cuts[:i] + best_cuts[i+1:]
            tot2_cut = apply_all_cuts(tot2, ntuple_names, n_minus_1_cuts, getVarDict)

            # Re-scan this variable
            cut_var = cut["cut_var"]
            cut_type = cut["cut_type"]
            cut_values = cut_config[cut_var][cut_type]

            sig_simple_list, sigacc_simple_list, _ = calculate_significance(
                cut_var, cut_type, cut_values, tot2_cut, ntuple_names
                , signal_name, getVarDict, getWeight
            )
            best_cut, best_sig, idx = get_best_cut(cut_values, sig_simple_list)

            if abs(best_cut - cut["best_cut"]) > tolerance:
            # if best_sig - final_significance > tolerance:
                print(f"Updating {cut_var} ({cut_type}): {cut['best_cut']} → {best_cut}  (sig {final_significance:.2f} → {best_sig:.2f})")
                best_cuts[i]["best_cut"] = best_cut
                final_significance = best_sig
                converged = False  # Found at least one improvement

        iteration += 1

    return best_cuts, final_significance

# < -- n-1 iterations until no further improvement (max significance) -- >
optimized_cuts, final_significance = n_minus_1_optimizer(
    initial_cut, cut_config, tot2, ntuple_names, signal_name, getVarDict, getWeight, final_significance
)


--- Iteration 1 ---
Updating dphi_jj (uppercut): 2.7600000000000016 → 1.5500000000000005  (sig 0.79 → 0.82)
Updating dphi_phterm_jetterm (lowercut): 1.4500000000000004 → 1.6500000000000006  (sig 0.82 → 0.86)
Updating dphi_phterm_jetterm (uppercut): 3.000000000000001 → 2.8000000000000007  (sig 0.86 → 0.89)
Updating metsigres (lowercut): 10300 → 10700  (sig 0.89 → 0.89)
Updating metsigres (uppercut): 34200 → 20100  (sig 0.89 → 1.09)
Updating met_noJVT (lowercut): 100000 → 103000  (sig 1.09 → 1.13)
Updating met_noJVT (uppercut): 249100 → 248300  (sig 1.13 → 1.13)

--- Iteration 2 ---
Updating dphi_jj (uppercut): 1.5500000000000005 → 1.5800000000000005  (sig 1.13 → 1.13)
Updating dphi_phterm_jetterm (lowercut): 1.6500000000000006 → 1.8000000000000007  (sig 1.13 → 1.15)
Updating metsigres (lowercut): 10700 → 11500  (sig 1.15 → 1.17)
Updating metsigres (uppercut): 20100 → 19300  (sig 1.17 → 1.17)
Updating met_noJVT (lowercut): 103000 → 103400  (sig 1.17 → 1.18)
Updating met_noJVT (uppercut)

In [7]:
final_significance

1.176777241481177

In [8]:
tot2_optimized_cuts = apply_all_cuts(tot2, ntuple_names, optimized_cuts, getVarDict)
compute_total_significance(tot2_optimized_cuts, ntuple_names, signal_name, getVarDict, getWeight)

1.176777241481177

In [9]:
# < -- Save data after cuts to a csv file for BDT input -- >
Vars = [
    'balance', 
    'VertexBDTScore',
    'dmet',
    'dphi_jj',
    'dphi_met_central_jet',
    'dphi_met_phterm',
    'dphi_met_ph',
    'dphi_met_jetterm',
    'dphi_phterm_jetterm',
    'dphi_ph_centraljet1',
    'ph_pt',
    'ph_eta',
    'ph_phi',
    'jet_central_eta',
    'jet_central_pt1',
    'jet_central_pt2',
    'jetterm',
    'jetterm_sumet',
    'metsig',
    'metsigres',
    'met',
    'met_noJVT',
    'metplusph',
    'failJVT_jet_pt1',
    'softerm',
    'n_jet_central'
]

data_list = []

for j in range(len(ntuple_names)):
    process = ntuple_names[j]
    fb = tot2_optimized_cuts[j] 
    
    data_dict = {}
    
    for var in Vars:
        var_config = getVarDict(fb, process, var_name=var)
        data_dict[var] = var_config[var]['var']
    
    weights = getWeight(fb, process)
    data_dict['weights'] = weights
    
    n_events = len(weights)
    data_dict['process'] = [process] * n_events
    label = 1 if process == 'ggHyyd' else 0
    data_dict['label'] = [label] * n_events
    
    df_temp = pd.DataFrame(data_dict)
    data_list.append(df_temp)

df_all = pd.concat(data_list, ignore_index=True)
df_all.head()

df_all.to_csv("/data/jlai/ntups/csv/jet_faking_BDT_input_basic_reduced.csv", index=False)


In [None]:
def compute_total_significance(tot2, ntuple_names, signal_name, getVarDict, getWeight):
    signal_sum = 0
    bkg_sum = 0
    for i in range(len(ntuple_names)):
        fb = tot2[i]
        process = ntuple_names[i]
        weights = getWeight(fb, process)
        print(process, ak.sum(weights))
        if process == signal_name:
            signal_sum += ak.sum(weights)
        else:
            bkg_sum += ak.sum(weights)
    return signal_sum / np.sqrt(bkg_sum) if bkg_sum > 0 else 0

tot2_cut = apply_all_cuts(tot2, ntuple_names, cut_tmp, getVarDict)
compute_total_significance(tot2_cut, ntuple_names, signal_name, getVarDict, getWeight)


In [None]:
cut_tmp = [{'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.10}, 
 {'cut_var': 'dphi_met_phterm', 'cut_type': 'lowercut', 'best_cut': 1.35}, 
 {'cut_var': 'dphi_met_phterm', 'cut_type': 'uppercut', 'best_cut': 3}, 
 {'cut_var': 'metsig', 'cut_type': 'lowercut', 'best_cut': 6},
 {'cut_var': 'metsig', 'cut_type': 'uppercut', 'best_cut': 16}, 
 {'cut_var': 'ph_eta', 'cut_type': 'uppercut', 'best_cut': 1.75},
 {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -20000},
 {'cut_var': 'dmet', 'cut_type': 'uppercut', 'best_cut': 50000},
 {'cut_var': 'dphi_met_jetterm', 'cut_type': 'lowercut', 'best_cut': 0.05},
 {'cut_var': 'dphi_met_jetterm', 'cut_type': 'uppercut', 'best_cut': 0.65},
 {'cut_var': 'balance', 'cut_type': 'lowercut', 'best_cut': 0.90},
 {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 2.4}]

tot2_cut = apply_all_cuts(tot2, ntuple_names, cut_tmp, getVarDict)
sig_simple_list, sigacc_simple_list, acceptance_values = calculate_significance('VertexBDTScore', 'lowercut', np.arange(0, 0.4+0.1, 0.1), tot2_cut, ntuple_names, signal_name, getVarDict, getWeight)
best_cut, best_sig, idx = get_best_cut(cut_values, sigacc_simple_list) 

sig_simple_list[idx], sigacc_simple_list[idx], acceptance_values[idx]


In [31]:
optimized_cuts

[{'cut_var': 'VertexBDTScore', 'cut_type': 'lowercut', 'best_cut': 0.1},
 {'cut_var': 'balance', 'cut_type': 'lowercut', 'best_cut': 0.9},
 {'cut_var': 'dmet', 'cut_type': 'lowercut', 'best_cut': -20000},
 {'cut_var': 'dmet', 'cut_type': 'uppercut', 'best_cut': 35000},
 {'cut_var': 'dphi_jj', 'cut_type': 'uppercut', 'best_cut': 3.100000000000002},
 {'cut_var': 'dphi_met_jetterm', 'cut_type': 'lowercut', 'best_cut': 0.05},
 {'cut_var': 'dphi_met_jetterm',
  'cut_type': 'uppercut',
  'best_cut': 0.7000000000000002},
 {'cut_var': 'dphi_met_phterm',
  'cut_type': 'lowercut',
  'best_cut': 1.6000000000000005},
 {'cut_var': 'dphi_met_phterm',
  'cut_type': 'uppercut',
  'best_cut': 2.900000000000001},
 {'cut_var': 'dphi_ph_centraljet1',
  'cut_type': 'lowercut',
  'best_cut': 1.7000000000000002},
 {'cut_var': 'dphi_phterm_jetterm',
  'cut_type': 'lowercut',
  'best_cut': 1.9500000000000008},
 {'cut_var': 'dphi_phterm_jetterm',
  'cut_type': 'uppercut',
  'best_cut': 3.100000000000001},
 {'cu

In [None]:
def calculate_significance(cut_var, cut_type, best_cut, tot_1st_cut, ntuple_names, signal_name, getVarDict, getWeight):
    sig_simple_list = []
    sigacc_simple_list = []
    acceptance_values = []
    tot_tmp = []

    for cut in cut_values:
        sig_after_cut = 0
        bkg_after_cut = []
        sig_events = 0

        for i in range(len(ntuple_names)):
            fb = tot2[i]
            process = ntuple_names[i]
            var_config = getVarDict(fb, process, var_name=cut_var)
            x = var_config[cut_var]['var']
            mask = x != -999
            x = x[mask]

            if process == signal_name:
                sig_events = getWeight(fb, process)
                sig_events = sig_events[mask]
                mask = x >= cut if cut_type == 'lowercut' else x <= cut
                sig_after_cut = ak.sum(sig_events[mask])
            else:
                bkg_events = getWeight(fb, process)
                bkg_events = bkg_events[mask]
                mask = x >= cut if cut_type == 'lowercut' else x <= cut
                bkg_after_cut.append(ak.sum(bkg_events[mask]))

            
            tot_tmp.append(fb)

        total_bkg = sum(bkg_after_cut)
        total_signal = sig_after_cut

        sig_simple = total_signal / np.sqrt(total_bkg) if total_bkg > 0 else 0
        acceptance = total_signal / sum(sig_events) if sum(sig_events) > 0 else 0

        sig_simple_list.append(sig_simple)
        sigacc_simple_list.append(sig_simple * acceptance)
        acceptance_values.append(acceptance * 100)

    return sig_simple_list, sigacc_simple_list, acceptance_values

In [None]:
for i in range(10):
    for cut_var, cut_types in cut_config.items():
        for cut_type, cut_values in cut_types.items():
            sig_simple_list, sigacc_simple_list, acceptance_values = calculate_significance(
                cut_var, cut_type, cut_values, tot2, ntuple_names, signal_name, getVarDict, getWeight
            )
    
            best_cut, best_sig, idx = get_best_cut(cut_values, sigacc_simple_list)
            result = {
                "cut_var": cut_var,
                "cut_type": cut_type,
                "best_cut": best_cut,
                "best_sig_x_acc": best_sig,
                "significance": sig_simple_list[idx],
                "acceptance": acceptance_values[idx]
            }

In [None]:
def sel(tot):
    tot2 = []
    for i in range(len(tot)):
        fb2 = tot[i]

        # jet_sum_tmp = ak.sum(fb2['jet_central_pt'], axis=-1)
        # expr = (fb2['met_tst_et'] + ak.firsts(fb2['ph_pt'])) / ak.where(jet_sum_tmp != 0, jet_sum_tmp, 1)
        # balance_tmp = ak.where(jet_sum_tmp != 0, expr, -999)
        # mask1 = balance_tmp >= 0.65
        # mask2 = balance_tmp == -999
        # fb2 = fb2[mask1 | mask2]
        
        # metsig_tmp = fb2['met_tst_sig'] 
        # mask1 = metsig_tmp >= 7
        # mask2 = metsig_tmp <= 13
        # fb2 = fb2[mask1 * mask2]
        
        # ph_eta_tmp = np.abs(ak.firsts(fb2['ph_eta']))
        # fb2 = fb2[ph_eta_tmp <= 1.75]

        # dphi_met_phterm_tmp = np.arccos(np.cos(fb2['met_tst_phi'] - fb2['met_phterm_phi'])) # added cut 3
        # fb2 = fb2[dphi_met_phterm_tmp >= 1.35]

        # dmet_tmp = fb2['met_tst_noJVT_et'] - fb2['met_tst_et']
        # mask1 = dmet_tmp >= -20000
        # mask2 = dmet_tmp <= 50000
        # fb2 = fb2[mask1 * mask2]

        # phi1_tmp = ak.firsts(fb2['jet_central_phi']) # added cut 7
        # phi2_tmp = ak.mask(fb2['jet_central_phi'], ak.num(fb2['jet_central_phi']) >= 2)[:, 1] 
        # dphi_tmp = np.arccos(np.cos(phi1_tmp - phi2_tmp))
        # dphi_jj_tmp = ak.fill_none(dphi_tmp, -999)
        # fb2 = fb2[dphi_jj_tmp <= 2.5]

        # dphi_met_jetterm_tmp = np.where(fb2['met_jetterm_et'] != 0,   # added cut 5
        #                     np.arccos(np.cos(fb2['met_tst_phi'] - fb2['met_jetterm_phi'])),
        #                     -999)
        # fb2 = fb2[dphi_met_jetterm_tmp <= 0.70]
        
        tot2.append(fb2)
    return tot2

# tot2 = sel(tot)
tot2 = tot

signal_name = 'ggHyyd'  # Define signal dataset
cut_name = 'basic'

def getCutDict():
    cut_dict = {}

    cut_dict['dphi_jj'] = {
        'lowercut': np.arange(0, 0.4+0.1, 0.1) # BDTScore > cut
    }
    cut_dict['balance'] = {
        'lowercut': np.arange(0, 1.5 + 0.05, 0.05), # balance > cut
        'uppercut': np.arange(5, 8 + 0.2, 0.2) # balance < cut
    }
    cut_dict['dmet'] = {
        'lowercut': np.arange(-30000, 0 + 5000, 5000), # dmet > cut
        'uppercut': np.arange(10000, 100000 + 5000, 5000), # -10000 < dmet < cut
    }
    cut_dict['dphi_jj'] = {
        'uppercut': np.arange(1, 3.1 + 0.1, 0.1) # dphi_jj < cut
    }
    cut_dict['dphi_met_jetterm'] = {
        'lowercut': np.arange(0, 1 + 0.05, 0.05), # dphi_met_jetterm > cut 
        'uppercut': np.arange(0.5, 2 + 0.05, 0.05), # dphi_met_jetterm < cut 
    }
    cut_dict['dphi_met_phterm'] = {
        'lowercut': np.arange(1, 2 + 0.05, 0.05), # dphi_met_phterm > cut
        'uppercut': np.arange(2, 3.1 + 0.1, 0.1), # dphi_met_phterm < cut
    }
    cut_dict['dphi_ph_centraljet1'] = {
        'lowercut': np.arange(0, 2.5 + 0.1, 0.1), # dphi_ph_centraljet1 > cut
        'uppercut': np.arange(1.5, 3.1 + 0.1, 0.1) # dphi_ph_centraljet1 < cut
    }
    cut_dict['dphi_phterm_jetterm'] = {
        'lowercut': np.arange(1, 2.5 + 0.05, 0.05), # dphi_phterm_jetterm > cut
        'uppercut': np.arange(2, 4 + 0.1, 0.1) # dphi_phterm_jetterm < cut
    }
    cut_dict['met'] = {
        'lowercut': np.arange(100000, 140000 + 5000, 5000),  # met > cut
        'uppercut': np.arange(140000, 300000 + 5000, 5000),  # met < cut
    }
    cut_dict['metsig'] = {
        'lowercut': np.arange(0, 10 + 1, 1), # metsig > cut
        'uppercut': np.arange(10, 30 + 1, 1), # metsig < cut 
    }
    cut_dict['mt'] = {
        'lowercut': np.arange(80, 130+5, 5), # mt > cut
        'uppercut': np.arange(120, 230+5, 5) # mt < cut
    }
    cut_dict['n_jet_central'] = {
        'uppercut': np.arange(0, 8+1, 1) # njet < cut
    }
    cut_dict['ph_eta'] = {
        'uppercut': np.arange(1, 2.5 + 0.05, 0.05), # ph_eta < cut
    }
    cut_dict['ph_pt'] = {
        'lowercut': np.arange(50000, 100000 + 5000, 5000),  # ph_pt > cut
        'uppercut': np.arange(100000, 300000 + 10000, 10000),  # ph_pt > cut
    }

    return cut_dict
cut_config = getCutDict()

def calculate_significance(cut_var, cut_type, cut_values):
    sig_simple_list = []
    sig_s_plus_b_list = []
    sig_s_plus_1p3b_list = []
    sig_binomial_list = []

    sigacc_simple_list = []
    sigacc_s_plus_b_list = []
    sigacc_s_plus_1p3b_list = []
    sigacc_binomial_list = []

    acceptance_values = []  # Store acceptance percentages

    for cut in cut_values:
        sig_after_cut = 0
        bkg_after_cut = []
        sig_events = 0
        
        for i in range(len(ntuple_names)-1): # not include dijet
            fb = tot2[i]
            process = ntuple_names[i]
            var_config = getVarDict(fb, process, var_name=cut_var)
            x = var_config[cut_var]['var']
            mask = x != -999 # Apply cut: Remove -999 values 
            x = x[mask]

            if process == signal_name:
                sig_events = getWeight(fb, process)
                sig_events = sig_events[mask]
                if cut_type == 'lowercut':
                    mask = x >= cut
                elif cut_type == 'uppercut':
                    mask = x <= cut
                else:
                    raise ValueError("Invalid cut type")
                sig_after_cut = ak.sum(sig_events[mask])
            
            else:
                bkg_events = getWeight(fb, process)
                bkg_events = bkg_events[mask]
                if cut_type == 'lowercut':
                    mask = x >= cut
                elif cut_type == 'uppercut':
                    mask = x <= cut
                else:
                    raise ValueError("Invalid cut type")
                bkg_after_cut.append(ak.sum(bkg_events[mask]))

       # Now compute different types of significance
        total_bkg = sum(bkg_after_cut)
        total_signal = sig_after_cut

        # Avoid zero division carefully
        if total_bkg > 0:
            sig_simple = total_signal / np.sqrt(total_bkg)
            sig_s_plus_b = total_signal / np.sqrt(total_signal + total_bkg) if (total_signal + total_bkg) > 0 else 0
            sig_s_plus_1p3b = total_signal / np.sqrt(total_signal + 1.3 * total_bkg) if (total_signal + 1.3*total_bkg) > 0 else 0
            sig_binomial = zbi(total_signal, total_bkg, sigma_b_frac=0.3)
        else:
            sig_simple = sig_s_plus_b = sig_s_plus_1p3b = sig_binomial = 0

        # Acceptance
        acceptance = total_signal / sum(sig_events) if sum(sig_events) > 0 else 0
        acceptance_values.append(acceptance * 100)  # percentage

        # Save significance
        sig_simple_list.append(sig_simple)
        sig_s_plus_b_list.append(sig_s_plus_b)
        sig_s_plus_1p3b_list.append(sig_s_plus_1p3b)
        sig_binomial_list.append(sig_binomial)

        # Save significance × acceptance
        sigacc_simple_list.append(sig_simple * acceptance)
        sigacc_s_plus_b_list.append(sig_s_plus_b * acceptance)
        sigacc_s_plus_1p3b_list.append(sig_s_plus_1p3b * acceptance)
        sigacc_binomial_list.append(sig_binomial * acceptance)

    return (sig_simple_list, sig_s_plus_b_list, sig_s_plus_1p3b_list, sig_binomial_list,
            sigacc_simple_list, sigacc_s_plus_b_list, sigacc_s_plus_1p3b_list, sigacc_binomial_list,
            acceptance_values)

# Compute significance for each variable dynamically
for cut_var, cut_types in cut_config.items():
    for cut_type, cut_values in cut_types.items():
        (sig_simple_list, sig_s_plus_b_list, sig_s_plus_1p3b_list, sig_binomial_list,
         sigacc_simple_list, sigacc_s_plus_b_list, sigacc_s_plus_1p3b_list, sigacc_binomial_list,
         acceptance_values) = calculate_significance(cut_var, cut_type, cut_values)

        # Plot results
        fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(8, 10), sharex=True)

        # Top plot: Significance vs. Cut
        ax_top.plot(cut_values, sig_simple_list, marker='o', label='S/√B')
        # ax_top.plot(cut_values, sig_s_plus_b_list, marker='s', label='S/√(S+B)')
        # ax_top.plot(cut_values, sig_s_plus_1p3b_list, marker='^', label='S/√(S+1.3B)')
        # ax_top.plot(cut_values, sig_binomial_list, marker='x', label='BinomialExpZ')
        ax_top.set_ylabel('Significance')
        ax_top.set_title(f'Significance vs. {cut_var} ({cut_type})')
        ax_top.legend()
        ax_top.grid(True)

        # Bottom plot: Significance * Acceptance vs. Cut
        ax_bot.plot(cut_values, sigacc_simple_list, marker='o', label='(S/√B) × Acceptance')
        # ax_bot.plot(cut_values, sigacc_s_plus_b_list, marker='s', label='(S/√(S+B)) × Acceptance')
        # ax_bot.plot(cut_values, sigacc_s_plus_1p3b_list, marker='^', label='(S/√(S+1.3B)) × Acceptance')
        # ax_bot.plot(cut_values, sigacc_binomial_list, marker='x', label='BinomialExpZ × Acceptance')

        for i, txt in enumerate(acceptance_values):
            ax_bot.text(cut_values[i], sigacc_simple_list[i], f'{txt:.1f}%', 
                        fontsize=10, ha='right', va='bottom', color='purple')
            
        ax_bot.set_xlabel(f'{cut_var} Cut')
        ax_bot.set_ylabel('Significance × Acceptance')
        ax_bot.set_title(f'Significance × Acceptance vs. {cut_var} ({cut_type})')
        
        ax_bot.set_xticks(cut_values)
        ax_bot.set_xticklabels(ax_bot.get_xticks(), rotation=45, ha='right')
        ax_bot.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        # ax_bot.xaxis.set_major_locator(ticker.MaxNLocator(nbins=15))  # Show at most 10 x-ticks
        
        var_configs_tmp = getVarDict(tot2[0], signal_name, cut_var)
        ax_bot.set_xlabel(var_configs_tmp[cut_var]['title'])
        ax_bot.legend()
        ax_bot.grid(True)

        plt.tight_layout()
        plt.savefig(f"/data/projects/campfire-workshop/dark_photon/plots/mc23d_{cut_name}cut/significance_{cut_var}_{cut_type}.png")
        print(f"Successfully saved to /data/projects/campfire-workshop/dark_photon/plots/mc23d_{cut_name}cut/significance_{cut_var}_{cut_type}.png")
        plt.close()

In [3]:
def sel(tot):
    tot2 = []
    for i in range(len(tot)):
        fb2 = tot[i]

        fb2 = fb2[fb2['BDTScore'] >= 0.1]

        dphi_met_phterm_tmp = np.arccos(np.cos(fb2['met_tst_phi'] - fb2['met_phterm_phi']))
        fb2 = fb2[dphi_met_phterm_tmp >= 1.35]

        metsig_tmp = fb2['met_tst_sig'] 
        mask1 = metsig_tmp >= 5
        mask2 = metsig_tmp <= 16
        fb2 = fb2[mask1 * mask2]
    
        ph_eta_tmp = np.abs(ak.firsts(fb2['ph_eta']))
        fb2 = fb2[ph_eta_tmp <= 1.75]

        dmet_tmp = fb2['met_tst_noJVT_et'] - fb2['met_tst_et']
        mask1 = dmet_tmp >= -20000
        mask2 = dmet_tmp <= 50000
        fb2 = fb2[mask1 * mask2]

        dphi_met_jetterm_tmp = np.where(fb2['met_jetterm_et'] != 0,   # added cut 5
                                np.arccos(np.cos(fb2['met_tst_phi'] - fb2['met_jetterm_phi'])),
                                -999)
        fb2 = fb2[dphi_met_jetterm_tmp <= 0.65]

        jet_sum_tmp = ak.sum(fb2['jet_central_pt'], axis=-1)
        expr = (fb2['met_tst_et'] + ak.firsts(fb2['ph_pt'])) / ak.where(jet_sum_tmp != 0, jet_sum_tmp, 1)
        balance_tmp = ak.where(jet_sum_tmp != 0, expr, -999)
        mask1 = balance_tmp >= 0.70
        mask2 = balance_tmp == -999
        fb2 = fb2[mask1 | mask2]

        dphi_jj_tmp = fb2['dphi_central_jj']
        dphi_jj_tmp = ak.where(dphi_jj_tmp == -10, np.nan, dphi_jj_tmp)
        dphi_jj_tmp = np.arccos(np.cos(dphi_jj_tmp))
        dphi_jj_tmp = ak.where(np.isnan(dphi_jj_tmp), -999, dphi_jj_tmp)
        fb2 = fb2[dphi_jj_tmp <= 2.4]
        
        tot2.append(fb2)
    return tot2

# tot2 = sel(tot)
tot2 = tot

cut_name = 'basic'
var_config = getVarDict(tot2[0], 'ggHyyd', var_name='dphi_jj')
# var_config = getVarDict(tot2[0], 'ggHyyd')


for var in var_config:
    # print(var)
    bg_values = []     
    bg_weights = []    
    bg_colors = []     
    bg_labels = []     

    signal_values = [] 
    signal_weights = []
    signal_color = None 
    signal_label = None

    for j in range(len(ntuple_names)):
    # for j in range(len(ntuple_names)-1): # leave dijet out
        process = ntuple_names[j]
        fb = tot2[j]  # TTree
        var_config = getVarDict(fb, process, var_name=var)

        x = var_config[var]['var'] # TBranch
        bins = var_config[var]['bins'] 

        if 'weight' in var_config[var]:  # If weight is there
            weights = var_config[var]['weight']
        else:
            weights = getWeight(fb, process)
        
        sample_info = sample_dict[process]
        color = sample_info['color']
        legend = sample_info['legend']

        
        if process == 'ggHyyd':  # signal
            signal_values.append(x)
            signal_weights.append(weights)
            signal_color = color
            signal_label = legend
        else:   # background
            bg_values.append(x)
            bg_weights.append(weights)
            bg_colors.append(color)
            bg_labels.append(legend)

    fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(12, 13), gridspec_kw={'height_ratios': [9, 4]})

    ax_top.hist(bg_values, bins=bins, weights=bg_weights, color=bg_colors,
                label=bg_labels, stacked=True)

    ax_top.hist(signal_values, bins=bins, weights=signal_weights, color=signal_color,
                label=signal_label, histtype='step', linewidth=2)

    signal_all = np.concatenate(signal_values) if len(signal_values) > 0 else np.array([])
    signal_weights_all = np.concatenate(signal_weights) if len(signal_weights) > 0 else np.array([])

    # Add error bar for signal (top plot)
    if len(signal_all) > 0:
        signal_counts, bin_edges = np.histogram(signal_all, bins=bins, weights=signal_weights_all)
        sum_weights_sq, _ = np.histogram(signal_all, bins=bins, weights=signal_weights_all**2)
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        signal_errors = np.sqrt(sum_weights_sq)  # Poisson error sqrt(N)

        ax_top.errorbar(bin_centers, signal_counts, yerr=signal_errors, fmt='.', linewidth=2,
                        color=signal_color, capsize=0)

    ax_top.set_yscale('log')
    ax_top.set_ylim(0.0001, 1e11)
    ax_top.set_xlim(bins[0], bins[-1])
    ax_top.minorticks_on()
    ax_top.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_top.set_ylabel("Events")
    ax_top.legend(ncol=2)
    # ax_top.set_title("vtx_sumPt distribution")

    bg_all = np.concatenate(bg_values) if len(bg_values) > 0 else np.array([])
    bg_weights_all = np.concatenate(bg_weights) if len(bg_weights) > 0 else np.array([])

    # Compute the weighted histogram counts using np.histogram
    S_counts, _ = np.histogram(signal_all, bins=bins, weights=signal_weights_all)
    B_counts, _ = np.histogram(bg_all, bins=bins, weights=bg_weights_all)     

    # Compute per-bin significance
    sig_simple = np.zeros_like(S_counts, dtype=float)
    sig_s_plus_b = np.zeros_like(S_counts, dtype=float)
    sig_s_plus_1p3b = np.zeros_like(S_counts, dtype=float)

    sqrt_B = np.sqrt(B_counts)
    sqrt_SplusB = np.sqrt(S_counts + B_counts)
    sqrt_Splus1p3B = np.sqrt(S_counts + 1.3 * B_counts)

    # Avoid division by zero safely
    sig_simple = np.where(B_counts > 0, S_counts / sqrt_B, 0)
    sig_s_plus_b = np.where((S_counts + B_counts) > 0, S_counts / sqrt_SplusB, 0)
    sig_s_plus_1p3b = np.where((S_counts + 1.3 * B_counts) > 0, S_counts / sqrt_Splus1p3B, 0)

    # Add Binomial ExpZ per bin
    zbi_per_bin = np.array([
        zbi(S_counts[i], B_counts[i], sigma_b_frac=0.3)
        for i in range(len(S_counts))
    ])

    # Compute the bin centers for plotting
    bin_centers = 0.5 * (bins[:-1] + bins[1:])

    # Compute the total significance: total S / sqrt(total B)
    total_signal = np.sum(S_counts)
    total_bkg = np.sum(B_counts)

    if total_bkg > 0:
        total_sig_simple = total_signal / np.sqrt(total_bkg)
        total_sig_s_plus_b = total_signal / np.sqrt(total_signal + total_bkg)
        total_sig_s_plus_1p3b = total_signal / np.sqrt(total_signal + 1.3 * total_bkg)
        total_sig_binomial = zbi(total_signal, total_bkg, sigma_b_frac=0.3)
    else:
        total_sig_simple = total_sig_s_plus_b = total_sig_s_plus_1p3b = total_sig_binomial = 0

    # --- Plot all significance curves ---
    ax_bot.step(bin_centers, sig_simple, where='mid', color='chocolate', linewidth=2,
                label=f"S/√B = {total_sig_simple:.4f}")
    ax_bot.step(bin_centers, sig_s_plus_b, where='mid', color='tomato', linewidth=2,
                label=f"S/√(S+B) = {total_sig_s_plus_b:.4f}")
    ax_bot.step(bin_centers, sig_s_plus_1p3b, where='mid', color='orange', linewidth=2,
                label=f"S/√(S+1.3B) = {total_sig_s_plus_1p3b:.4f}")
    ax_bot.step(bin_centers, zbi_per_bin, where='mid', color='plum', linewidth=2,
                label=f"Binomial ExpZ = {total_sig_binomial:.4f}")

    ax_bot.set_xlabel(var_config[var]['title'])
    # ax_bot.set_xticks(np.linspace(bins[0], bins[-1], 11))
    ax_bot.set_ylabel("Significance")
    ax_bot.set_ylim(-0.8, 2)
    ax_top.set_xlim(bins[0], bins[-1])

    # Do not set a title on the bottom plot.
    ax_bot.set_title("")

    # Draw a legend with purple text.
    leg = ax_bot.legend()
    for text in leg.get_texts():
        text.set_color('purple')

    plt.xlim(bins[0], bins[-1])
    plt.tight_layout()
    plt.savefig(f"/data/jlai/dark_photon/jets_faking_photons/lumi135/mc23d_{cut_name}cut/{var}_nodijet.png")
    print(f"successfully saved to /data/jlai/dark_photon/jets_faking_photons/lumi135/mc23d_{cut_name}cut/{var}_nodijet.png")
    plt.close()
    # plt.show()

    y_true = np.concatenate([np.ones_like(signal_all), np.zeros_like(bg_all)])
    # Use the vtx_sumPt values as the classifier output.
    y_scores = np.concatenate([signal_all, bg_all])
    # Combine the weights for all events.
    y_weights = np.concatenate([signal_weights_all, bg_weights_all])

    # Compute the weighted ROC curve.
    fpr, tpr, thresholds = roc_curve(y_true, y_scores, sample_weight=y_weights)
    sorted_indices = np.argsort(fpr)
    fpr_sorted = fpr[sorted_indices]
    tpr_sorted = tpr[sorted_indices]

    roc_auc = auc(fpr_sorted, tpr_sorted)

    # Create a new figure for the ROC curve.
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, lw=2, color='red', label=f'ROC curve (AUC = {roc_auc:.5f})')
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random chance')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve for {var}")
    plt.legend(loc="lower right")
    plt.grid(True, which="both", linestyle="--", linewidth=0.5)
    plt.tight_layout()    
    plt.savefig(f"/data/jlai/dark_photon/jets_faking_photons/lumi135/mc23d_{cut_name}cut/roc_curve_{var}.png")
    print(f"successfully saved to /data/jlai/dark_photon/jets_faking_photons/lumi135/mc23d_{cut_name}cut/roc_curve_{var}.png")
    plt.close()
    # plt.show()


  return impl(*broadcasted_args, **(kwargs or {}))


successfully saved to /data/jlai/dark_photon/jets_faking_photons/lumi135/mc23d_basiccut/dphi_jj_nodijet.png
successfully saved to /data/jlai/dark_photon/jets_faking_photons/lumi135/mc23d_basiccut/roc_curve_dphi_jj.png
