In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import polars as pl
import pickle
import gc
import psutil
from tqdm.notebook import tqdm

import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.signal_categories import topological_category_labels, topological_category_colors, topological_category_hatches, topological_category_labels_latex
from src.signal_categories import filetype_category_labels, filetype_category_colors, filetype_category_hatches, filetype_category_labels_latex
from src.signal_categories import del1g_detailed_category_labels, del1g_detailed_category_colors, del1g_detailed_category_hatches, del1g_detailed_category_labels_latex, del1g_detailed_categories_dic
from src.signal_categories import del1g_simple_category_labels, del1g_simple_category_colors, del1g_simple_category_hatches, del1g_simple_category_labels_latex, del1g_simple_categories_dic
from src.signal_categories import train_category_queries, train_category_labels, train_category_colors, train_category_hatches, train_category_labels_latex

from src.file_locations import intermediate_files_location


# File Loading

In [None]:
training = "all_vars"

reco_sig_categories = train_category_labels

del1g_norm_factor = 0.5
iso1g_norm_factor = 0.05


In [None]:
print("loading all_df.parquet...")
all_df = pl.read_parquet(f"{intermediate_files_location}/all_df.parquet")
print(f"{all_df.shape=}")

# this only includes predictions for events passing the preselection used during training
print("loading predictions_df.parquet...")
predictions_df = pl.read_parquet(f"../training_outputs/{training}/predictions.parquet")
print(f"{predictions_df.shape=}")

all_df = all_df.filter(pl.col("filetype") != "data")

print("merging all_df and predictions.pkl...")
merged_df = all_df.join(predictions_df, on=["filetype", "run", "subrun", "event"], how="left").clone()
del all_df, predictions_df
gc.collect()  # Force cleanup of deleted dataframes immediately

prob_categories = ["prob_" + cat for cat in reco_sig_categories]

for prob in prob_categories:
    merged_df = merged_df.with_columns(pl.col(prob).fill_null(-1))

num_train_events = merged_df.filter(pl.col("used_for_training") == True).height
num_test_events = merged_df.filter(pl.col("used_for_testing") == True).height
print(f"{merged_df.height=}")
print(f"{num_train_events=}")
print(f"{num_test_events=}")

frac_test = num_test_events / (num_train_events + num_test_events)
print(f"weighting up by the fraction of test events: {frac_test:.3f}")

# Use polars expressions for weight modification
merged_df = merged_df.with_columns([
    pl.when(pl.col("used_for_testing") == True)
    .then(pl.col("wc_net_weight") / frac_test)
    .otherwise(pl.col("wc_net_weight"))
    .alias("wc_net_weight_temp")
]).with_columns([
    pl.when(pl.col("iso1g_overlay") == True)
    .then(pl.col("wc_net_weight_temp") * iso1g_norm_factor)
    .when(pl.col("del1g_overlay") == True)
    .then(pl.col("wc_net_weight_temp") * del1g_norm_factor)
    .otherwise(pl.col("wc_net_weight_temp"))
    .alias("wc_net_weight")
]).drop("wc_net_weight_temp")

merged_df = merged_df.filter((pl.col("used_for_testing") == True) | pl.col("used_for_testing").is_null())

# Extract wc_truth_muonMomentum_3
def extract_muon_momentum_3(x):
    if isinstance(x, float):
        return -1
    elif isinstance(x, list) and len(x) > 3:
        return x[3]
    else:
        return -1

#erin_sig_query = "(wc_match_completeness_energy>0.1*wc_truth_energyInside and wc_truth_single_photon==1 and (wc_truth_isCC==0 or (wc_truth_isCC==1 and abs(wc_truth_nuPdg)==14 and abs(wc_truth_muonMomentum_3-0.105658)<0.1)))"
merged_df = merged_df.with_columns([
    (
        (pl.col("wc_truth_single_photon") == 1) &
        (
            (pl.col("wc_truth_isCC") == 0) |
            ((pl.col("wc_truth_isCC") == 1) & (pl.col("wc_truth_nuPdg").abs() == 14) & ((pl.col("wc_truth_muonMomentum_3") - 0.105658).abs() < 0.1))
        )
    ).cast(pl.Int32).alias("erin_inclusive_1g_true_sig")
])


In [None]:
print(f"{merged_df.height=}, {merged_df.width=}")
print(f"{merged_df.height=}, {merged_df.width=}")


In [None]:
merged_df = merged_df.with_columns([
    (
        (pl.col("wc_shw_sp_n_20mev_showers") > 0) &
        (pl.col("wc_reco_nuvtxX") > 5.0) & (pl.col("wc_reco_nuvtxX") < 250.0) &
        (pl.col("wc_single_photon_numu_score") > 0.4) &
        (pl.col("wc_single_photon_other_score") > 0.2) &
        (pl.col("wc_single_photon_ncpi0_score") > -0.05) &
        (pl.col("wc_single_photon_nue_score") > -1.0) &
        (pl.col("wc_shw_sp_n_20br1_showers") == 1)
    ).cast(pl.Int32).alias("erin_inclusive_1g_sel")
])


In [None]:
probs_2d_arr = merged_df.select(prob_categories).to_numpy()
reco_categories_argmax = np.argmax(probs_2d_arr, axis=1)
del probs_2d_arr

merged_df = merged_df.with_columns([
    pl.Series("reco_category_argmax_index", reco_categories_argmax)
])

presel_merged_df = merged_df.filter(pl.col("wc_kine_reco_Enu") > 0)


# Preselection Efficiencies

In [None]:
total_num_truth_by_category = []
for i in range(len(del1g_detailed_category_labels[:-1])):
    total_num_truth_by_category.append(
        merged_df.filter(pl.col("del1g_detailed_signal_category") == i)["wc_net_weight"].sum()
    )

generic_merged_df = presel_merged_df
total_num_generic_truth_by_category = []
for i in range(len(del1g_detailed_category_labels[:-1])):
    total_num_generic_truth_by_category.append(
        generic_merged_df.filter(pl.col("del1g_detailed_signal_category") == i)["wc_net_weight"].sum()
    )

total_num_presel_truth_by_category = []
for i in range(len(del1g_detailed_category_labels[:-1])):
    total_num_presel_truth_by_category.append(
        presel_merged_df.filter(pl.col("del1g_detailed_signal_category") == i)["wc_net_weight"].sum()
    )

print("WC Generic Selection Topological Efficiencies:")
for i in range(len(del1g_detailed_category_labels[:-1])):
    print(f"{del1g_detailed_category_labels[i]}: {total_num_generic_truth_by_category[i]} / {total_num_truth_by_category[i]} = {total_num_generic_truth_by_category[i] / total_num_truth_by_category[i]:.3f}")

print("\nPreselection Topological Efficiencies:")
for i in range(len(del1g_detailed_category_labels[:-1])):
    print(f"{del1g_detailed_category_labels[i]}: {total_num_presel_truth_by_category[i]} / {total_num_truth_by_category[i]} = {total_num_presel_truth_by_category[i] / total_num_truth_by_category[i]:.3f}")


In [None]:
# load reco_category_queries
with open(f"{intermediate_files_location}/reco_category_queries.pkl", "rb") as f:
    reco_category_queries = pickle.load(f)


# Nominal Sel Efficiencies

In [None]:
nominal_sel_matrix = np.zeros((len(del1g_detailed_category_labels[:-1]), len(reco_sig_categories)))
for i in tqdm(range(len(del1g_detailed_category_labels[:-1]))):
    for j in range(len(reco_sig_categories)):
        nominal_sel_matrix[i, j] = presel_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == i) &
            reco_category_queries[j]
        )["wc_net_weight"].sum()

plt.figure(figsize=(10, 10))
plt.imshow(nominal_sel_matrix, norm=mpl.colors.LogNorm())
plt.colorbar(label="Number of Events")
for i in range(len(del1g_detailed_category_labels[:-1])):
    for j in range(len(reco_sig_categories)):
        plt.text(j, i, f'{nominal_sel_matrix[i,j]:.1f}', ha='center', va='center', fontsize=4)
plt.xticks(range(len(reco_sig_categories)), reco_sig_categories, rotation=90, fontsize=4)
plt.yticks(range(len(del1g_detailed_category_labels[:-1])), del1g_detailed_category_labels[:-1], fontsize=4)
plt.xlabel("Reconstructed Category")
plt.ylabel("Truth Category")
plt.title("Nominal Selection Matrix")
plt.show()


In [None]:
nominal_eff_matrix = (nominal_sel_matrix.T / total_num_truth_by_category).T

plt.figure(figsize=(10, 10))
plt.imshow(nominal_eff_matrix, cmap="Blues")
plt.colorbar(label="Efficiency")
for i in range(len(del1g_detailed_category_labels[:-1])):
    for j in range(len(reco_sig_categories)):
        plt.text(j, i, f'{nominal_eff_matrix[i,j]:.3f}', ha='center', va='center', fontsize=4)
plt.xticks(range(len(reco_sig_categories)), reco_sig_categories, rotation=90, fontsize=4)
plt.yticks(range(len(del1g_detailed_category_labels[:-1])), del1g_detailed_category_labels[:-1], fontsize=4)
plt.xlabel("Reconstructed Category")
plt.ylabel("Truth Category")
plt.title("Nominal Selection Efficiency Matrix")
plt.show()

# 1g Efficiencies By Cut Value

In [None]:
presel_merged_df = presel_merged_df.with_columns([
    (pl.col("prob_1gNp") + pl.col("prob_1g0p") + pl.col("prob_1gNp1mu") + pl.col("prob_1g0p1mu")).alias("prob_1g")
])
presel_merged_df = presel_merged_df.with_columns([
    (pl.col("prob_1gNp") + pl.col("prob_1g0p")).alias("prob_1g0mu")
])
presel_merged_df = presel_merged_df.with_columns([
    (pl.col("prob_1gNp1mu") + pl.col("prob_1g0p1mu")).alias("prob_1g1mu")
])

presel_merged_df = presel_merged_df.with_columns([
    (pl.col("prob_NC1pi0_Np") + pl.col("prob_NC1pi0_0p")).alias("prob_NC1pi0")
])
presel_merged_df = presel_merged_df.with_columns([
    (pl.col("prob_numuCC1pi0_Np") + pl.col("prob_numuCC1pi0_0p")).alias("prob_numuCC1pi0")
])
presel_merged_df = presel_merged_df.with_columns([
    (pl.col("prob_nueCC_Np") + pl.col("prob_nueCC_0p")).alias("prob_nueCC")
])
presel_merged_df = presel_merged_df.with_columns([
    (pl.col("prob_numuCC_Np") + pl.col("prob_numuCC_0p")).alias("prob_numuCC")
])

presel_merged_df = presel_merged_df.with_columns([
    pl.when(pl.col("wc_kine_reco_Enu") < 0)
    .then(-999)
    .when(pl.col("wc_nc_delta_score").is_null())
    .then(-999)
    .otherwise(pl.col("wc_nc_delta_score"))
    .alias("wc_nc_delta_score_generic")
])


In [None]:
plt.rcParams.update({'font.size': 14})

eff_eval_cats = [
    "NCDeltaRad", "NCDeltaRad_1gNp", "NCDeltaRad_1g0p", 
    "numuCCDeltaRad", "numuCCDeltaRad_1gNp", "numuCCDeltaRad_1g0p", 
    "erin_inclusive_1g",
    "del1g_Np", "del1g_0p", "del1g_Np1mu", "del1g_0p1mu", "del1g", "del1g_outFV", 
    "iso1g", "iso1g_outFV",
    "NC1pi0", "NC1pi0_Np", "NC1pi0_0p",
    "numuCC1pi0", "numuCC1pi0_Np", "numuCC1pi0_0p",
    "nueCC", "nueCC_Np", "nueCC_0p",
    "numuCC", "numuCC_Np", "numuCC_0p",
    "eta_other",
]

for eff_eval_cat in eff_eval_cats:

    eff_eval_cat_latex = eff_eval_cat
    if eff_eval_cat == "NCDeltaRad":
        eff_eval_cat_latex = r"NC $\Delta\rightarrow N \gamma$"
    elif eff_eval_cat == "NCDeltaRad_1gNp":
        eff_eval_cat_latex = r"NC $\Delta\rightarrow N \gamma$ $Np$"
    elif eff_eval_cat == "NCDeltaRad_1g0p":
        eff_eval_cat_latex = r"NC $\Delta\rightarrow N \gamma$ $0p$"
    elif eff_eval_cat == "numuCCDeltaRad":
        eff_eval_cat_latex = r"$\nu_\mu$ CC $\Delta\rightarrow N \gamma$"
    elif eff_eval_cat == "numuCCDeltaRad_1gNp":
        eff_eval_cat_latex = r"$\nu_\mu$ CC $\Delta\rightarrow N \gamma$ $Np$"
    elif eff_eval_cat == "numuCCDeltaRad_1g0p":
        eff_eval_cat_latex = r"$\nu_\mu$ CC $\Delta\rightarrow N \gamma$ $0p$"
    elif eff_eval_cat == "erin_inclusive_1g":
        eff_eval_cat_latex = r"Erin Inclusive $1\gamma$"
    elif eff_eval_cat == "del1g_Np":
        eff_eval_cat_latex = r"Del1g $1\gamma Np$"
    elif eff_eval_cat == "del1g_0p":
        eff_eval_cat_latex = r"Del1g $1\gamma0p$"
    elif eff_eval_cat == "del1g_Np1mu":
        eff_eval_cat_latex = r"Del1g $1\gamma Np1\mu$"
    elif eff_eval_cat == "del1g_0p1mu":
        eff_eval_cat_latex = r"Del1g $1\gamma 0p1\mu$"
    elif eff_eval_cat == "del1g_outFV":
        eff_eval_cat_latex = r"Del1g $1\gamma$ Out FV"
    elif eff_eval_cat == "iso1g":
        eff_eval_cat_latex = r"Iso1g $1\gamma 0p$"
    elif eff_eval_cat == "iso1g_outFV":
        eff_eval_cat_latex = r"Iso1g $1\gamma$ Out FV"
    elif eff_eval_cat == "NC1pi0":
        eff_eval_cat_latex = r"NC 1$\pi^0$"
    elif eff_eval_cat == "NC1pi0_Np":
        eff_eval_cat_latex = r"NC 1$\pi^0$ $Np$"
    elif eff_eval_cat == "NC1pi0_0p":
        eff_eval_cat_latex = r"NC 1$\pi^0$ $0p$"
    elif eff_eval_cat == "numuCC1pi0":
        eff_eval_cat_latex = r"$\nu_\mu$ CC 1$\pi^0$"
    elif eff_eval_cat == "numuCC1pi0_Np":
        eff_eval_cat_latex = r"$\nu_\mu$ CC 1$\pi^0$ $Np$"
    elif eff_eval_cat == "numuCC1pi0_0p":
        eff_eval_cat_latex = r"$\nu_\mu$ CC 1$\pi^0$ $0p$"
    elif eff_eval_cat == "nueCC":
        eff_eval_cat_latex = r"$\nu_e$ CC"
    elif eff_eval_cat == "nueCC_Np":
        eff_eval_cat_latex = r"$\nu_e$ CC $Np$"
    elif eff_eval_cat == "nueCC_0p":
        eff_eval_cat_latex = r"$\nu_e$ CC $0p$"
    elif eff_eval_cat == "numuCC":
        eff_eval_cat_latex = r"$\nu_\mu$ CC 0$\pi^0$"
    elif eff_eval_cat == "numuCC_Np":
        eff_eval_cat_latex = r"$\nu_\mu$ CC $Np$ 0$\pi^0$"
    elif eff_eval_cat == "numuCC_0p":
        eff_eval_cat_latex = r"$\nu_\mu$ CC $0p$ 0$\pi^0$"
        
    if "del1g" in eff_eval_cat:
        rel_filetypes_merged_df = merged_df.filter(pl.col("iso1g_overlay") == False)
        rel_filetypes_presel_merged_df = presel_merged_df.filter(pl.col("iso1g_overlay") == False)
    elif "iso1g" in eff_eval_cat:
        rel_filetypes_merged_df = merged_df.filter(pl.col("del1g_overlay") == False)
        rel_filetypes_presel_merged_df = presel_merged_df.filter(pl.col("del1g_overlay") == False)
    else:
        rel_filetypes_merged_df = merged_df.filter((pl.col("del1g_overlay") == False) & (pl.col("iso1g_overlay") == False))
        rel_filetypes_presel_merged_df = presel_merged_df.filter((pl.col("del1g_overlay") == False) & (pl.col("iso1g_overlay") == False))

    if eff_eval_cat == "NCDeltaRad":
        sig_df = rel_filetypes_presel_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["NCDeltaRad_1gNp"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["NCDeltaRad_1g0p"])
        )
        bkg_df = rel_filetypes_presel_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["NCDeltaRad_1gNp"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["NCDeltaRad_1g0p"]))
        )
        total_sig = rel_filetypes_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["NCDeltaRad_1gNp"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["NCDeltaRad_1g0p"])
        )["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["NCDeltaRad_1gNp"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["NCDeltaRad_1g0p"]))
        )["wc_net_weight"].sum()
    elif eff_eval_cat == "numuCCDeltaRad":
        sig_df = rel_filetypes_presel_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCCDeltaRad_1gNp"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCCDeltaRad_1g0p"])
        )
        bkg_df = rel_filetypes_presel_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCCDeltaRad_1gNp"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCCDeltaRad_1g0p"]))
        )
        total_sig = rel_filetypes_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCCDeltaRad_1gNp"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCCDeltaRad_1g0p"])
        )["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCCDeltaRad_1gNp"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCCDeltaRad_1g0p"]))
        )["wc_net_weight"].sum()

    elif eff_eval_cat == "erin_inclusive_1g":
        sig_df = rel_filetypes_presel_merged_df.filter(pl.col("erin_inclusive_1g_true_sig") == 1)
        bkg_df = rel_filetypes_presel_merged_df.filter(pl.col("erin_inclusive_1g_true_sig") == 0)
        total_sig = rel_filetypes_merged_df.filter(pl.col("erin_inclusive_1g_true_sig") == 1)["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(pl.col("erin_inclusive_1g_true_sig") == 0)["wc_net_weight"].sum()

    elif eff_eval_cat == "del1g":
        sig_df = rel_filetypes_presel_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np1mu"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p1mu"])
        )
        bkg_df = rel_filetypes_presel_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np1mu"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p1mu"]))
        )
        total_sig = rel_filetypes_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np1mu"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p1mu"])
        )["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np1mu"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p1mu"]))
        )["wc_net_weight"].sum()
    
    elif eff_eval_cat == "NC1pi0":
        sig_df = rel_filetypes_presel_merged_df.filter(
            (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["NC1pi0_Np"]) |
            (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["NC1pi0_0p"])
        )
        bkg_df = rel_filetypes_presel_merged_df.filter(
            ~((pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["NC1pi0_Np"]) |
              (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["NC1pi0_0p"]))
        )
        total_sig = rel_filetypes_merged_df.filter(
            (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["NC1pi0_Np"]) |
            (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["NC1pi0_0p"])
        )["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(
            ~((pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["NC1pi0_Np"]) |
              (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["NC1pi0_0p"]))
        )["wc_net_weight"].sum()

    elif eff_eval_cat == "numuCC1pi0":
        sig_df = rel_filetypes_presel_merged_df.filter(
            (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["numuCC1pi0_Np"]) |
            (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["numuCC1pi0_0p"])
        )
        bkg_df = rel_filetypes_presel_merged_df.filter(
            ~((pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["numuCC1pi0_Np"]) |
              (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["numuCC1pi0_0p"]))
        )
        total_sig = rel_filetypes_merged_df.filter(
            (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["numuCC1pi0_Np"]) |
            (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["numuCC1pi0_0p"])
        )["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(
            ~((pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["numuCC1pi0_Np"]) |
              (pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic["numuCC1pi0_0p"]))
        )["wc_net_weight"].sum()

    elif eff_eval_cat == "nueCC":
        sig_df = rel_filetypes_presel_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["nueCC_Np"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["nueCC_0p"])
        )
        bkg_df = rel_filetypes_presel_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["nueCC_Np"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["nueCC_0p"]))
        )
        total_sig = rel_filetypes_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["nueCC_Np"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["nueCC_0p"])
        )["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["nueCC_Np"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["nueCC_0p"]))
        )["wc_net_weight"].sum()

    elif eff_eval_cat == "numuCC":
        sig_df = rel_filetypes_presel_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCC_Np"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCC_0p"])
        )
        bkg_df = rel_filetypes_presel_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCC_Np"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCC_0p"]))
        )
        total_sig = rel_filetypes_merged_df.filter(
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCC_Np"]) |
            (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCC_0p"])
        )["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(
            ~((pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCC_Np"]) |
              (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["numuCC_0p"]))
        )["wc_net_weight"].sum()

    elif eff_eval_cat in ["NC1pi0_Np", "NC1pi0_0p", "numuCC1pi0_Np", "numuCC1pi0_0p"]: 
        # for these, we have to use the simple category, since we want to combine all the detailed pi0 breakdown categories
        sig_df = rel_filetypes_presel_merged_df.filter(pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic[eff_eval_cat])
        bkg_df = rel_filetypes_presel_merged_df.filter(pl.col("del1g_simple_signal_category") != del1g_simple_categories_dic[eff_eval_cat])
        total_sig = rel_filetypes_merged_df.filter(pl.col("del1g_simple_signal_category") == del1g_simple_categories_dic[eff_eval_cat])["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(pl.col("del1g_simple_signal_category") != del1g_simple_categories_dic[eff_eval_cat])["wc_net_weight"].sum()

    else:
        sig_df = rel_filetypes_presel_merged_df.filter(pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic[eff_eval_cat])
        bkg_df = rel_filetypes_presel_merged_df.filter(pl.col("del1g_detailed_signal_category") != del1g_detailed_categories_dic[eff_eval_cat])
        total_sig = rel_filetypes_merged_df.filter(pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic[eff_eval_cat])["wc_net_weight"].sum()
        total_bkg = rel_filetypes_merged_df.filter(pl.col("del1g_detailed_signal_category") != del1g_detailed_categories_dic[eff_eval_cat])["wc_net_weight"].sum()

    if total_sig == 0:
        print(f"No signal events found for {eff_eval_cat}!")
        continue

    multi_class_bdt_score_name = f"prob_{eff_eval_cat}"
    
    if eff_eval_cat == "NCDeltaRad":
        multi_class_bdt_score_name = "prob_1g0mu"
    elif eff_eval_cat == "NCDeltaRad_1gNp":
        multi_class_bdt_score_name = "prob_1gNp"
    elif eff_eval_cat == "NCDeltaRad_1g0p":
        multi_class_bdt_score_name = "prob_1g0p"

    elif eff_eval_cat == "numuCCDeltaRad":
        multi_class_bdt_score_name = "prob_1g1mu"
    elif eff_eval_cat == "numuCCDeltaRad_1gNp":
        multi_class_bdt_score_name = "prob_1gNp1mu"
    elif eff_eval_cat == "numuCCDeltaRad_1g0p":
        multi_class_bdt_score_name = "prob_1g0p1mu"

    elif eff_eval_cat == "erin_inclusive_1g":
        multi_class_bdt_score_name = "prob_1g0mu"

    elif eff_eval_cat == "del1g":
        multi_class_bdt_score_name = "prob_1g"
    elif eff_eval_cat == "del1g_Np":
        multi_class_bdt_score_name = "prob_1gNp"
    elif eff_eval_cat == "del1g_0p":
        multi_class_bdt_score_name = "prob_1g0p"
    elif eff_eval_cat == "del1g_Np1mu":
        multi_class_bdt_score_name = "prob_1gNp1mu"
    elif eff_eval_cat == "del1g_0p1mu":
        multi_class_bdt_score_name = "prob_1g0p1mu"
    elif eff_eval_cat == "del1g_outFV":
        multi_class_bdt_score_name = "prob_1g_outFV"

    elif eff_eval_cat == "iso1g":
        multi_class_bdt_score_name = "prob_1g0p"
    elif eff_eval_cat == "iso1g_outFV":
        multi_class_bdt_score_name = "prob_1g_outFV"
    
    else:
        multi_class_bdt_score_name = f"prob_{eff_eval_cat}"

    sig_bdt_scores = sig_df[multi_class_bdt_score_name].to_numpy()
    bkg_bdt_scores = bkg_df[multi_class_bdt_score_name].to_numpy()
    sig_weights = sig_df["wc_net_weight"].to_numpy()
    bkg_weights = bkg_df["wc_net_weight"].to_numpy()

    num_points = 500
    cutoffs = np.linspace(0, 1, num_points)
    all_effs = []
    all_purs = []
    for cutoff in cutoffs:
        sig_sel_weights = sig_weights[sig_bdt_scores > cutoff]
        sig_sel = np.sum(sig_sel_weights)
        bkg_sel_weights = bkg_weights[bkg_bdt_scores > cutoff]
        bkg_sel = np.sum(bkg_sel_weights)
        eff = sig_sel / total_sig if total_sig > 0 else np.nan
        pur = sig_sel / (sig_sel + bkg_sel) if sig_sel + bkg_sel > 0 else np.nan
        all_effs.append(eff)
        all_purs.append(pur)

    nominal_eff = None
    nominal_pur = None
    # evaluate the nominal selection efficiency and purity (default multi-class plotting cut value)
    if eff_eval_cat in ["1gNp", "1g0p", "1gNp1mu", "1g0p1mu", "1g_outFV", 
                        "NC1pi0_Np", "NC1pi0_0p", "numuCC1pi0_Np", "numuCC1pi0_0p", "nueCC_Np", "nueCC_0p", 
                        "del1g_Np", "del1g_0p", "del1g_Np1mu", "del1g_0p1mu", "iso1g",
                        "NCDeltaRad_1gNp", "NCDeltaRad_1g0p", "numuCCDeltaRad_1gNp", "numuCCDeltaRad_1g0p",
                        "nueCC_Np", "nueCC_0p",
                        "numuCC_Np", "numuCC_0p",
                        "eta_other"]:

        curr_reco_cat = eff_eval_cat
        if eff_eval_cat == "del1g_Np" or eff_eval_cat == "NCDeltaRad_1gNp":
            curr_reco_cat = "1gNp"
        elif eff_eval_cat == "del1g_0p" or eff_eval_cat == "iso1g" or eff_eval_cat == "NCDeltaRad_1g0p":
            curr_reco_cat = "1g0p"
        elif eff_eval_cat == "del1g_Np1mu" or eff_eval_cat == "numuCCDeltaRad_1gNp":
            curr_reco_cat = "1gNp1mu"
        elif eff_eval_cat == "del1g_0p1mu" or eff_eval_cat == "numuCCDeltaRad_1g0p":
            curr_reco_cat = "1g0p1mu"
        else:
            curr_reco_cat = eff_eval_cat

        nominal_sel_sig_df = sig_df.filter(reco_category_queries[del1g_simple_categories_dic[curr_reco_cat]])
        nominal_sel_sig_weights = nominal_sel_sig_df["wc_net_weight"].to_numpy()
        nominal_sel_sig = np.sum(nominal_sel_sig_weights)
        nominal_sel_bkg_df = bkg_df.filter(reco_category_queries[del1g_simple_categories_dic[curr_reco_cat]])
        nominal_sel_bkg_weights = nominal_sel_bkg_df["wc_net_weight"].to_numpy()
        nominal_sel_bkg = np.sum(nominal_sel_bkg_weights)
        nominal_eff = nominal_sel_sig / total_sig if total_sig > 0 else np.nan
        nominal_pur = nominal_sel_sig / (nominal_sel_sig + nominal_sel_bkg) if nominal_sel_sig + nominal_sel_bkg > 0 else np.nan

    if eff_eval_cat in ["NCDeltaRad", "NCDeltaRad_1gNp", "NCDeltaRad_1g0p", "del1g", "del1g_Np", "del1g_0p", "iso1g"]:
        # efficiency and purity curve for the WC NC Delta BDT, Np, 0p, and Xp (later, we only plot the proton state that makes sense for the signal category)
        sig_nc_delta_bdt_scores = sig_df["wc_nc_delta_score"].to_numpy()
        bkg_nc_delta_bdt_scores = bkg_df["wc_nc_delta_score"].to_numpy()
        sig_wc_reco_num_protons_35_MeV = sig_df["wc_reco_num_protons_35_MeV"].to_numpy()
        bkg_wc_reco_num_protons_35_MeV = bkg_df["wc_reco_num_protons_35_MeV"].to_numpy()
        sig_wc_match_isFCs = sig_df["wc_match_isFC"].to_numpy()
        bkg_wc_match_isFCs = bkg_df["wc_match_isFC"].to_numpy()
        all_nc_delta_bdt_scores = np.concatenate([sig_nc_delta_bdt_scores, bkg_nc_delta_bdt_scores])
        good_nc_delta_bdt_scores = all_nc_delta_bdt_scores[all_nc_delta_bdt_scores == all_nc_delta_bdt_scores]
        min_delta_score, max_delta_score = np.min(good_nc_delta_bdt_scores), np.max(good_nc_delta_bdt_scores)
        cutoffs = np.linspace(min_delta_score, max_delta_score, num_points)
        all_effs_nc_delta_Xp = []
        all_purs_nc_delta_Xp = []
        all_effs_nc_delta_Np = []
        all_purs_nc_delta_Np = []
        all_effs_nc_delta_0p = []
        all_purs_nc_delta_0p = []
        for cutoff in cutoffs:
            sig_sel_weights = sig_weights[np.logical_and(sig_nc_delta_bdt_scores > cutoff, sig_wc_match_isFCs == 1)]
            sig_sel = np.sum(sig_sel_weights)
            bkg_sel_weights = bkg_weights[np.logical_and(bkg_nc_delta_bdt_scores > cutoff, bkg_wc_match_isFCs == 1)]
            bkg_sel = np.sum(bkg_sel_weights)
            eff = sig_sel / total_sig if total_sig > 0 else np.nan
            pur = sig_sel / (sig_sel + bkg_sel) if sig_sel + bkg_sel > 0 else np.nan
            all_effs_nc_delta_Xp.append(eff)
            all_purs_nc_delta_Xp.append(pur)
            Np_sig_sel_weights = sig_weights[np.logical_and(np.logical_and(sig_nc_delta_bdt_scores > cutoff, sig_wc_match_isFCs == 1), sig_wc_reco_num_protons_35_MeV > 0)]
            Np_sig_sel = np.sum(Np_sig_sel_weights)
            Np_bkg_sel_weights = bkg_weights[np.logical_and(np.logical_and(bkg_nc_delta_bdt_scores > cutoff, bkg_wc_match_isFCs == 1), bkg_wc_reco_num_protons_35_MeV > 0)]
            Np_bkg_sel = np.sum(Np_bkg_sel_weights)
            Np_eff = Np_sig_sel / total_sig if total_sig > 0 else np.nan
            Np_pur = Np_sig_sel / (Np_sig_sel + Np_bkg_sel) if Np_sig_sel + Np_bkg_sel > 0 else np.nan
            all_effs_nc_delta_Np.append(Np_eff)
            all_purs_nc_delta_Np.append(Np_pur)
            zero_p_sig_sel_weights = sig_weights[np.logical_and(np.logical_and(sig_nc_delta_bdt_scores > cutoff, sig_wc_match_isFCs == 1), sig_wc_reco_num_protons_35_MeV == 0)]
            zero_p_sig_sel = np.sum(zero_p_sig_sel_weights)
            zero_p_bkg_sel_weights = bkg_weights[np.logical_and(np.logical_and(bkg_nc_delta_bdt_scores > cutoff, bkg_wc_match_isFCs == 1), bkg_wc_reco_num_protons_35_MeV == 0)]
            zero_p_bkg_sel = np.sum(zero_p_bkg_sel_weights)
            zero_p_eff = zero_p_sig_sel / total_sig if total_sig > 0 else np.nan
            zero_p_pur = zero_p_sig_sel / (zero_p_sig_sel + zero_p_bkg_sel) if zero_p_sig_sel + zero_p_bkg_sel > 0 else np.nan
            all_effs_nc_delta_0p.append(zero_p_eff)
            all_purs_nc_delta_0p.append(zero_p_pur)
        nc_delta_261_Xp_sig_sel_weights = sig_df.filter((pl.col("wc_nc_delta_score") > 2.61) & (pl.col("wc_match_isFC") == 1))["wc_net_weight"].to_numpy()
        nc_delta_261_Xp_sig_sel = np.sum(nc_delta_261_Xp_sig_sel_weights)
        nc_delta_261_Xp_bkg_sel_weights = bkg_df.filter((pl.col("wc_nc_delta_score") > 2.61) & (pl.col("wc_match_isFC") == 1))["wc_net_weight"].to_numpy()
        nc_delta_261_Xp_bkg_sel = np.sum(nc_delta_261_Xp_bkg_sel_weights)
        nc_delta_261_Xp_eff = nc_delta_261_Xp_sig_sel / total_sig if total_sig > 0 else np.nan
        nc_delta_261_Xp_pur = nc_delta_261_Xp_sig_sel / (nc_delta_261_Xp_sig_sel + nc_delta_261_Xp_bkg_sel) if nc_delta_261_Xp_sig_sel + nc_delta_261_Xp_bkg_sel > 0 else np.nan
        nc_delta_261_Np_sig_sel_weights = sig_df.filter((pl.col("wc_nc_delta_score") > 2.61) & (pl.col("wc_match_isFC") == 1) & (pl.col("wc_reco_num_protons_35_MeV") > 0))["wc_net_weight"].to_numpy()
        nc_delta_261_Np_sig_sel = np.sum(nc_delta_261_Np_sig_sel_weights)
        nc_delta_261_Np_bkg_sel_weights = bkg_df.filter((pl.col("wc_nc_delta_score") > 2.61) & (pl.col("wc_match_isFC") == 1) & (pl.col("wc_reco_num_protons_35_MeV") > 0))["wc_net_weight"].to_numpy()
        nc_delta_261_Np_bkg_sel = np.sum(nc_delta_261_Np_bkg_sel_weights)
        nc_delta_261_Np_eff = nc_delta_261_Np_sig_sel / total_sig if total_sig > 0 else np.nan
        nc_delta_261_Np_pur = nc_delta_261_Np_sig_sel / (nc_delta_261_Np_sig_sel + nc_delta_261_Np_bkg_sel) if nc_delta_261_Np_sig_sel + nc_delta_261_Np_bkg_sel > 0 else np.nan
        nc_delta_261_0p_sig_sel_weights = sig_df.filter((pl.col("wc_nc_delta_score") > 2.61) & (pl.col("wc_match_isFC") == 1) & (pl.col("wc_reco_num_protons_35_MeV") == 0))["wc_net_weight"].to_numpy()
        nc_delta_261_0p_sig_sel = np.sum(nc_delta_261_0p_sig_sel_weights)
        nc_delta_261_0p_bkg_sel_weights = bkg_df.filter((pl.col("wc_nc_delta_score") > 2.61) & (pl.col("wc_match_isFC") == 1) & (pl.col("wc_reco_num_protons_35_MeV") == 0))["wc_net_weight"].to_numpy()
        nc_delta_261_0p_bkg_sel = np.sum(nc_delta_261_0p_bkg_sel_weights)
        nc_delta_261_0p_eff = nc_delta_261_0p_sig_sel / total_sig if total_sig > 0 else np.nan
        nc_delta_261_0p_pur = nc_delta_261_0p_sig_sel / (nc_delta_261_0p_sig_sel + nc_delta_261_0p_bkg_sel) if nc_delta_261_0p_sig_sel + nc_delta_261_0p_bkg_sel > 0 else np.nan

    # efficiency and purity curve for the WC nueCC BDT, Np, 0p, and Xp (later, we only plot the proton state that makes sense for the signal category)
    if "nueCC" in eff_eval_cat:
        sig_nue_bdt_scores = sig_df["wc_nue_score"].to_numpy()
        bkg_nue_bdt_scores = bkg_df["wc_nue_score"].to_numpy()
        sig_wc_reco_num_protons_35_MeV = sig_df["wc_reco_num_protons_35_MeV"].to_numpy()
        bkg_wc_reco_num_protons_35_MeV = bkg_df["wc_reco_num_protons_35_MeV"].to_numpy()
        all_nue_bdt_scores = np.concatenate([sig_nue_bdt_scores, bkg_nue_bdt_scores])
        good_nue_bdt_scores = np.nan_to_num(all_nue_bdt_scores, nan=0, posinf=0, neginf=0)
        min_nue_score, max_nue_score = np.min(good_nue_bdt_scores), np.max(good_nue_bdt_scores)
        cutoffs = np.linspace(min_nue_score, max_nue_score, num_points)
        all_effs_nue_Xp = []
        all_purs_nue_Xp = []
        all_effs_nue_Np = []
        all_purs_nue_Np = []
        all_effs_nue_0p = []
        all_purs_nue_0p = []
        for cutoff in cutoffs:
            Xp_sig_sel_weights = sig_weights[sig_nue_bdt_scores > cutoff]
            Xp_sig_sel = np.sum(Xp_sig_sel_weights)
            Xp_bkg_sel_weights = bkg_weights[bkg_nue_bdt_scores > cutoff]
            Xp_bkg_sel = np.sum(Xp_bkg_sel_weights)
            Xp_eff = Xp_sig_sel / total_sig if total_sig > 0 else np.nan
            Xp_pur = Xp_sig_sel / (Xp_sig_sel + Xp_bkg_sel) if Xp_sig_sel + Xp_bkg_sel > 0 else np.nan
            all_effs_nue_Xp.append(Xp_eff)
            all_purs_nue_Xp.append(Xp_pur)
            Np_sig_sel_weights = sig_weights[np.logical_and(sig_nue_bdt_scores > cutoff, sig_wc_reco_num_protons_35_MeV > 0)]
            Np_sig_sel = np.sum(Np_sig_sel_weights)
            Np_bkg_sel_weights = bkg_weights[np.logical_and(bkg_nue_bdt_scores > cutoff, bkg_wc_reco_num_protons_35_MeV > 0)]
            Np_bkg_sel = np.sum(Np_bkg_sel_weights)
            Np_eff = Np_sig_sel / total_sig if total_sig > 0 else np.nan
            Np_pur = Np_sig_sel / (Np_sig_sel + Np_bkg_sel) if Np_sig_sel + Np_bkg_sel > 0 else np.nan
            all_effs_nue_Np.append(Np_eff)
            all_purs_nue_Np.append(Np_pur)
            zero_p_sig_sel_weights = sig_weights[np.logical_and(sig_nue_bdt_scores > cutoff, sig_wc_reco_num_protons_35_MeV == 0)]
            zero_p_sig_sel = np.sum(zero_p_sig_sel_weights)
            zero_p_bkg_sel_weights = bkg_weights[np.logical_and(bkg_nue_bdt_scores > cutoff, bkg_wc_reco_num_protons_35_MeV == 0)]
            zero_p_bkg_sel = np.sum(zero_p_bkg_sel_weights)
            zero_p_eff = zero_p_sig_sel / total_sig if total_sig > 0 else np.nan
            zero_p_pur = zero_p_sig_sel / (zero_p_sig_sel + zero_p_bkg_sel) if zero_p_sig_sel + zero_p_bkg_sel > 0 else np.nan
            all_effs_nue_0p.append(zero_p_eff)
            all_purs_nue_0p.append(zero_p_pur)
        nue_7_Xp_sig_sel_weights = sig_df.filter(pl.col("wc_nue_score") > 7)["wc_net_weight"].to_numpy()
        nue_7_Xp_sig_sel = np.sum(nue_7_Xp_sig_sel_weights)
        nue_7_Xp_bkg_sel_weights = bkg_df.filter(pl.col("wc_nue_score") > 7)["wc_net_weight"].to_numpy()
        nue_7_Xp_bkg_sel = np.sum(nue_7_Xp_bkg_sel_weights)
        nue_7_Xp_eff = nue_7_Xp_sig_sel / total_sig if total_sig > 0 else np.nan
        nue_7_Xp_pur = nue_7_Xp_sig_sel / (nue_7_Xp_sig_sel + nue_7_Xp_bkg_sel) if nue_7_Xp_sig_sel + nue_7_Xp_bkg_sel > 0 else np.nan
        nue_7_Np_sig_sel_weights = sig_df.filter((pl.col("wc_nue_score") > 7) & (pl.col("wc_reco_num_protons_35_MeV") > 0))["wc_net_weight"].to_numpy()
        nue_7_Np_sig_sel = np.sum(nue_7_Np_sig_sel_weights)
        nue_7_Np_bkg_sel_weights = bkg_df.filter((pl.col("wc_nue_score") > 7) & (pl.col("wc_reco_num_protons_35_MeV") > 0))["wc_net_weight"].to_numpy()
        nue_7_Np_bkg_sel = np.sum(nue_7_Np_bkg_sel_weights)
        nue_7_Np_eff = nue_7_Np_sig_sel / total_sig if total_sig > 0 else np.nan
        nue_7_Np_pur = nue_7_Np_sig_sel / (nue_7_Np_sig_sel + nue_7_Np_bkg_sel) if nue_7_Np_sig_sel + nue_7_Np_bkg_sel > 0 else np.nan
        nue_7_0p_sig_sel_weights = sig_df.filter((pl.col("wc_nue_score") > 7) & (pl.col("wc_reco_num_protons_35_MeV") == 0))["wc_net_weight"].to_numpy()
        nue_7_0p_sig_sel = np.sum(nue_7_0p_sig_sel_weights)
        nue_7_0p_bkg_sel_weights = bkg_df.filter((pl.col("wc_nue_score") > 7) & (pl.col("wc_reco_num_protons_35_MeV") == 0))["wc_net_weight"].to_numpy()
        nue_7_0p_bkg_sel = np.sum(nue_7_0p_bkg_sel_weights)
        nue_7_0p_eff = nue_7_0p_sig_sel / total_sig if total_sig > 0 else np.nan
        nue_7_0p_pur = nue_7_0p_sig_sel / (nue_7_0p_sig_sel + nue_7_0p_bkg_sel) if nue_7_0p_sig_sel + nue_7_0p_bkg_sel > 0 else np.nan

    # efficiency and purity curve for the WC numuCC BDT, Np, 0p, and Xp (later, we only plot the proton state that makes sense for the signal category)
    if "numuCC" in eff_eval_cat:
        sig_numuCC_bdt_scores = sig_df["wc_numu_score"].to_numpy()
        bkg_numuCC_bdt_scores = bkg_df["wc_numu_score"].to_numpy()
        sig_wc_reco_num_protons_35_MeV = sig_df["wc_reco_num_protons_35_MeV"].to_numpy()
        bkg_wc_reco_num_protons_35_MeV = bkg_df["wc_reco_num_protons_35_MeV"].to_numpy()
        all_numuCC_bdt_scores = np.concatenate([sig_numuCC_bdt_scores, bkg_numuCC_bdt_scores])
        good_numuCC_bdt_scores = np.nan_to_num(all_numuCC_bdt_scores, nan=0, posinf=0, neginf=0)
        min_numuCC_score, max_numuCC_score = np.min(good_numuCC_bdt_scores), np.max(good_numuCC_bdt_scores)
        cutoffs = np.linspace(min_numuCC_score, max_numuCC_score, num_points)
        all_effs_numuCC_Xp = []
        all_purs_numuCC_Xp = []
        all_effs_numuCC_Np = []
        all_purs_numuCC_Np = []
        all_effs_numuCC_0p = []
        all_purs_numuCC_0p = []
        for cutoff in cutoffs:
            Xp_sig_sel_weights = sig_weights[sig_numuCC_bdt_scores > cutoff]
            Xp_sig_sel = np.sum(Xp_sig_sel_weights)
            Xp_bkg_sel_weights = bkg_weights[bkg_numuCC_bdt_scores > cutoff]
            Xp_bkg_sel = np.sum(Xp_bkg_sel_weights)
            Xp_eff = Xp_sig_sel / total_sig if total_sig > 0 else np.nan
            Xp_pur = Xp_sig_sel / (Xp_sig_sel + Xp_bkg_sel) if Xp_sig_sel + Xp_bkg_sel > 0 else np.nan
            all_effs_numuCC_Xp.append(Xp_eff)
            all_purs_numuCC_Xp.append(Xp_pur)
            Np_sig_sel_weights = sig_weights[np.logical_and(sig_numuCC_bdt_scores > cutoff, sig_wc_reco_num_protons_35_MeV > 0)]
            Np_sig_sel = np.sum(Np_sig_sel_weights)
            Np_bkg_sel_weights = bkg_weights[np.logical_and(bkg_numuCC_bdt_scores > cutoff, bkg_wc_reco_num_protons_35_MeV > 0)]
            Np_bkg_sel = np.sum(Np_bkg_sel_weights)
            Np_eff = Np_sig_sel / total_sig if total_sig > 0 else np.nan
            Np_pur = Np_sig_sel / (Np_sig_sel + Np_bkg_sel) if Np_sig_sel + Np_bkg_sel > 0 else np.nan
            all_effs_numuCC_Np.append(Np_eff)
            all_purs_numuCC_Np.append(Np_pur)
            zero_p_sig_sel_weights = sig_weights[np.logical_and(sig_numuCC_bdt_scores > cutoff, sig_wc_reco_num_protons_35_MeV == 0)]
            zero_p_sig_sel = np.sum(zero_p_sig_sel_weights)
            zero_p_bkg_sel_weights = bkg_weights[np.logical_and(bkg_numuCC_bdt_scores > cutoff, bkg_wc_reco_num_protons_35_MeV == 0)]
            zero_p_bkg_sel = np.sum(zero_p_bkg_sel_weights)
            zero_p_eff = zero_p_sig_sel / total_sig if total_sig > 0 else np.nan
            zero_p_pur = zero_p_sig_sel / (zero_p_sig_sel + zero_p_bkg_sel) if zero_p_sig_sel + zero_p_bkg_sel > 0 else np.nan
            all_effs_numuCC_0p.append(zero_p_eff)
            all_purs_numuCC_0p.append(zero_p_pur)
        numuCC_09_Xp_sig_sel_weights = sig_weights[sig_numuCC_bdt_scores > 0.9]
        numuCC_09_Xp_sig_sel = np.sum(numuCC_09_Xp_sig_sel_weights)
        numuCC_09_Xp_bkg_sel_weights = bkg_weights[bkg_numuCC_bdt_scores > 0.9]
        numuCC_09_Xp_bkg_sel = np.sum(numuCC_09_Xp_bkg_sel_weights)
        numuCC_09_Xp_eff = numuCC_09_Xp_sig_sel / total_sig if total_sig > 0 else np.nan
        numuCC_09_Xp_pur = numuCC_09_Xp_sig_sel / (numuCC_09_Xp_sig_sel + numuCC_09_Xp_bkg_sel) if numuCC_09_Xp_sig_sel + numuCC_09_Xp_bkg_sel > 0 else np.nan
        numuCC_09_Np_sig_sel_weights = sig_weights[np.logical_and(sig_numuCC_bdt_scores > 0.9, sig_wc_reco_num_protons_35_MeV > 0)]
        numuCC_09_Np_sig_sel = np.sum(numuCC_09_Np_sig_sel_weights)
        numuCC_09_Np_bkg_sel_weights = bkg_weights[np.logical_and(bkg_numuCC_bdt_scores > 0.9, bkg_wc_reco_num_protons_35_MeV > 0)]
        numuCC_09_Np_bkg_sel = np.sum(numuCC_09_Np_bkg_sel_weights)
        numuCC_09_Np_eff = numuCC_09_Np_sig_sel / total_sig if total_sig > 0 else np.nan
        numuCC_09_Np_pur = numuCC_09_Np_sig_sel / (numuCC_09_Np_sig_sel + numuCC_09_Np_bkg_sel) if numuCC_09_Np_sig_sel + numuCC_09_Np_bkg_sel > 0 else np.nan
        numuCC_09_0p_sig_sel_weights = sig_weights[np.logical_and(sig_numuCC_bdt_scores > 0.9, sig_wc_reco_num_protons_35_MeV == 0)]
        numuCC_09_0p_sig_sel = np.sum(numuCC_09_0p_sig_sel_weights)
        numuCC_09_0p_bkg_sel_weights = bkg_weights[np.logical_and(bkg_numuCC_bdt_scores > 0.9, bkg_wc_reco_num_protons_35_MeV == 0)]
        numuCC_09_0p_bkg_sel = np.sum(numuCC_09_0p_bkg_sel_weights)
        numuCC_09_0p_eff = numuCC_09_0p_sig_sel / total_sig if total_sig > 0 else np.nan
        numuCC_09_0p_pur = numuCC_09_0p_sig_sel / (numuCC_09_0p_sig_sel + numuCC_09_0p_bkg_sel) if numuCC_09_0p_sig_sel + numuCC_09_0p_bkg_sel > 0 else np.nan

    erin_inclusive_1g_sig_sel_weights = sig_df.filter(pl.col("erin_inclusive_1g_sel") == 1)["wc_net_weight"].to_numpy()
    erin_inclusive_1g_sig_sel = np.sum(erin_inclusive_1g_sig_sel_weights)
    erin_inclusive_1g_bkg_sel_weights = bkg_df.filter(pl.col("erin_inclusive_1g_sel") == 1)["wc_net_weight"].to_numpy()
    erin_inclusive_1g_bkg_sel = np.sum(erin_inclusive_1g_bkg_sel_weights)
    erin_inclusive_1g_eff = erin_inclusive_1g_sig_sel / total_sig if total_sig > 0 else np.nan
    erin_inclusive_1g_pur = erin_inclusive_1g_sig_sel / (erin_inclusive_1g_sig_sel + erin_inclusive_1g_bkg_sel) if erin_inclusive_1g_sig_sel + erin_inclusive_1g_bkg_sel > 0 else np.nan

    # efficiency and purity curve for the WC NC pi0 BDT, Np, 0p, and Xp (later, we only plot the proton state that makes sense for the signal category)
    if "NC1pi0" in eff_eval_cat:
        nc_pi0_presel_sig_df = sig_df.filter(
            (pl.col("wc_kine_reco_Enu") > 0) &
            (pl.col("wc_kine_pio_energy_1") > 0) &
            (pl.col("wc_kine_pio_energy_2") > 0) &
            (pl.col("wc_match_isFC") == 1) &
            ~((pl.col("wc_reco_showerKE") > 0) & (pl.col("wc_nc_delta_score") > 2.61))
        )
        nc_pi0_presel_bkg_df = bkg_df.filter(
            (pl.col("wc_kine_reco_Enu") > 0) &
            (pl.col("wc_kine_pio_energy_1") > 0) &
            (pl.col("wc_kine_pio_energy_2") > 0) &
            (pl.col("wc_match_isFC") == 1) &
            ~((pl.col("wc_reco_showerKE") > 0) & (pl.col("wc_nc_delta_score") > 2.61))
        )
        nc_pi0_presel_sig_weights = nc_pi0_presel_sig_df["wc_net_weight"].to_numpy()
        nc_pi0_presel_bkg_weights = nc_pi0_presel_bkg_df["wc_net_weight"].to_numpy()
        sig_ncpi0_bdt_scores = nc_pi0_presel_sig_df["wc_nc_pio_score"].to_numpy()
        bkg_ncpi0_bdt_scores = nc_pi0_presel_bkg_df["wc_nc_pio_score"].to_numpy()
        sig_wc_reco_num_protons_35_MeV = nc_pi0_presel_sig_df["wc_reco_num_protons_35_MeV"].to_numpy()
        bkg_wc_reco_num_protons_35_MeV = nc_pi0_presel_bkg_df["wc_reco_num_protons_35_MeV"].to_numpy()
        all_ncpi0_bdt_scores = np.concatenate([sig_ncpi0_bdt_scores, bkg_ncpi0_bdt_scores])
        good_ncpi0_bdt_scores = np.nan_to_num(all_ncpi0_bdt_scores, nan=0, posinf=0, neginf=0)
        min_ncpi0_score, max_ncpi0_score = np.min(good_ncpi0_bdt_scores), np.max(good_ncpi0_bdt_scores)
        cutoffs = np.linspace(min_ncpi0_score, max_ncpi0_score, num_points)
        all_effs_ncpi0_Xp = []
        all_purs_ncpi0_Xp = []
        all_effs_ncpi0_Np = []
        all_purs_ncpi0_Np = []
        all_effs_ncpi0_0p = []
        all_purs_ncpi0_0p = []
        for cutoff in cutoffs:
            Xp_sig_sel_weights = nc_pi0_presel_sig_weights[sig_ncpi0_bdt_scores > cutoff]
            Xp_sig_sel = np.sum(Xp_sig_sel_weights)
            Xp_bkg_sel_weights = nc_pi0_presel_bkg_weights[bkg_ncpi0_bdt_scores > cutoff]
            Xp_bkg_sel = np.sum(Xp_bkg_sel_weights)
            Xp_eff = Xp_sig_sel / total_sig if total_sig > 0 else np.nan
            Xp_pur = Xp_sig_sel / (Xp_sig_sel + Xp_bkg_sel) if Xp_sig_sel + Xp_bkg_sel > 0 else np.nan
            all_effs_ncpi0_Xp.append(Xp_eff)
            all_purs_ncpi0_Xp.append(Xp_pur)
            Np_sig_sel_weights = nc_pi0_presel_sig_weights[np.logical_and(sig_ncpi0_bdt_scores > cutoff, sig_wc_reco_num_protons_35_MeV > 0)]
            Np_sig_sel = np.sum(Np_sig_sel_weights)
            Np_bkg_sel_weights = nc_pi0_presel_bkg_weights[np.logical_and(bkg_ncpi0_bdt_scores > cutoff, bkg_wc_reco_num_protons_35_MeV > 0)]
            Np_bkg_sel = np.sum(Np_bkg_sel_weights)
            Np_eff = Np_sig_sel / total_sig if total_sig > 0 else np.nan
            Np_pur = Np_sig_sel / (Np_sig_sel + Np_bkg_sel) if Np_sig_sel + Np_bkg_sel > 0 else np.nan
            all_effs_ncpi0_Np.append(Np_eff)
            all_purs_ncpi0_Np.append(Np_pur)
            zero_p_sig_sel_weights = nc_pi0_presel_sig_weights[np.logical_and(sig_ncpi0_bdt_scores > cutoff, sig_wc_reco_num_protons_35_MeV == 0)]
            zero_p_sig_sel = np.sum(zero_p_sig_sel_weights)
            zero_p_bkg_sel_weights = nc_pi0_presel_bkg_weights[np.logical_and(bkg_ncpi0_bdt_scores > cutoff, bkg_wc_reco_num_protons_35_MeV == 0)]
            zero_p_bkg_sel = np.sum(zero_p_bkg_sel_weights)
            zero_p_eff = zero_p_sig_sel / total_sig if total_sig > 0 else np.nan
            zero_p_pur = zero_p_sig_sel / (zero_p_sig_sel + zero_p_bkg_sel) if zero_p_sig_sel + zero_p_bkg_sel > 0 else np.nan
            all_effs_ncpi0_0p.append(zero_p_eff)
            all_purs_ncpi0_0p.append(zero_p_pur)
        nc_pi0_1816_Xp_sig_sel_weights = nc_pi0_presel_sig_df.filter(pl.col("wc_nc_pio_score") > 1.816)["wc_net_weight"].to_numpy()
        nc_pi0_1816_Xp_sig_sel = np.sum(nc_pi0_1816_Xp_sig_sel_weights)
        nc_pi0_1816_Xp_bkg_sel_weights = nc_pi0_presel_bkg_df.filter(pl.col("wc_nc_pio_score") > 1.816)["wc_net_weight"].to_numpy()
        nc_pi0_1816_Xp_bkg_sel = np.sum(nc_pi0_1816_Xp_bkg_sel_weights)
        nc_pi0_1816_Xp_eff = nc_pi0_1816_Xp_sig_sel / total_sig if total_sig > 0 else np.nan
        nc_pi0_1816_Xp_pur = nc_pi0_1816_Xp_sig_sel / (nc_pi0_1816_Xp_sig_sel + nc_pi0_1816_Xp_bkg_sel) if nc_pi0_1816_Xp_sig_sel + nc_pi0_1816_Xp_bkg_sel > 0 else np.nan
        nc_pi0_1816_Np_sig_sel_weights = nc_pi0_presel_sig_df.filter((pl.col("wc_nc_pio_score") > 1.816) & (pl.col("wc_reco_num_protons_35_MeV") > 0))["wc_net_weight"].to_numpy()
        nc_pi0_1816_Np_sig_sel = np.sum(nc_pi0_1816_Np_sig_sel_weights)
        nc_pi0_1816_Np_bkg_sel_weights = nc_pi0_presel_bkg_df.filter((pl.col("wc_nc_pio_score") > 1.816) & (pl.col("wc_reco_num_protons_35_MeV") > 0))["wc_net_weight"].to_numpy()
        nc_pi0_1816_Np_bkg_sel = np.sum(nc_pi0_1816_Np_bkg_sel_weights)
        nc_pi0_1816_Np_eff = nc_pi0_1816_Np_sig_sel / total_sig if total_sig > 0 else np.nan
        nc_pi0_1816_Np_pur = nc_pi0_1816_Np_sig_sel / (nc_pi0_1816_Np_sig_sel + nc_pi0_1816_Np_bkg_sel) if nc_pi0_1816_Np_sig_sel + nc_pi0_1816_Np_bkg_sel > 0 else np.nan
        nc_pi0_1816_0p_sig_sel_weights = nc_pi0_presel_sig_df.filter((pl.col("wc_nc_pio_score") > 1.816) & (pl.col("wc_reco_num_protons_35_MeV") == 0))["wc_net_weight"].to_numpy()
        nc_pi0_1816_0p_sig_sel = np.sum(nc_pi0_1816_0p_sig_sel_weights)
        nc_pi0_1816_0p_bkg_sel_weights = nc_pi0_presel_bkg_df.filter((pl.col("wc_nc_pio_score") > 1.816) & (pl.col("wc_reco_num_protons_35_MeV") == 0))["wc_net_weight"].to_numpy()
        nc_pi0_1816_0p_bkg_sel = np.sum(nc_pi0_1816_0p_bkg_sel_weights)
        nc_pi0_1816_0p_eff = nc_pi0_1816_0p_sig_sel / total_sig if total_sig > 0 else np.nan
        nc_pi0_1816_0p_pur = nc_pi0_1816_0p_sig_sel / (nc_pi0_1816_0p_sig_sel + nc_pi0_1816_0p_bkg_sel) if nc_pi0_1816_0p_sig_sel + nc_pi0_1816_0p_bkg_sel > 0 else np.nan

    # efficiency and purity for the WC cut-based numuCC Pi0 selection
    if "numuCC1pi0" in eff_eval_cat:
        # see https://github.com/BNLIF/wcp-uboone-bdt/blob/main/inc/WCPLEEANA/cuts.h#L3719
        wc_cutbased_cc_pi0_sig_sel_df = sig_df.filter(
            (pl.col("wc_kine_pio_flag") == 1) &
            (pl.col("wc_kine_pio_vtx_dis") < 9) &
            (pl.col("wc_kine_pio_energy_1") > 40) &
            (pl.col("wc_kine_pio_energy_2") > 25) &
            (pl.col("wc_kine_pio_dis_1") < 110) &
            (pl.col("wc_kine_pio_dis_2") < 120) &
            (pl.col("wc_kine_pio_angle") > 0) &
            (pl.col("wc_kine_pio_angle") < 174) &
            (pl.col("wc_kine_pio_mass") > 22) &
            (pl.col("wc_kine_pio_mass") < 300)
        )
        wc_cutbased_cc_pi0_bkg_sel_df = bkg_df.filter(
            (pl.col("wc_kine_pio_flag") == 1) &
            (pl.col("wc_kine_pio_vtx_dis") < 9) &
            (pl.col("wc_kine_pio_energy_1") > 40) &
            (pl.col("wc_kine_pio_energy_2") > 25) &
            (pl.col("wc_kine_pio_dis_1") < 110) &
            (pl.col("wc_kine_pio_dis_2") < 120) &
            (pl.col("wc_kine_pio_angle") > 0) &
            (pl.col("wc_kine_pio_angle") < 174) &
            (pl.col("wc_kine_pio_mass") > 22) &
            (pl.col("wc_kine_pio_mass") < 300)
        )
        wc_cutbased_cc_pi0_Xp_sig_sel = wc_cutbased_cc_pi0_sig_sel_df["wc_net_weight"].sum()
        wc_cutbased_cc_pi0_Xp_bkg_sel = wc_cutbased_cc_pi0_bkg_sel_df["wc_net_weight"].sum()
        wc_cutbased_cc_pi0_Xp_eff = wc_cutbased_cc_pi0_Xp_sig_sel / total_sig if total_sig > 0 else np.nan
        wc_cutbased_cc_pi0_Xp_pur = wc_cutbased_cc_pi0_Xp_sig_sel / (wc_cutbased_cc_pi0_Xp_sig_sel + wc_cutbased_cc_pi0_Xp_bkg_sel) if (wc_cutbased_cc_pi0_Xp_sig_sel + wc_cutbased_cc_pi0_Xp_bkg_sel) > 0 else np.nan
        wc_cutbased_cc_pi0_Np_sig_sel = wc_cutbased_cc_pi0_sig_sel_df.filter(pl.col("wc_reco_num_protons_35_MeV") > 0)["wc_net_weight"].sum()
        wc_cutbased_cc_pi0_Np_bkg_sel = wc_cutbased_cc_pi0_bkg_sel_df.filter(pl.col("wc_reco_num_protons_35_MeV") > 0)["wc_net_weight"].sum()
        wc_cutbased_cc_pi0_Np_eff = wc_cutbased_cc_pi0_Np_sig_sel / total_sig if total_sig > 0 else np.nan
        wc_cutbased_cc_pi0_Np_pur = wc_cutbased_cc_pi0_Np_sig_sel / (wc_cutbased_cc_pi0_Np_sig_sel + wc_cutbased_cc_pi0_Np_bkg_sel) if (wc_cutbased_cc_pi0_Np_sig_sel + wc_cutbased_cc_pi0_Np_bkg_sel) > 0 else np.nan
        wc_cutbased_cc_pi0_0p_sig_sel = wc_cutbased_cc_pi0_sig_sel_df.filter(pl.col("wc_reco_num_protons_35_MeV") == 0)["wc_net_weight"].sum()
        wc_cutbased_cc_pi0_0p_bkg_sel = wc_cutbased_cc_pi0_bkg_sel_df.filter(pl.col("wc_reco_num_protons_35_MeV") == 0)["wc_net_weight"].sum()
        wc_cutbased_cc_pi0_0p_eff = wc_cutbased_cc_pi0_0p_sig_sel / total_sig if total_sig > 0 else np.nan
        wc_cutbased_cc_pi0_0p_pur = wc_cutbased_cc_pi0_0p_sig_sel / (wc_cutbased_cc_pi0_0p_sig_sel + wc_cutbased_cc_pi0_0p_bkg_sel) if (wc_cutbased_cc_pi0_0p_sig_sel + wc_cutbased_cc_pi0_0p_bkg_sel) > 0 else np.nan

    # no-selection efficiency and purity (efficiency should always be 1)

    nosel_sig_sel = total_sig
    nosel_bkg_sel = total_bkg
    nosel_eff = nosel_sig_sel / total_sig if total_sig > 0 else np.nan
    nosel_pur = nosel_sig_sel / (nosel_sig_sel + nosel_bkg_sel) if nosel_sig_sel + nosel_bkg_sel > 0 else np.nan

    # Wire-Cell generic neutrino selection efficiency and purity
    generic_sig_sel_df = sig_df
    generic_sig_sel = generic_sig_sel_df["wc_net_weight"].sum()
    generic_bkg_sel_df = bkg_df
    generic_bkg_sel = generic_bkg_sel_df["wc_net_weight"].sum()
    generic_eff = generic_sig_sel / total_sig if total_sig > 0 else np.nan
    generic_pur = generic_sig_sel / (generic_sig_sel + generic_bkg_sel) if generic_sig_sel + generic_bkg_sel > 0 else np.nan

    plt.figure(figsize=(10, 6))

    plt.scatter(nosel_eff, nosel_pur, label="No Selection", marker="*", edgecolor="black", color="tab:green", s=500)
    plt.scatter(generic_eff, generic_pur, label="Generic Selection", marker="*", edgecolors="black", color="tab:red", s=500)

    plt.plot(all_effs, all_purs, color="tab:blue", label="Multi-Class BDT")

    if nominal_eff is not None and nominal_pur is not None:
        plt.scatter(nominal_eff, nominal_pur, label="Nominal Multi-Class BDT Sel", marker="*", edgecolors="black", color="tab:cyan", s=500)

    # plot the WC nueCC BDT efficiencies and purities only for nueCC signal categories
    if eff_eval_cat == "nueCC_Np":
        plt.plot(all_effs_nue_Np, all_purs_nue_Np, color="tab:pink", label="WC nue BDT")
        plt.scatter(nue_7_Np_eff, nue_7_Np_pur, label=r"WC $\nu_e$ CC $Np$ Sel", marker="*", edgecolors="black", color="tab:pink", s=500)
    elif eff_eval_cat == "nueCC_0p":
        plt.plot(all_effs_nue_0p, all_purs_nue_0p, color="tab:pink", label="WC nue BDT")
        plt.scatter(nue_7_0p_eff, nue_7_0p_pur, label=r"WC $\nu_e$ CC $0p$ Sel", marker="*", edgecolors="black", color="tab:pink", s=500)
    elif eff_eval_cat == "nueCC":
        plt.plot(all_effs_nue_Xp, all_purs_nue_Xp, color="tab:pink", label="WC nue BDT")
        plt.scatter(nue_7_Xp_eff, nue_7_Xp_pur, label=r"WC $\nu_e$ CC $Xp$ Sel", marker="*", edgecolors="black", color="tab:pink", s=500)

    # plot the WC numuCC BDT efficiencies and purities only for numuCC signal categories
    if eff_eval_cat == "numuCC_Np":
        plt.plot(all_effs_numuCC_Np, all_purs_numuCC_Np, color="tab:gray", label="WC numu BDT")
        plt.scatter(numuCC_09_Np_eff, numuCC_09_Np_pur, label=r"WC $\nu_\mu$ CC $Np$ Sel", marker="*", edgecolors="black", color="tab:gray", s=500)
    elif eff_eval_cat == "numuCC_0p":
        plt.plot(all_effs_numuCC_0p, all_purs_numuCC_0p, color="tab:gray", label="WC numu BDT")
        plt.scatter(numuCC_09_0p_eff, numuCC_09_0p_pur, label=r"WC $\nu_\mu$ CC $0p$ Sel", marker="*", edgecolors="black", color="tab:gray", s=500)
    elif eff_eval_cat == "numuCC":
        plt.plot(all_effs_numuCC_Xp, all_purs_numuCC_Xp, color="tab:gray", label="WC numu BDT")
        plt.scatter(numuCC_09_Xp_eff, numuCC_09_Xp_pur, label=r"WC $\nu_\mu$ CC $Xp$ Sel", marker="*", edgecolors="black", color="tab:gray", s=500)

    # plot the WC NC pi0 BDT efficiencies and purities only for NC pi0 signal categories
    if eff_eval_cat == "NC1pi0_Np":
        plt.plot(all_effs_ncpi0_Np, all_purs_ncpi0_Np, color="tab:olive", label=r"WC NC $\pi^0$ $Np$ BDT")
        plt.scatter(nc_pi0_1816_Np_eff, nc_pi0_1816_Np_pur, label=r"WC NC $\pi^0$ $Np$ Sel", marker="*", edgecolors="black", color="tab:olive", s=500)
    elif eff_eval_cat == "NC1pi0_0p":
        plt.plot(all_effs_ncpi0_0p, all_purs_ncpi0_0p, color="tab:olive", label=r"WC NC $\pi^0$ $0p$ BDT")
        plt.scatter(nc_pi0_1816_0p_eff, nc_pi0_1816_0p_pur, label=r"WC NC $\pi^0$ $0p$ Sel", marker="*", edgecolors="black", color="tab:olive", s=500)
    elif eff_eval_cat == "NC1pi0":
        plt.plot(all_effs_ncpi0_Xp, all_purs_ncpi0_Xp, color="tab:olive", label=r"WC NC $\pi^0$ $Xp$ BDT")
        plt.scatter(nc_pi0_1816_Xp_eff, nc_pi0_1816_Xp_pur, label=r"WC NC $\pi^0$ $Xp$ Sel", marker="*", edgecolors="black", color="tab:olive", s=500)

    # plot the WC numuCC pi0 BDT efficiencies and purities only for numuCC pi0 signal categories
    if eff_eval_cat == "numuCC1pi0_Np":
        plt.scatter(wc_cutbased_cc_pi0_Np_eff, wc_cutbased_cc_pi0_Np_pur, label=r"WC $\nu_\mu$ CC $\pi^0$ $Np$ Sel", marker="*", edgecolors="black", color="tab:olive", s=500)
    elif eff_eval_cat == "numuCC1pi0_0p":
        plt.scatter(wc_cutbased_cc_pi0_0p_eff, wc_cutbased_cc_pi0_0p_pur, label=r"WC $\nu_\mu$ CC $\pi^0$ $0p$ Sel", marker="*", edgecolors="black", color="tab:olive", s=500)
    elif eff_eval_cat == "numuCC1pi0":
        plt.scatter(wc_cutbased_cc_pi0_Xp_eff, wc_cutbased_cc_pi0_Xp_pur, label=r"WC $\nu_\mu$ CC $\pi^0$ $Xp$ Sel", marker="*", edgecolors="black", color="tab:olive", s=500)

    # plot the WC NC Delta BDT efficiencies and purities only for 1g signal categories
    if eff_eval_cat == "NCDeltaRad_1gNp" or eff_eval_cat == "del1g_Np":
        plt.plot(all_effs_nc_delta_Np, all_purs_nc_delta_Np, color="tab:orange", label=r"WC NC Delta $Np$ Sel")
        plt.scatter(nc_delta_261_Np_eff, nc_delta_261_Np_pur, label=r"WC NC Delta $Np$ Sel", marker="*", edgecolors="black", color="tab:orange", s=500)
    elif eff_eval_cat == "NCDeltaRad_1g0p" or eff_eval_cat == "del1g_0p" or eff_eval_cat == "iso1g":
        plt.plot(all_effs_nc_delta_0p, all_purs_nc_delta_0p, color="tab:orange", label=r"WC NC Delta $0p$ Sel")
        plt.scatter(nc_delta_261_0p_eff, nc_delta_261_0p_pur, label=r"WC NC Delta $0p$ Sel", marker="*", edgecolors="black", color="tab:orange", s=500)
    elif eff_eval_cat == "NCDeltaRad" or eff_eval_cat == "del1g":
        plt.plot(all_effs_nc_delta_Xp, all_purs_nc_delta_Xp, color="tab:orange", label=r"WC NC Delta $Xp$ Sel")
        plt.scatter(nc_delta_261_Xp_eff, nc_delta_261_Xp_pur, label=r"WC NC Delta $Xp$ Sel", marker="*", edgecolors="black", color="tab:orange", s=500)

    # always show Erin's inclusive 1g selection for comparison
    plt.scatter(erin_inclusive_1g_eff, erin_inclusive_1g_pur, label="Erin Inclusive 1g Sel", marker="*", edgecolors="black", color="tab:purple", s=500)

    if eff_eval_cat == "eta_other":
        # https://arxiv.org/pdf/2305.16249
        plt.scatter(0.136, 0.499, label=r"David Pandora $\eta$ Sel\n(eff/pur for only $\eta$ events)", marker="*", edgecolors="black", color="tab:brown", s=500)
    
    plt.xlabel("Efficiency")
    plt.ylabel("Purity")
    plt.title(f"True {eff_eval_cat_latex} Eff vs Pur")
    plt.xlim(0, 1)
    ymin, ymax = plt.ylim()
    plt.ylim(0, ymax)
    plt.legend(loc="upper right")

    plt.savefig(f"../plots/true_{eff_eval_cat}_eff_vs_pur.png")
    plt.show()


# Kinematic Efficiencies

In [None]:
def frac_efficiency_stat_error(unweighted_sel, unweighted_total):
    
    unweighted_eff = np.divide(
        unweighted_sel,
        unweighted_total,
        out=np.full_like(unweighted_sel, np.nan, dtype=float),
        where=unweighted_total != 0
    )
    
    # Calculate sqrt term only where denominator is non-zero
    sqrt_term = np.divide(
        unweighted_eff * (1 - unweighted_eff),
        unweighted_total,
        out=np.full_like(unweighted_sel, np.nan, dtype=float),
        where=unweighted_total != 0
    )
    
    # Take sqrt only of non-negative values
    sqrt_term = np.sqrt(np.maximum(sqrt_term, 0))
    
    error = np.divide(
        sqrt_term,
        unweighted_eff,
        out=np.full_like(unweighted_sel, np.nan, dtype=float),
        where=(unweighted_total != 0) & (unweighted_eff != 0)
    )
    
    return error

all_iso1g_sig_df = merged_df.filter(pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["iso1g"])
presel_iso1g_sig_df = presel_merged_df.filter(pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["iso1g"])

all_del1g_sig_df = merged_df.filter(
    (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np"]) |
    (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p"]) |
    (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np1mu"]) |
    (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p1mu"])
)
presel_del1g_sig_df = presel_merged_df.filter(
    (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np"]) |
    (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p"]) |
    (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_Np1mu"]) |
    (pl.col("del1g_detailed_signal_category") == del1g_detailed_categories_dic["del1g_0p1mu"])
)

presel_iso1g_sig_1g0p_df = presel_iso1g_sig_df.filter(pl.col("reco_category") == del1g_simple_categories_dic["1g0p"])
presel_del1g_sig_1g_combined_df = presel_del1g_sig_df.filter(
    (pl.col("reco_category") == del1g_simple_categories_dic["1gNp"]) |
    (pl.col("reco_category") == del1g_simple_categories_dic["1g0p"]) |
    (pl.col("reco_category") == del1g_simple_categories_dic["1gNp1mu"]) |
    (pl.col("reco_category") == del1g_simple_categories_dic["1g0p1mu"]  )
)

generic_iso1g_sig_df = all_iso1g_sig_df.filter(pl.col("wc_kine_reco_Enu") > 0)
generic_del1g_sig_df = all_del1g_sig_df.filter(pl.col("wc_kine_reco_Enu") > 0)


# 1D Shower Kinematics

## Iso1g Shower Energy

In [None]:
bins = np.linspace(0, 3000, 31)
bin_centers = (bins[:-1] + bins[1:]) / 2

sig_counts = np.histogram(all_iso1g_sig_df["wc_true_leading_shower_energy"].to_numpy(), weights=all_iso1g_sig_df["wc_net_weight"].to_numpy(), bins=bins)[0]
sig_counts_no_weight = np.histogram(all_iso1g_sig_df["wc_true_leading_shower_energy"].to_numpy(), bins=bins)[0]

plt.figure(figsize=(10, 6))

plt.hist(all_iso1g_sig_df["wc_true_leading_shower_energy"].to_numpy(), weights=1e-4*np.ones(all_iso1g_sig_df.height), bins=bins, 
    color="tab:grey", alpha=0.5, zorder=-1, label=r"All Iso1g signal" "\n(arbitrary units)")

for eff_eval_cat in ["1g0p", "presel", "generic"]:
    if eff_eval_cat == "1g0p":
        sig_sel_df = presel_iso1g_sig_1g0p_df
    elif eff_eval_cat == "presel":
        sig_sel_df = presel_iso1g_sig_df
    elif eff_eval_cat == "generic":
        sig_sel_df = generic_iso1g_sig_df

    sig_sel_counts = np.histogram(sig_sel_df["wc_true_leading_shower_energy"].to_numpy(), weights=sig_sel_df["wc_net_weight"].to_numpy(), bins=bins)[0]
    sig_sel_counts_no_weight = np.histogram(sig_sel_df["wc_true_leading_shower_energy"].to_numpy(), bins=bins)[0]
    ratios = np.nan_to_num(sig_sel_counts / sig_counts)
    errs = ratios * frac_efficiency_stat_error(sig_sel_counts_no_weight, sig_counts_no_weight)
    plt.errorbar(bin_centers, ratios, yerr=errs, fmt='o', markersize=5, capsize=5, label=f"{eff_eval_cat} Selection")

plt.xlabel("True Leading Shower Energy (MeV)")
plt.ylabel("Efficiency")
plt.title("Iso1g Signal Selection Efficiency")
plt.legend()
plt.xlim(bins[0], bins[-1])
#plt.ylim(0, 0.3)
plt.savefig("../plots/iso1g_shower_energy_eff.pdf")


## Del1g truth_energyInside

In [None]:
bins = np.linspace(0, 3000, 31)
bin_centers = (bins[:-1] + bins[1:]) / 2

# Totals
sig_counts = np.histogram(all_del1g_sig_df["wc_truth_energyInside"].to_numpy(), weights=all_del1g_sig_df["wc_net_weight"].to_numpy(), bins=bins)[0]
sig_counts_no_weight = np.histogram(all_del1g_sig_df["wc_truth_energyInside"].to_numpy(), bins=bins)[0]

plt.figure(figsize=(10, 6))

plt.hist(all_iso1g_sig_df["wc_truth_energyInside"].to_numpy(), weights=2e-5*np.ones(all_iso1g_sig_df.height), bins=bins, 
    color="tab:grey", alpha=0.5, zorder=-1, label=r"All Iso1g signal" "\n(arbitrary units)")

for eff_eval_cat in ["1g", "presel", "generic"]:
    if eff_eval_cat == "1g":
        sig_sel_df = presel_del1g_sig_1g_combined_df
    elif eff_eval_cat == "presel":
        sig_sel_df = presel_del1g_sig_df
    elif eff_eval_cat == "generic":
        sig_sel_df = generic_del1g_sig_df

    sig_sel_counts = np.histogram(sig_sel_df["wc_truth_energyInside"].to_numpy(), weights=sig_sel_df["wc_net_weight"].to_numpy(), bins=bins)[0]
    sig_sel_counts_no_weight = np.histogram(sig_sel_df["wc_truth_energyInside"].to_numpy(), bins=bins)[0]
    ratios = np.divide(
            sig_sel_counts,
            sig_counts,
            out=np.full_like(sig_sel_counts, np.nan, dtype=float),
            where=sig_counts != 0
        )
    errs = ratios * frac_efficiency_stat_error(sig_sel_counts_no_weight, sig_counts_no_weight)
    plt.errorbar(bin_centers, ratios, yerr=errs, fmt='o', markersize=5, capsize=5, label=f"{eff_eval_cat} Selection")

plt.xlabel("True Energy Deposited in TPC (MeV)")
plt.ylabel("Efficiency")
plt.title("Del1g Signal Selection Efficiency")
plt.legend()
plt.xlim(bins[0], bins[-1])
#plt.ylim(0, 0.3)
plt.savefig("../plots/iso1g_shower_energy_eff.pdf")


# 2D Shower Kinematics

## Iso1g

In [None]:
all_iso1g_sig_true_shower_energy = all_iso1g_sig_df["wc_true_leading_shower_energy"].to_numpy()
all_iso1g_sig_true_shower_costheta = all_iso1g_sig_df["wc_true_leading_shower_costheta"].to_numpy()
all_iso1g_sig_weights = all_iso1g_sig_df["wc_net_weight"].to_numpy()

x_bin_edges = np.linspace(0., 3000., 31)
y_bin_edges = np.linspace(-1., 1., 21)
bins = (x_bin_edges, y_bin_edges)

plt.rcParams['font.size'] = 16

plt.figure(dpi=100, figsize=(10,7))
ax = plt.gca()
plt.hist2d(all_iso1g_sig_true_shower_energy, all_iso1g_sig_true_shower_costheta, weights=all_iso1g_sig_weights, bins=bins, norm=mpl.colors.LogNorm())
plt.xlabel("True Primary Photon Energy (MeV)")
plt.ylabel("True Primary Photon Cos(theta)")
plt.colorbar(label="Count (weighted to 1.11e21 POT)")
plt.title("All True 1g Signal", pad=15)
plt.savefig("../plots/iso1g_2d_all_1g.pdf")

for eff_eval_cat in ["1g0p", "presel", "generic"]:
    if eff_eval_cat == "1g0p":
        sig_sel_df = presel_iso1g_sig_1g0p_df
    elif eff_eval_cat == "presel":
        sig_sel_df = presel_iso1g_sig_df
    elif eff_eval_cat == "generic":
        sig_sel_df = generic_iso1g_sig_df

    counts_sig, x_edges_, y_edges_ = np.histogram2d(all_iso1g_sig_df["wc_true_leading_shower_energy"].to_numpy(), all_iso1g_sig_df["wc_true_leading_shower_costheta"].to_numpy(), 
            bins=bins, weights=all_iso1g_sig_df["wc_net_weight"].to_numpy())
    counts_sig_sel, x_edges_, y_edges_ = np.histogram2d(sig_sel_df["wc_true_leading_shower_energy"].to_numpy(), sig_sel_df["wc_true_leading_shower_costheta"].to_numpy(), 
            bins=bins, weights=sig_sel_df["wc_net_weight"].to_numpy())
    wc_eff_arr = []
    bin_center_x_arr = []
    bin_center_y_arr = []
    for row in range(len(y_bin_edges) - 1):
        for col in range(len(x_bin_edges) - 1):
            bin_center_x_arr.append(x_bin_edges[col] + (x_bin_edges[1] - x_bin_edges[0]) / 2.)
            bin_center_y_arr.append(y_bin_edges[row] + (y_bin_edges[1] - y_bin_edges[0]) / 2.)
            numerator = counts_sig_sel[col][row]
            denominator = counts_sig[col][row]
            wc_eff_arr.append(numerator / denominator)
    plt.figure(dpi=100, figsize=(10,7))
    ax = plt.gca()
    plt.hist2d(bin_center_x_arr, bin_center_y_arr, weights=wc_eff_arr, bins=bins)
    plt.xlabel("True Primary Photon Energy (MeV)")
    plt.ylabel(r"True Primary Photon Cos($\theta$)")
    plt.colorbar()
    #plt.clim(0., 0.15)
    plt.title(f"Efficiency of {eff_eval_cat} Selection For all True Iso1g Signal", pad=15)
    plt.savefig(f"../plots/is1g_2d_eff_{eff_eval_cat}.pdf")


## Del1g

In [None]:
all_del1g_sig_true_shower_energy = all_del1g_sig_df["wc_true_leading_shower_energy"].to_numpy()
all_del1g_sig_true_shower_costheta = all_del1g_sig_df["wc_true_leading_shower_costheta"].to_numpy()
all_del1g_sig_weights = all_del1g_sig_df["wc_net_weight"].to_numpy()

x_bin_edges = np.linspace(0., 3000., 31)
y_bin_edges = np.linspace(-1., 1., 21)
bins = (x_bin_edges, y_bin_edges)

plt.rcParams['font.size'] = 16

plt.figure(dpi=100, figsize=(10,7))
ax = plt.gca()
plt.hist2d(all_del1g_sig_true_shower_energy, all_del1g_sig_true_shower_costheta, weights=all_del1g_sig_weights, bins=bins, norm=mpl.colors.LogNorm())
plt.xlabel("True Primary Photon Energy (MeV)")
plt.ylabel("True Primary Photon Cos(theta)")
plt.colorbar(label="Count (weighted to 1.11e21 POT)")
plt.title("All True 1g Signal", pad=15)
plt.savefig("../plots/del1g_2d_all_1g.pdf")

for eff_eval_cat in ["1g", "presel", "generic"]:
    if eff_eval_cat == "1g":
        sig_sel_df = presel_del1g_sig_1g_combined_df
    elif eff_eval_cat == "presel":
        sig_sel_df = presel_del1g_sig_df
    elif eff_eval_cat == "generic":
        sig_sel_df = generic_del1g_sig_df

    counts_sig, x_edges_, y_edges_ = np.histogram2d(all_del1g_sig_df["wc_true_leading_shower_energy"].to_numpy(), all_del1g_sig_df["wc_true_leading_shower_costheta"].to_numpy(), 
            bins=bins, weights=all_del1g_sig_df["wc_net_weight"].to_numpy())
    counts_sig_sel, x_edges_, y_edges_ = np.histogram2d(sig_sel_df["wc_true_leading_shower_energy"].to_numpy(), sig_sel_df["wc_true_leading_shower_costheta"].to_numpy(), 
            bins=bins, weights=sig_sel_df["wc_net_weight"].to_numpy())
    wc_eff_arr = []
    bin_center_x_arr = []
    bin_center_y_arr = []
    for row in range(len(y_bin_edges) - 1):
        for col in range(len(x_bin_edges) - 1):
            bin_center_x_arr.append(x_bin_edges[col] + (x_bin_edges[1] - x_bin_edges[0]) / 2.)
            bin_center_y_arr.append(y_bin_edges[row] + (y_bin_edges[1] - y_bin_edges[0]) / 2.)
            numerator = counts_sig_sel[col][row]
            denominator = counts_sig[col][row]
            wc_eff_arr.append(numerator / denominator)
    plt.figure(dpi=100, figsize=(10,7))
    ax = plt.gca()
    plt.hist2d(bin_center_x_arr, bin_center_y_arr, weights=wc_eff_arr, bins=bins)
    plt.xlabel("True Primary Photon Energy (MeV)")
    plt.ylabel(r"True Primary Photon Cos($\theta$)")
    plt.colorbar()
    #plt.clim(0., 0.15)
    plt.title(f"Efficiency of {eff_eval_cat} Selection For all True Del1g Signal", pad=15)
    plt.savefig(f"../plots/del1g_2d_eff_{eff_eval_cat}.pdf")


# Pi0 Efficiencies

In [None]:
all_true_ncpi0Np_df = merged_df.filter(
    (pl.col("iso1g_overlay") == False) &
    (pl.col("del1g_overlay") == False) &
    (pl.col("del1g_simple_signal_category") == "NC1pi0_Np")
)
all_true_notncpi0Np_df = merged_df.filter(
    (pl.col("iso1g_overlay") == False) &
    (pl.col("del1g_overlay") == False) &
    (pl.col("del1g_simple_signal_category") != "NC1pi0_Np")
)
total_true_ncpi0Np_sig = all_true_ncpi0Np_df["wc_net_weight"].sum()
wc_ncpi0Np_sel_true_ncpi0Np_df = all_true_ncpi0Np_df.filter(
    (pl.col("wc_reco_num_protons_35_MeV") > 0) &
    (pl.col("wc_nc_pio_score") > 1.816) &
    (pl.col("wc_kine_reco_Enu") > 0) &
    (pl.col("wc_kine_pio_energy_1") > 0) &
    (pl.col("wc_kine_pio_energy_2") > 0) &
    (pl.col("wc_match_isFC") == 1) &
    ~((pl.col("wc_reco_showerKE") > 0) & (pl.col("wc_nc_delta_score") > 2.61))
)
wc_ncpi0Np_sel_true_notncpi0Np_df = all_true_notncpi0Np_df.filter(
    (pl.col("wc_reco_num_protons_35_MeV") > 0) &
    (pl.col("wc_nc_pio_score") > 1.816) &
    (pl.col("wc_kine_reco_Enu") > 0) &
    (pl.col("wc_kine_pio_energy_1") > 0) &
    (pl.col("wc_kine_pio_energy_2") > 0) &
    (pl.col("wc_match_isFC") == 1) &
    ~((pl.col("wc_reco_showerKE") > 0) & (pl.col("wc_nc_delta_score") > 2.61))
)
wc_ncpi0Np_sel_true_ncpi0Np_count = wc_ncpi0Np_sel_true_ncpi0Np_df["wc_net_weight"].sum()
wc_ncpi0Np_sel_true_notncpi0Np_count = wc_ncpi0Np_sel_true_notncpi0Np_df["wc_net_weight"].sum()
wc_ncpi0Np_sel_true_ncpi0Np_eff = wc_ncpi0Np_sel_true_ncpi0Np_count / total_true_ncpi0Np_sig
wc_ncpi0Np_sel_true_ncpi0Np_pur = wc_ncpi0Np_sel_true_ncpi0Np_count / (wc_ncpi0Np_sel_true_ncpi0Np_count + wc_ncpi0Np_sel_true_notncpi0Np_count)

all_true_ncpi00p_df = merged_df.filter(
    (pl.col("iso1g_overlay") == False) &
    (pl.col("del1g_overlay") == False) &
    (pl.col("del1g_simple_signal_category") == "NC1pi0_0p")
)
all_true_notncpi00p_df = merged_df.filter(
    (pl.col("iso1g_overlay") == False) &
    (pl.col("del1g_overlay") == False) &
    (pl.col("del1g_simple_signal_category") != "NC1pi0_0p")
)
total_true_ncpi00p_sig = all_true_ncpi00p_df["wc_net_weight"].sum()
wc_ncpi00p_sel_true_ncpi00p_df = all_true_ncpi00p_df.filter(
    (pl.col("wc_reco_num_protons_35_MeV") == 0) &
    (pl.col("wc_nc_pio_score") > 1.816) &
    (pl.col("wc_kine_reco_Enu") > 0) &
    (pl.col("wc_kine_pio_energy_1") > 0) &
    (pl.col("wc_kine_pio_energy_2") > 0) &
    (pl.col("wc_match_isFC") == 1) &
    ~((pl.col("wc_reco_showerKE") > 0) & (pl.col("wc_nc_delta_score") > 2.61))
)
wc_ncpi00p_sel_true_notncpi00p_df = all_true_notncpi00p_df.filter(
    (pl.col("wc_reco_num_protons_35_MeV") == 0) &
    (pl.col("wc_nc_pio_score") > 1.816) &
    (pl.col("wc_kine_reco_Enu") > 0) &
    (pl.col("wc_kine_pio_energy_1") > 0) &
    (pl.col("wc_kine_pio_energy_2") > 0) &
    (pl.col("wc_match_isFC") == 1) &
    ~((pl.col("wc_reco_showerKE") > 0) & (pl.col("wc_nc_delta_score") > 2.61))
)
wc_ncpi00p_sel_true_ncpi00p_count = wc_ncpi00p_sel_true_ncpi00p_df["wc_net_weight"].sum()
wc_ncpi00p_sel_true_notncpi00p_count = wc_ncpi00p_sel_true_notncpi00p_df["wc_net_weight"].sum()
wc_ncpi00p_sel_true_ncpi00p_eff = wc_ncpi00p_sel_true_ncpi00p_count / total_true_ncpi00p_sig
wc_ncpi00p_sel_true_ncpi00p_pur = wc_ncpi00p_sel_true_ncpi00p_count / (wc_ncpi00p_sel_true_ncpi00p_count + wc_ncpi00p_sel_true_notncpi00p_count)

print(f"WC NC pi0 Np true Np Eff: {wc_ncpi0Np_sel_true_ncpi0Np_eff}, Pur: {wc_ncpi0Np_sel_true_ncpi0Np_pur}")
print(f"WC NC pi0 0p true 0p Eff: {wc_ncpi00p_sel_true_ncpi00p_eff}, Pur: {wc_ncpi00p_sel_true_ncpi00p_pur}")


In [None]:
# tuned to equal purity as above
multiBDT_ncpi0_Np_cut_value = 0.218
multiBDT_ncpi0_0p_cut_value = 0.0745

multiBDT_presel_true_ncpi0Np_df = all_true_ncpi0Np_df.filter(pl.col("wc_kine_reco_Enu") > 0)
multiBDT_ncpi0Np_sel_true_ncpi0Np_df = multiBDT_presel_true_ncpi0Np_df.filter(pl.col("prob_NC1pi0_Np") > multiBDT_ncpi0_Np_cut_value)
multiBDT_presel_true_notncpi0Np_df = all_true_notncpi0Np_df.filter(pl.col("wc_kine_reco_Enu") > 0)
multiBDT_ncpi0_presel_true_notncpi0Np_df = multiBDT_presel_true_notncpi0Np_df.filter(pl.col("prob_NC1pi0_Np") > multiBDT_ncpi0_Np_cut_value)
multiBDT_ncpi0Np_sel_true_ncpi0Np_count = multiBDT_ncpi0Np_sel_true_ncpi0Np_df["wc_net_weight"].sum()
multiBDT_ncpi0Np_sel_true_notncpi0Np_count = multiBDT_ncpi0_presel_true_notncpi0Np_df["wc_net_weight"].sum()
multiBDT_ncpi0Np_sel_true_ncpi0Np_eff = multiBDT_ncpi0Np_sel_true_ncpi0Np_count / total_true_ncpi0Np_sig
multiBDT_ncpi0Np_sel_true_ncpi0Np_pur = multiBDT_ncpi0Np_sel_true_ncpi0Np_count / (multiBDT_ncpi0Np_sel_true_ncpi0Np_count + multiBDT_ncpi0Np_sel_true_notncpi0Np_count)

multiBDT_presel_true_ncpi00p_df = all_true_ncpi00p_df.filter(pl.col("wc_kine_reco_Enu") > 0)
multiBDT_ncpi00p_sel_true_ncpi00p_df = multiBDT_presel_true_ncpi00p_df.filter(pl.col("prob_NC1pi0_0p") > multiBDT_ncpi0_0p_cut_value)
multiBDT_presel_true_notncpi00p_df = all_true_notncpi00p_df.filter(pl.col("wc_kine_reco_Enu") > 0)
multiBDT_ncpi0_presel_true_notncpi00p_df = multiBDT_presel_true_notncpi00p_df.filter(pl.col("prob_NC1pi0_0p") > multiBDT_ncpi0_0p_cut_value)
multiBDT_ncpi00p_sel_true_ncpi00p_count = multiBDT_ncpi00p_sel_true_ncpi00p_df["wc_net_weight"].sum()
multiBDT_ncpi00p_sel_true_notncpi00p_count = multiBDT_ncpi0_presel_true_notncpi00p_df["wc_net_weight"].sum()
multiBDT_ncpi00p_sel_true_ncpi00p_eff = multiBDT_ncpi00p_sel_true_ncpi00p_count / total_true_ncpi00p_sig
multiBDT_ncpi00p_sel_true_ncpi00p_pur = multiBDT_ncpi00p_sel_true_ncpi00p_count / (multiBDT_ncpi00p_sel_true_ncpi00p_count + multiBDT_ncpi00p_sel_true_notncpi00p_count)

print(f"MultiBDT NC pi0 Np true Np Eff: {multiBDT_ncpi0Np_sel_true_ncpi0Np_eff}, Pur: {multiBDT_ncpi0Np_sel_true_ncpi0Np_pur}")
print(f"MultiBDT NC pi0 0p true 0p Eff: {multiBDT_ncpi00p_sel_true_ncpi00p_eff}, Pur: {multiBDT_ncpi00p_sel_true_ncpi00p_pur}")


In [None]:
for variable in ["wc_truth_energyInside", "wc_true_leading_pi0_energy", "wc_true_leading_pi0_costheta", "wc_true_leading_pi0_opening_angle",
                "wc_true_leading_shower_energy", "wc_true_leading_shower_costheta", "wc_true_subleading_shower_energy", "wc_true_subleading_shower_costheta"]:

    if variable == "wc_truth_energyInside":
        bins = np.linspace(0, 3000, 31)
    elif variable == "wc_true_leading_pi0_energy":
        bins = np.linspace(0, 1000, 21)
    elif variable == "wc_true_leading_pi0_costheta":
        bins = np.linspace(-1, 1, 21)
    elif variable == "wc_true_leading_pi0_opening_angle":
        bins = np.linspace(0, 180, 19)
    elif variable == "wc_true_leading_shower_energy":
        bins = np.linspace(0, 3000, 31)
    elif variable == "wc_true_leading_shower_costheta":
        bins = np.linspace(-1, 1, 21)
    elif variable == "wc_true_subleading_shower_energy":
        bins = np.linspace(0, 3000, 31)
    elif variable == "wc_true_subleading_shower_costheta":
        bins = np.linspace(-1, 1, 21)
    else:
        raise ValueError(f"No bins defined for variable {variable}")

    bin_centers = (bins[:-1] + bins[1:]) / 2

    sig_counts = np.histogram(all_true_ncpi0Np_df[variable].to_numpy(), weights=all_true_ncpi0Np_df["wc_net_weight"].to_numpy(), bins=bins)[0]
    sig_counts_no_weight = np.histogram(all_true_ncpi0Np_df[variable].to_numpy(), bins=bins)[0]

    # NC Pi0 Np
    plt.figure(figsize=(10, 6))
    plt.hist(all_true_ncpi0Np_df[variable].to_numpy(), weights=2e-5*np.ones(all_true_ncpi0Np_df.height), bins=bins, 
        color="tab:grey", alpha=0.5, zorder=-1, label=r"All NC $\pi^0$ True Np signal" "\n(arbitrary units)")
    for eff_eval_cat in ["wc_sel", "multiBDT_sel", "multiBDT_presel", "generic"]:
        if eff_eval_cat == "wc_sel":
            sig_sel_df = wc_ncpi0Np_sel_true_ncpi0Np_df
            label = r"WC NC $\pi^0$ $Np$ Sel"
        elif eff_eval_cat == "multiBDT_sel":
            sig_sel_df = multiBDT_ncpi0Np_sel_true_ncpi0Np_df
            label = r"MultiBDT NC $\pi^0$ $Np$ Sel"
        elif eff_eval_cat == "multiBDT_presel":
            sig_sel_df = multiBDT_presel_true_ncpi0Np_df
            label = r"MultiBDT Presel"
        elif eff_eval_cat == "generic":
            multiBDT_generic_true_ncpi0Np_df = all_true_ncpi0Np_df.filter(pl.col("wc_kine_reco_Enu") > 0)
            sig_sel_df = multiBDT_generic_true_ncpi0Np_df
            label = r"WC Generic Sel"
        sig_sel_counts = np.histogram(sig_sel_df[variable].to_numpy(), weights=sig_sel_df["wc_net_weight"].to_numpy(), bins=bins)[0]
        sig_sel_counts_no_weight = np.histogram(sig_sel_df[variable].to_numpy(), bins=bins)[0]
        ratios = np.divide(
            sig_sel_counts,
            sig_counts,
            out=np.full_like(sig_sel_counts, np.nan, dtype=float),
            where=sig_counts != 0
        )
        errs = ratios * frac_efficiency_stat_error(sig_sel_counts_no_weight, sig_counts_no_weight)
        plt.errorbar(bin_centers, ratios, yerr=errs, fmt='o', markersize=5, capsize=5, label=f"{label}")
    if variable == "wc_truth_energyInside":
        plt.xlabel("True Energy Deposited in TPC (MeV)")
    elif variable == "wc_true_leading_pi0_energy":
        plt.xlabel("True Leading Pi0 Energy (MeV)")
    elif variable == "wc_true_leading_pi0_costheta":
        plt.xlabel("True Leading Pi0 Cos(theta)")
    elif variable == "wc_true_leading_pi0_opening_angle":
        plt.xlabel("True Leading Pi0 Opening Angle (degrees)")
    elif variable == "wc_true_leading_shower_energy":
        plt.xlabel("True Leading Shower Energy (MeV)")
    elif variable == "wc_true_leading_shower_costheta":
        plt.xlabel("True Leading Shower Cos(theta)")
    elif variable == "wc_true_subleading_shower_energy":
        plt.xlabel("True Subleading Shower Energy (MeV)")
    elif variable == "wc_true_subleading_shower_costheta":
        plt.xlabel("True Subleading Shower Cos(theta)")
    else:
        raise ValueError(f"No x-axis label for variable {variable}")
    plt.ylabel("Efficiency")
    plt.title("NC1Pi0 True Np Signal Selection Efficiency\nComparison at Equal Purity")
    plt.legend(loc="upper right")
    plt.xlim(bins[0], bins[-1])
    plt.savefig(f"../plots/ncpi0_Np_eff_{variable}.pdf")

    # NC Pi0 0p
    plt.figure(figsize=(10, 6))
    plt.hist(all_true_ncpi00p_df[variable].to_numpy(), weights=2e-5*np.ones(all_true_ncpi00p_df.height), bins=bins, 
        color="tab:grey", alpha=0.5, zorder=-1, label=r"All NC $\pi^0$ True 0p signal" "\n(arbitrary units)")
    for eff_eval_cat in ["wc_sel", "multiBDT_sel", "multiBDT_presel", "generic"]:
        if eff_eval_cat == "wc_sel":
            sig_sel_df = wc_ncpi00p_sel_true_ncpi00p_df
            label = r"WC NC $\pi^0$ $0p$ Sel"
        elif eff_eval_cat == "multiBDT_sel":
            sig_sel_df = multiBDT_ncpi00p_sel_true_ncpi00p_df
            label = r"MultiBDT NC $\pi^0$ $0p$ Sel"
        elif eff_eval_cat == "multiBDT_presel":
            sig_sel_df = multiBDT_presel_true_ncpi00p_df
            label = r"MultiBDT Presel"
        elif eff_eval_cat == "generic":
            multiBDT_generic_true_ncpi00p_df = all_true_ncpi00p_df.filter(pl.col("wc_kine_reco_Enu") > 0)
            sig_sel_df = multiBDT_generic_true_ncpi00p_df
            label = r"WC Generic Sel"
        sig_sel_counts = np.histogram(sig_sel_df[variable].to_numpy(), weights=sig_sel_df["wc_net_weight"].to_numpy(), bins=bins)[0]
        sig_sel_counts_no_weight = np.histogram(sig_sel_df[variable].to_numpy(), bins=bins)[0]
        ratios = np.divide(
            sig_sel_counts,
            sig_counts,
            out=np.full_like(sig_sel_counts, np.nan, dtype=float),
            where=sig_counts != 0
        )
        errs = ratios * frac_efficiency_stat_error(sig_sel_counts_no_weight, sig_counts_no_weight)
        plt.errorbar(bin_centers, ratios, yerr=errs, fmt='o', markersize=5, capsize=5, label=f"{label}")
    if variable == "wc_truth_energyInside":
        plt.xlabel("True Energy Deposited in TPC (MeV)")
    elif variable == "wc_true_leading_pi0_energy":
        plt.xlabel("True Leading Pi0 Energy (MeV)")
    elif variable == "wc_true_leading_pi0_costheta":
        plt.xlabel("True Leading Pi0 Cos(theta)")
    elif variable == "wc_true_leading_pi0_opening_angle":
        plt.xlabel("True Leading Pi0 Opening Angle (degrees)")
    elif variable == "wc_true_leading_shower_energy":
        plt.xlabel("True Leading Shower Energy (MeV)")
    elif variable == "wc_true_leading_shower_costheta":
        plt.xlabel("True Leading Shower Cos(theta)")
    elif variable == "wc_true_subleading_shower_energy":
        plt.xlabel("True Subleading Shower Energy (MeV)")
    elif variable == "wc_true_subleading_shower_costheta":
        plt.xlabel("True Subleading Shower Cos(theta)")
    else:
        raise ValueError(f"No x-axis label for variable {variable}")
    plt.ylabel("Efficiency")
    plt.title("NC1Pi0 True 0p Signal Selection Efficiency\nComparison at Equal Purity")
    plt.legend(loc="upper right")
    plt.xlim(bins[0], bins[-1])
    plt.savefig(f"../plots/ncpi0_0p_eff_{variable}.pdf")
