In [None]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import matplotlib as mpl

from tqdm.auto import tqdm

import warnings
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

from multiprocess import Pool

In [None]:
import pyanalib.pandas_helpers as ph
from makedf.util import *

import kinematics
import gump_cuts as gc
import loaddf
import syst
import importlib

In [None]:
PLOTDIR = "/Users/gputnam/Work/osc/cafpyana/plots-1-20-w-dirt/"

DOSAVE = True

import os
os.makedirs(PLOTDIR, exist_ok=True)
os.makedirs(PLOTDIR + "/png", exist_ok=True)
os.makedirs(PLOTDIR + "/pdf", exist_ok=True)

In [None]:
DETECTOR = "SBND"
INCLUDE_DIRT = True

In [None]:
DF_DIR = "/Users/gputnam/Work/osc/dfs/sbn-rewgted-4/"

if DETECTOR == "ICARUS":
    ONBEAM = DF_DIR + "ICARUS_SpringRun2BNB_unblind_prescaled.df"
    OFFBEAM = DF_DIR + "ICARUS_SpringRun2BNBOff_unblind_prescaled.df"
    
    ONBEAMPOT = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/Run2_BNB_uncalo_unblind_POT.df"

    MC_FILES = [DF_DIR + "ICARUS_SpringMCOverlay_rewgt.df"]
    DIRT_FILES = [DF_DIR + "ICARUS_SpringMCDirt_slimwgt.df"]
    
    DETVAR_FILES = []
    DETVAR_NAMES = []

elif DETECTOR == "SBND": 
    ONBEAM = DF_DIR + "SBND_SpringBNBData_Dev.df"
    OFFBEAM = DF_DIR + "SBND_SpringBNBOffData_5000.df"

    MC_FILES = [DF_DIR + "SBND_SpringMC_rewgt_%i.df" % i for i in range(5)]
    DIRT_FILES = [DF_DIR + "SBND_SpringLowEMC_rewgt_%i.df" % i for i in range(10)]

    DETVAR_FILES = [
        DF_DIR + "SBND_SpringMC_WMXThetaXW.df",
        DF_DIR + "SBND_SpringMC_WMYZ.df",
        DF_DIR + "SBND_SpringMC_0xSCE.df",
        DF_DIR + "SBND_SpringMC_2xSCE.df",
    ]
    DETVAR_NAMES = ["WM $X\\theta_{xw}$", "WM $YZ$", "0x SCE", "2x SCE"]
    
# elif DETECTOR == "SBND SPINE": 
    # ONBEAM = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/SBND_SPINE_SpringBNBDevData.df"
    # OFFBEAM = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/SBND_SPINE_SpringBNBOffData.df"

In [None]:
import h5py

def read_dfs(file, key):
    with h5py.File(file, "r") as f:
        keys = [k for k in f.keys() if k.startswith(key)]
        return pd.concat([pd.read_hdf(file, k) for k in keys])

In [None]:
if DETECTOR == "ICARUS":
    ngates_ON = read_dfs(ONBEAM, "trig").gate_delta.sum()*(1-1/100.)
    ngates_OFF = read_dfs(OFFBEAM, "trig").gate_delta.sum()*(1-1/20.)
    
    OFF_w = ngates_ON / ngates_OFF
elif "SBND" in DETECTOR:
    ngates_ON = read_dfs(ONBEAM, "bnb").shape[0]
    ngates_OFF = read_dfs(OFFBEAM, "hdr").noffbeambnb.sum()

    f_factor = 0.0754
    OFF_w = (1. - f_factor) * (ngates_ON) / (ngates_OFF)
    
ngates_ON, ngates_OFF, OFF_w

In [None]:
if DETECTOR == "ICARUS":
    print("ON:", 1/read_dfs(ONBEAM, "trig").gate_delta.mean(), "OFF:", 1/read_dfs(OFFBEAM, "trig").gate_delta.mean())

In [None]:
if DETECTOR == "ICARUS":
    # POT = pd.read_hdf(ONBEAMPOT).pot.sum()*1e12
    POT = read_dfs(ONBEAM, "hdr").merge(pd.read_hdf(ONBEAMPOT), left_index=True, right_index=True, how="left").pot_y.sum()*1e12
elif "SBND" in DETECTOR:
    POT = read_dfs(ONBEAM, "bnb").TOR875.sum()
    
POT

In [None]:
read_dfs(ONBEAM, "bnb").TOR875.sum()/ 1e19

In [None]:
print("N GATES ON / 5e12 POT")
print(5e12*ngates_ON/POT)

In [None]:
NEVT_ON = read_dfs(ONBEAM, "hdr").shape[0]
NEVT_OFF = read_dfs(OFFBEAM, "hdr").shape[0]

NEVT = NEVT_ON - NEVT_OFF*OFF_w

In [None]:
NEVT_ON, POT, NEVT_ON / (POT / 1e15)

In [None]:
NEVT, POT, NEVT / (POT / 1e15)

In [None]:
crt_ON = read_dfs(ONBEAM, "crt")
crt_OFF = read_dfs(OFFBEAM, "crt")

In [None]:
def top_crt(crtdf):
    return (crtdf.plane >= 30) & (crtdf.plane <= 40)

def side_crt(crtdf):
    return ~top_crt(crtdf)

In [None]:
# CRT_intime_hit_ON = crt_ON.time[top_crt(crt_ON) & (crt_ON.time > -1)].groupby(level=[0,1]).min()
# CRT_intime_hit_OFF = crt_OFF.time[top_crt(crt_OFF) & (crt_OFF.time > -1)].groupby(level=[0,1]).min()

CRT_intime_hit_ON = crt_ON.time[side_crt(crt_ON)]
CRT_intime_hit_OFF = crt_OFF.time[side_crt(crt_OFF)]

In [None]:
bins = np.linspace(-2, 3, 51)
# bins = np.linspace(-10, 10, 101)

N,bins = np.histogram(CRT_intime_hit_ON, bins=bins)
centers = (bins[:-1] + bins[1:]) / 2

plt.errorbar(centers, N, np.sqrt(N), color="black", linestyle="none", marker=".", label="Beam ON")

Noff,_ = np.histogram(CRT_intime_hit_OFF, bins=bins)
plt.errorbar(centers, Noff*OFF_w, np.sqrt(Noff)*OFF_w, color="red", linestyle="none", marker=".", label="Beam OFF")

plt.legend()

In [None]:
importlib.reload(loaddf)

In [None]:
df, mcpot = loaddf.loadl(MC_FILES, njob=min(len(MC_FILES), 10))

In [None]:
# LOAD DETECTOR VARIATION SAMPLES

if len(DETVAR_FILES) > 0:
    detvars, detvar_pots = zip(*tqdm([loaddf.load(f, include_syst=False) for f in DETVAR_FILES]))
else:
    detvars = []
    detvar_pots = []

In [None]:
detvars = list(detvars)

In [None]:
def scale_pot(df, pot, desired_pot):
    """Scale DataFrame by desired POT."""
    print(f"POT: {pot}\nScaling to: {desired_pot}")
    scale = POT / pot
    df['glob_scale'] = scale
    return pot, scale

scale_pot(df, mcpot, POT)
for i in range(len(detvars)):
    print(DETVAR_NAMES[i])
    scale_pot(detvars[i], detvar_pots[i], POT)

In [None]:
def ICARUS_dirtcut(df):
    vtx = pd.DataFrame({
        "x": df.true_vtx_x,
        "y": df.true_vtx_y,
        "z": df.true_vtx_z,
    })
    return ~np.isnan(df.true_vtx_x) & ~gc._fv_cut(vtx, "ICARUS", 0, 0, 0, 0)

if INCLUDE_DIRT:
    dirt, dirtpot = loaddf.loadl(DIRT_FILES, include_syst=False)

    scale_pot(dirt, dirtpot, POT)
    
    if DETECTOR == "ICARUS":
        dirt = dirt[ICARUS_dirtcut(dirt)]

    dirt["crthit"] = False
    dirt["dirt"] = True
    df["dirt"] = False

In [None]:
if INCLUDE_DIRT:
    # add dirt to the CV, and also include it (below) as a 100 syst unc.
    df = pd.concat([df, dirt])
    
    # disable systematic weights associated with dirt events
    for c in df.columns:
        if "univ" in c:
            df[c] = df[c].fillna(1)
            
    # also add it to the detector variations, for consistency
    for i in range(len(detvars)):
        detvars[i] = pd.concat([detvars[i], dirt])

In [None]:
# ADD IN PID VARIATIONS
def v_variation(df, setvars):
    df = df[[c for c in df.columns if "univ" not in c]].copy()
    for (new, old) in setvars:
        df[new] = df[old]
    return df

def v_chi2smear(df):
    setvars = [
        ("mu_chi2_of_mu_cand", "mu_chi2smear_of_mu_cand"),
        ("mu_chi2_of_p_cand",  "mu_chi2smear_of_prot_cand"),
        ("prot_chi2_of_mu_cand", "prot_chi2smear_of_mu_cand"),
        ("prot_chi2_of_p_cand",  "prot_chi2smear_of_prot_cand"),
    ]
    return v_variation(df, setvars)


def v_chi2hi(df):
    setvars = [
        ("mu_chi2_of_mu_cand", "mu_chi2hi_of_mu_cand"),
        ("mu_chi2_of_p_cand",  "mu_chi2hi_of_prot_cand"),
        ("prot_chi2_of_mu_cand", "prot_chi2hi_of_mu_cand"),
        ("prot_chi2_of_p_cand",  "prot_chi2hi_of_prot_cand"),
    ]
    return v_variation(df, setvars)

In [None]:
detvars += [v_chi2smear(df), v_chi2hi(df)]

DETVAR_NAMES += ["Smeared dE/dx", "Gain Hi"]

In [None]:
# Define systematic uncertainties
systematics = syst.SystematicList([
    loaddf.FluxSystematic(df),
    loaddf.XSecSystematic(df),
    syst.SystematicList([syst.SampleSystematic(d) for d in detvars]),
] + ([] if not INCLUDE_DIRT else [syst.SystSampleSystematic(dirt)]))

In [None]:
ONdf,_ = loaddf.load(ONBEAM, load_truth=False, include_syst=False)
OFFdf,_ = loaddf.load(OFFBEAM, load_truth=False, include_syst=False)

In [None]:
mode_list = [1, 10, 0]

mode_labels = ["Cosmic", "Dirt", 'Other $\\nu$', '$\\nu_\\mu$ CC RES', '$\\nu_\\mu$ CC MEC', '$\\nu_\\mu$ CC QE']
mode_colors = ["#95af8b", "#43140b", "#c89648", "#1e3f54", "#d54c28", "#315031"]

def nuFV(df):
    vtx = pd.DataFrame({
        "x": df.true_vtx_x,
        "y": df.true_vtx_y,
        "z": df.true_vtx_z,
    })
    return ~np.isnan(df.true_vtx_x) & gc._fv_cut(vtx, DETECTOR, 0, 0, 0, 0)

def breakdown_mode(var, df):
    """Break down variable by interaction mode."""
    numu_cc = (np.abs(df.true_nu_pdg) == 14) & (df.true_isnc == False)
    fid = nuFV(df)
    fid = ~df.dirt

    ret = [
        var[np.isnan(df.genie_mode)],
        var[~fid & ~np.isnan(df.genie_mode)],
        var[(~np.any([df.genie_mode == i for i in mode_list], axis=0) | ~numu_cc) & fid & ~np.isnan(df.genie_mode)]
    ] +\
        [var[(df.genie_mode == i) & numu_cc & fid] for i in mode_list]
        
    return ret

In [None]:
pdg_list = [2212, 13, 211]
pdg_labels = ["$p$", "$\\mu$", "$\\pi^\\pm$", "Other"]
pdg_colors = ["#315031", "#d54c28", "#1e3f54", "#c89648"]

def breakdown_pdg(var, df, particle="p"):
    ret = [var[np.abs(df["%s_true_pdg" % particle] == i)] for i in pdg_list]
    ret.append(var[sum([np.abs(df["%s_true_pdg" % particle] == i) for i in pdg_list]) == 0])
    return ret

In [None]:
FONTSIZE = 14
HAWKS_COLORS = ["#315031", "#d54c28", "#1e3f54", "#c89648", "#43140b", "#95af8b"]

def add_style(ax, xlabel, title="", det="ICARUS"):
    ax.tick_params(axis='both', which='both', direction='in', length=6, width=1.5, labelsize=FONTSIZE, top=True, right=True)
    for spine in ax.spines.values():
        spine.set_linewidth(1.5)
    ax.set_xlabel(xlabel, fontsize=FONTSIZE, fontweight='bold')
    ax.set_ylabel('Area Normalized', fontsize=FONTSIZE, fontweight='bold')
    ax.set_title(f"$\\bf{{{det}}}$  {title}", fontsize=FONTSIZE+2)
    ax.legend(fontsize=FONTSIZE)


In [None]:
def f_chi2(NMC, Ndata, cov):
    # ignore singular entries
    which_bin = NMC > 0

    NMC = NMC[which_bin]
    Ndata = Ndata[which_bin]
    cov = cov[which_bin, :]
    cov = cov[:, which_bin]

    delta = NMC - Ndata
    try:
        cov_inv = np.linalg.inv(cov)
    except np.linalg.LinAlgError as _:
        return -1, which_bin.sum()
        
    return delta@cov_inv@delta, which_bin.sum()

In [None]:
def make_plot_data(var, bins, cut, mc_weight, breakdown, areanorm, breakdown_labels, breakdown_colors, xlabel, title, 
                   det="ICARUS", fillna=np.nan, nsystuniv=50):
    
    pvars = breakdown(df.loc[df[cut], var].fillna(fillna), df[df[cut]])
    weights = breakdown(df.loc[df[cut], mc_weight], df[df[cut]])

    NMC_breakdown = []
    for pvar, w in zip(pvars, weights):    
        thisNMC, bins = np.histogram(pvar, bins=bins, weights=w)
        NMC_breakdown.append(thisNMC)
        
    NMC,_ = np.histogram(df.loc[df[cut], var].fillna(fillna), bins=bins, weights=df.loc[df[cut], mc_weight])
    NMC_abs = NMC
    if areanorm:
        diff = (bins[1:] - bins[:-1])
        norm = np.sum(NMC*diff)
        if norm > 1e-5:
            NMC = NMC / norm
            for i in range(len(NMC_breakdown)):
                NMC_breakdown[i] = NMC_breakdown[i] / norm

    NMC_breakdown = np.array(NMC_breakdown)
        
    NON,_ = np.histogram(ONdf.loc[ONdf[cut], var].fillna(fillna), bins=bins)
    NOff,_ = np.histogram(OFFdf.loc[OFFdf[cut], var].fillna(fillna), bins=bins)

    N = NON - NOff*OFF_w
    Nerr = np.sqrt(NON + NOff*OFF_w**2)
    if areanorm:
        diff = (bins[1:] - bins[:-1])
        
        norm = np.sum(N*diff)
        if norm > 1e-5:
            N = N / norm
            Nerr = Nerr / norm

    cov = systematics.cov(var, cut, bins, NMC_abs, shapeonly=areanorm)
    err = np.sqrt(np.diag(cov))

    cov_w_stat = cov + np.diag(Nerr**2) # add stat uncertainty
    chi2, ndof = f_chi2(NMC, N, cov_w_stat)

    return {
        "det": det,
        "title": title,
        "xlabel": xlabel,
        "bins": bins,
        "areanorm": areanorm,
        "breakdown_labels": breakdown_labels,
        "breakdown_colors": breakdown_colors,
        "NMC_breakdown": NMC_breakdown,
        "NMC_total": NMC,
        "NData": N,
        "NDataErr": Nerr,
        "cov": cov,
        "cov_w_stat": cov_w_stat,
        "chi2": chi2,
        "ndof": ndof,
        "POT": POT
    }


In [None]:
def ratio_plot(plt, plotdata):
    fig, (ax0, ax1) = plt.subplots(2, 1, height_ratios=[3, 1], sharex=True)
    bins = plotdata["bins"]
    centers = (bins[:-1] + bins[1:])/2

    NMC_breakdown = plotdata["NMC_breakdown"]
    fill = np.array([centers for _ in range(NMC_breakdown.shape[0])]).T
    ax0.hist(fill, bins=bins, stacked=True, label=plotdata["breakdown_labels"],
                    color=plotdata["breakdown_colors"], weights=NMC_breakdown.T)

    NData = plotdata["NData"]
    NDataErr = plotdata["NDataErr"]
    line = ax0.errorbar(centers, NData, NDataErr, color="black", linestyle="none", marker=".")

    NMC = plotdata["NMC_total"]
    err = np.sqrt(np.diag(plotdata["cov"]))
    ax0.fill_between(bins[:-1], NMC+err, NMC-err, facecolor="none", hatch="//", edgecolor="gray", linewidth=0.0, step="post")

    ax1.errorbar(centers, NData/NMC, NDataErr/NMC, color="black", linestyle="none", marker=".")
    ax1.set_ylim([0.5, 1.5])
    ax1.axhline([1], color="red", linestyle="--")
    ax1.fill_between(bins[:-1], 1+err/NMC, 1-err/NMC, facecolor="none", hatch="//", edgecolor="gray", linewidth=0.0, step="post")

    ax0.tick_params(axis='both', which='both', direction='in', length=6, width=1.5, labelsize=FONTSIZE, top=True, right=True)
    ax1.tick_params(axis='both', which='both', direction='in', length=6, width=1.5, labelsize=FONTSIZE, top=True, right=True)
    for spine in ax0.spines.values():
        spine.set_linewidth(1.5)
    ax1.set_xlabel(plotdata["xlabel"], fontsize=FONTSIZE, fontweight='bold')
    
    if plotdata["areanorm"]:
        ax0.set_ylabel('Area Normalized' % (plotdata["POT"]/1e19), fontsize=FONTSIZE, fontweight='bold')
    else:
        ax0.set_ylabel('Events / %.1f$\\times 10^{19}$ POT' % (plotdata["POT"]/1e19), fontsize=FONTSIZE, fontweight='bold')

    det = plotdata["det"]
    title = plotdata["title"]
    ax0.set_title(f"$\\bf{{{det}}}$ {title}", fontsize=FONTSIZE+2)
    ld = ax0.legend([line], ["Data\n(ON Beam - OFF)"], frameon=False, loc="upper left", fontsize=10)

    ax0_l0, ax0_hi = ax0.get_ylim()
    ax0.set_ylim([ax0_l0, ax0_hi*1.2])
    
    ax0.legend(fontsize=12)
    ax0.add_artist(ld)

    chi2_str = "$\\chi^2_\\mathrm{shape}$" if plotdata["areanorm"] else "$\\chi^2$"
    ax0.text(0.5, 0.98, "%s: %.1f / %i" % (chi2_str, plotdata["chi2"], plotdata["ndof"] - int(plotdata["areanorm"])),
            verticalalignment="top", horizontalalignment="center", fontsize=FONTSIZE-2, transform=ax0.transAxes)
    
    plt.subplots_adjust(hspace=0.05)
    

In [None]:
def FV(df):    
    det = DETECTOR.split(" ")[0]
    is_spine = "SPINE" in DETECTOR
    
    ret = gc.slcfv_cut(df, det) & gc.mufv_cut(df, det) & gc.pfv_cut(df, det) 
    
    if is_spine:
        ret = ret & (df.is_time_contained)
    
    return ret
    
def simple_cosmic_rej(df):
    is_spine = "SPINE" in DETECTOR
    return FV(df) & (df.crlongtrkdiry > -0.3)

def crtveto(df):
    return FV(df) & ~df.crthit

def twoprong_cut(df):
    return FV(df) & np.isnan(df.other_shw_length) & np.isnan(df.other_trk_length)

def pid_cut(df):
    is_spine = "SPINE" in DETECTOR
    if not is_spine:
        return twoprong_cut(df) & gc.pid_cut_df(df)
    else:
        return twoprong_cut(df) & (df.prot_chi2_of_prot_cand > 0.6) & (df.mu_chi2_of_mu_cand > 0.6)


In [None]:
# if DETECTOR == "ICARUS" or DETECTOR == "SBND":
#     cuts = [
#         FV,
#         crtveto,
#         simple_cosmic_rej,
#         twoprong_cut,
#         pid_cut,
#     ]
    
#     cutnames = [
#         "Contained",
#         "CRT Veto",
#         "Simple Cos. Rej.",
#         "Two Prong Cut",
#         "PID Cut",
#     ]
# elif DETECTOR == "SBND SPINE":
#     cuts = [
#         FV,
#         twoprong_cut,
#         pid_cut,
#     ]
    
#     cutnames = [
#         "Contained",
#         "Two Prong Cut",
#         "PID Cut",
#     ]

# plotvars = [
#     "crlongtrkdiry",
#     "nu_score",
#     "mu_chi2_of_prot_cand",
#     "prot_chi2_of_prot_cand",
#     "mu_chi2_of_mu_cand",
#     "prot_chi2_of_mu_cand",  
# ]

# if "SPINE" not in DETECTOR:
#     bins = [
#         np.linspace(-1,1,21),
#         np.linspace(0, 1, 21),
#         np.linspace(0, 80, 21),
#         np.linspace(0, 300, 21),
#         np.linspace(0, 80, 21),
#         np.linspace(0, 300, 21),
#     ]
# else:
#     bins = [
#         np.linspace(-1,1,21),
#         np.linspace(0, 1, 21),
#         np.linspace(0, 1, 21),
#         np.linspace(0, 1, 21),
#         np.linspace(0, 1, 21),
#         np.linspace(0, 1, 21),
#     ]

# labels = [
#     "CRLongTrkDirY",
#     "$\\nu$ Score",
#     "Proton Cand. $\\mu$-like PID",
#     "Proton Cand. $p$-like PID",
#     "Muon Cand. $\\mu$-like PID",
#     "Muon Cand. $p$-like PID",
# ]

# for c, cname in zip(cuts, cutnames):
#     df[cname] = c(df)
#     ONdf[cname] = c(ONdf)
#     OFFdf[cname] = c(OFFdf)

#     if INCLUDE_DIRT:
#         dirt[cname] = c(dirt)

#     for i in range(len(detvars)):
#         detvars[i][cname] = c(detvars[i])

In [None]:
# def inner(dat):
#     (v, b, l, cut, cutname) = dat
#     return v, make_plot_data(v, b, cutname, "glob_scale", breakdown_mode, False, mode_labels, 
#                                   mode_colors, l, cutname, fillna=-1, det=DETECTOR)

In [None]:
# all_plotdata_normed = {}


# with Pool(10) as p:
#     for cut, cutname in zip(cuts, cutnames):
#         all_plotdata_normed[cutname] = {}
#         inputs = [(v, b, l, cut, cutname) for (v, b, l) in zip(plotvars, bins, labels)]
#         for v, plotdata in tqdm(p.imap_unordered(inner, inputs), total=len(inputs)):    
#             all_plotdata_normed[cutname][v] = plotdata


In [None]:
# ifig = 0
# for cname in cutnames:
#     for v in plotvars:
#         plt.figure(ifig)
#         ratio_plot(plt, all_plotdata_normed[cname][v])
        
#         if DOSAVE:
#             savename_pdf = PLOTDIR + "/pdf/%s_%s_%s_potnorm.pdf" % (all_plotdata_normed[cname][v]["det"], cname.replace(" ", "").replace(".", "").lower(), v)
#             savename_png = PLOTDIR + "/png/%s_%s_%s_potnorm.png" % (all_plotdata_normed[cname][v]["det"], cname.replace(" ", "").replace(".", "").lower(), v)
#             plt.savefig(savename_pdf, bbox_inches="tight")
#             plt.savefig(savename_png, bbox_inches="tight")
#             plt.close()
#         else:
#             ifig += 1

In [None]:
if DETECTOR == "ICARUS" or DETECTOR == "SBND":
    cuts = [
        FV,
        crtveto,
        simple_cosmic_rej,
        twoprong_cut,
        pid_cut,
    ]
    
    cutnames = [
        "Contained",
        "CRT Veto",
        "Simple Cos. Rej.",
        "Two Prong Cut",
        "PID Cut",
    ]
elif DETECTOR == "SBND SPINE":
    cuts = [
        FV,
        twoprong_cut,
        pid_cut,
    ]
    
    cutnames = [
        "Contained",
        "Two Prong Cut",
        "PID Cut",
    ]

plotvars = [
    "crlongtrkdiry",
    "nu_score",
    "other_trk_length",
    "other_shw_length",
    "mu_chi2_of_prot_cand",
    "prot_chi2_of_prot_cand",
    "mu_chi2_of_mu_cand",
    "prot_chi2_of_mu_cand",  
    "del_p",
    "p_len",
]

if "SPINE" not in DETECTOR:
    bins = [
        np.linspace(-1,1,21),
        np.linspace(0, 1, 21),
        np.array([-5] + list(np.linspace(0, 20, 5))),
        np.array([-10] + list(np.linspace(0, 100, 11))),
        np.linspace(0, 80, 21),
        np.linspace(0, 300, 21),
        np.linspace(0, 80, 21),
        np.linspace(0, 300, 21),
        np.linspace(0, 1.5, 16),
        np.linspace(0, 50, 11)
    ]
else:
    bins = [
        np.linspace(-1,1,21),
        np.linspace(0, 1, 21),
        np.array([-5] + list(np.linspace(0, 20, 5))),
        np.array([-10] + list(np.linspace(0, 100, 11))),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        # np.linspace(0, 1.5, 16),
        np.linspace(0, 1.0, 11),
        np.linspace(0, 50, 11)
    ]

labels = [
    "CRLongTrkDirY",
    "$\\nu$ Score",
    "Maximum Third Track Length [cm]",
    "Maximum Shower Length [cm]",
    "Proton Cand. $\\mu$-like PID",
    "Proton Cand. $p$-like PID",
    "Muon Cand. $\\mu$-like PID",
    "Muon Cand. $p$-like PID",
    "Transverse Momentum [GeV]",
    "Proton Cand. Length [cm]",
]

for c, cname in zip(cuts, cutnames):
    df[cname] = c(df)
    ONdf[cname] = c(ONdf)
    OFFdf[cname] = c(OFFdf)

    if INCLUDE_DIRT:
        dirt[cname] = c(dirt)

    for i in range(len(detvars)):
        detvars[i][cname] = c(detvars[i])

In [None]:
def inner(dat):
    (v, b, l, cut, cutname) = dat
    return v, make_plot_data(v, b, cutname, "glob_scale", breakdown_mode, True, mode_labels, 
                                  mode_colors, l, cutname, fillna=-1, det=DETECTOR)

In [None]:
all_plotdata = {}


with Pool(10) as p:
    for cut, cutname in zip(cuts, cutnames):
        all_plotdata[cutname] = {}
        inputs = [(v, b, l, cut, cutname) for (v, b, l) in zip(plotvars, bins, labels)]
        for v, plotdata in tqdm(p.imap_unordered(inner, inputs), total=len(inputs)):    
            all_plotdata[cutname][v] = plotdata


In [None]:
ifig = 0
for cname in cutnames:
    for v in plotvars:
        plt.figure(ifig)
        ratio_plot(plt, all_plotdata[cname][v])
        
        if DOSAVE:
            savename_pdf = PLOTDIR + "/pdf/%s_%s_%s.pdf" % (all_plotdata[cname][v]["det"].replace(" ", "-"), cname.replace(" ", "").replace(".", "").lower(), v)
            savename_png = PLOTDIR + "/png/%s_%s_%s.png" % (all_plotdata[cname][v]["det"].replace(" ", "-"), cname.replace(" ", "").replace(".", "").lower(), v)
            plt.savefig(savename_pdf, bbox_inches="tight")
            plt.savefig(savename_png, bbox_inches="tight")
            plt.close()
        else:
            ifig += 1

In [None]:
cuts = [
    FV,
    simple_cosmic_rej,
    twoprong_cut,
    pid_cut,
]

cutnames = [
    "Contained",
    "Simple Cos. Rej.",
    "Two Prong Cut",
    "PID Cut",
]

plotvars = [
    "mu_chi2_of_prot_cand",
    "prot_chi2_of_prot_cand",
    "mu_chi2_of_mu_cand",
    "prot_chi2_of_mu_cand",  
]

bins = [
    np.linspace(0, 80, 21),
    np.linspace(0, 300, 21),
    np.linspace(0, 80, 21),
    np.linspace(0, 300, 21),
]

labels = [
    "Proton Cand. $\\chi^2_\\mu$",
    "Proton Cand. $\\chi^2_p$",
    "Muon Cand. $\\chi^2_\\mu$",
    "Muon Cand. $\\chi^2_p$",
]

for c, cname in zip(cuts, cutnames):
    df[cname] = c(df)
    ONdf[cname] = c(ONdf)
    OFFdf[cname] = c(OFFdf)

    if INCLUDE_DIRT:
        dirt[cname] = c(dirt)

    for i in range(len(detvars)):
        detvars[i][cname] = c(detvars[i])

In [None]:
def inner(dat):
    (v, b, l, cut, cutname) = dat
    return v, make_plot_data(v, b, cutname, "glob_scale", breakdown_pdg, True, pdg_labels, 
                                  pdg_colors, l, cutname, fillna=-1, det=DETECTOR)

In [None]:
all_plotdata_pdg = {}


with Pool(4) as p:
    for cut, cutname in zip(cuts, cutnames):
        all_plotdata_pdg[cutname] = {}
        inputs = [(v, b, l, cut, cutname) for (v, b, l) in zip(plotvars, bins, labels)]
        for v, plotdata in tqdm(p.imap_unordered(inner, inputs), total=len(inputs)):    
            all_plotdata_pdg[cutname][v] = plotdata


In [None]:
ifig = 0
for cname in cutnames:
    for v in plotvars:
        plt.figure(ifig)
        ratio_plot(plt, all_plotdata_pdg[cname][v])
        
        if DOSAVE:
            savename_pdf = PLOTDIR + "/pdf/%s_%s_%s_bkdwnpdg.pdf" % (all_plotdata_pdg[cname][v]["det"], cname.replace(" ", "").replace(".", "").lower(), v)
            savename_png = PLOTDIR + "/png/%s_%s_%s_bkdwnpdg.png" % (all_plotdata_pdg[cname][v]["det"], cname.replace(" ", "").replace(".", "").lower(), v)
            plt.savefig(savename_pdf, bbox_inches="tight")
            plt.savefig(savename_png, bbox_inches="tight")
            plt.close()
        else:
            ifig += 1