In [None]:
import os
import glob
import numpy as np
import awkward

In [None]:
import matplotlib
%matplotlib notebook
import matplotlib.pyplot as plt

In [None]:
# import matplotlib
# %matplotlib inline
# import matplotlib.pyplot as plt

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve, auc

In [None]:
# taken from the training outputs -- update if needed
presel_eff = {0: 0.0531643892564612, 1: 0.9818927295320153, 10: 0.9940605634291205, 100: 0.9957324620030694, 1000: 0.9973451587037725}

In [None]:
sig_filelist = ['OUTPUTDIR/test_sb_OUTPUT.awkd']
bkg_filelist = glob.glob('OUTPUTDIR/test_b/*.awkd')
# bkg_filelist = []  # uncomment this if you don't have the extra bkg samples

In [None]:
sig_tables = [awkward.load(f) for f in sig_filelist]
bkg_tables = [awkward.load(f) for f in bkg_filelist]

In [None]:
load_branches = [
#     'ecalDigis_recon.id_',
#     'ecalDigis_recon.energy_',
    'EcalVetoGabriel_recon.nReadoutHits_',
    'EcalVetoGabriel_recon.deepestLayerHit_',
    'EcalVetoGabriel_recon.summedDet_',
    'EcalVetoGabriel_recon.summedTightIso_',
    'EcalVetoGabriel_recon.maxCellDep_',
    'EcalVetoGabriel_recon.showerRMS_',
    'EcalVetoGabriel_recon.xStd_',
    'EcalVetoGabriel_recon.yStd_',
    'EcalVetoGabriel_recon.avgLayerHit_',
    'EcalVetoGabriel_recon.stdLayerHit_',
    'EcalVetoGabriel_recon.ecalBackEnergy_',
#     'EcalVetoGabriel_recon.electronContainmentEnergy_',
#     'EcalVetoGabriel_recon.photonContainmentEnergy_',
#     'EcalVetoGabriel_recon.outsideContainmentEnergy_',
#     'EcalVetoGabriel_recon.outsideContainmentNHits_',
#     'EcalVetoGabriel_recon.outsideContainmentXStd_',
#     'EcalVetoGabriel_recon.outsideContainmentYStd_',
    'EcalVetoGabriel_recon.discValue_',
#     'EcalVetoGabriel_recon.recoilPx_',
#     'EcalVetoGabriel_recon.recoilPy_',
#     'EcalVetoGabriel_recon.recoilPz_',
    'EcalVetoGabriel_recon.recoilX_',
    'EcalVetoGabriel_recon.recoilY_',
#     'EcalVetoGabriel_recon.ecalLayerEdepReadout_',
    'ParticleNet_extra_label',
    'ParticleNet_disc',
    'TargetSPRecoilE_pt', # use this for plotting: this is the recoil electron pT at TargetSP
]

In [None]:
a = {}
for k in load_branches:
    arrs = []
    for tab in sig_tables + bkg_tables:
        arr = tab[k] if k in tab else np.zeros_like(tab['ParticleNet_disc'])
        arrs.append(arr)
    a[k] = awkward.concatenate(arrs)
    if k.startswith('EcalVeto'):
        a[k] = a[k].regular()
        if a[k].ndim==2 and a[k].shape[1]==1:
            a[k] = a[k][:,0]

In [None]:
for k in a.keys():
    print(k, a[k].shape)

In [None]:
def to_categorical(y, num_classes=None):
    """Converts a class vector (integers) to binary class matrix.
    E.g. for use with categorical_crossentropy.
    # Arguments
        y: class vector to be converted into a matrix
            (integers from 0 to num_classes).
        num_classes: total number of classes.
    # Returns
        A binary matrix representation of the input.
    """
    y = np.array(y, dtype='int').ravel()
    if not num_classes:
        num_classes = np.max(y) + 1
    n = y.shape[0]
    categorical = np.zeros((n, num_classes), dtype='int')
    categorical[np.arange(n), y] = 1
    return categorical


def plotROC(y_preds, y_truth, sample_weight=None, output=None, labels=['signal'], sig_eff=1, bkg_eff=1, **kwargs):
    from sklearn.metrics import auc, roc_curve, accuracy_score

    fpr = dict()
    tpr = dict()
    thresholds= dict()
    roc_auc = dict()
    outputs = {}

    plt.figure()

    for label, pred in zip(labels, y_preds):
        fpr[label], tpr[label], thresholds[label] = roc_curve(y_truth, pred, sample_weight=sample_weight)
        roc_auc[label] = auc(fpr[label], tpr[label])
        fpr[label] *= bkg_eff
        tpr[label] *= sig_eff

        legend = '%s (auc* = %0.6f)' % (label, roc_auc[label])
        print(legend)
        eff = get_signal_effs(fpr[label], tpr[label], thresholds[label])
        outputs[label] = eff
        print(eff)
        plt.plot(fpr[label], tpr[label], label=legend)
#     plt.plot([0, 1], [1, 0], 'k--')
    plt.xlim(kwargs.get('xlim', [0, 1]))
    plt.ylim(kwargs.get('ylim', [0, 1]))
    plt.xlabel('False positive rate ($\epsilon_{B}$)')
    plt.ylabel('True positive rate ($\epsilon_{S}$)')
#     plt.title('Receiver operating characteristic example')
    plt.legend(loc='best')
    if kwargs.get('logy', False):
        plt.yscale('log')
    if kwargs.get('logx', False):
        plt.xscale('log')
    plt.grid()
    if output:
        plt.savefig(output)
#     return {'fpr':fpr, 'tpr':tpr, 'thresholds':thresholds}
    return outputs

mistags=[1e-3, 1e-4, 1e-5, 1e-6]
def get_signal_effs(fpr, tpr, thresholds, mistags=mistags):
    outputs = []
    for m in mistags:
        idx = next(idx for idx, v in enumerate(fpr) if v > m)
        outputs.append((fpr[idx], tpr[idx], thresholds[idx]))
    return outputs


In [None]:
test_extra_labels = a['ParticleNet_extra_label']
test_labels = test_extra_labels>0

In [None]:
roc_info = {}
for k in presel_eff:
    if k > 0:
        mass = '%d MeV' % k
        print(mass)
        roc_info[k] = plotROC([a['ParticleNet_disc'], a['EcalVetoGabriel_recon.discValue_']], test_labels, 
                sample_weight=np.logical_or(test_extra_labels == 0, test_extra_labels == k),
                sig_eff=presel_eff[k], bkg_eff=presel_eff[0],
                labels=['ParticleNet', 'BDT'], xlim=[1e-6, .01], ylim=[0, 1], logx=True)

In [None]:
roc_info

In [None]:
plot_bins = {
#     'EcalVetoGabriel_recon.nReadoutHits_':np.linspace(0, 50, 51),
#     'EcalVetoGabriel_recon.deepestLayerHit_':np.linspace(0, 35, 36),
#     'EcalVetoGabriel_recon.summedDet_':np.linspace(0, 2000, 41),
#     'EcalVetoGabriel_recon.summedTightIso_':np.linspace(0, 400, 41),
#     'EcalVetoGabriel_recon.maxCellDep_':np.linspace(0, 400, 41),
#     'EcalVetoGabriel_recon.showerRMS_':np.linspace(0, 250, 26),
#     'EcalVetoGabriel_recon.xStd_':np.linspace(0, 200, 41),
#     'EcalVetoGabriel_recon.yStd_':np.linspace(0, 200, 41),
#     'EcalVetoGabriel_recon.avgLayerHit_':np.linspace(0, 35, 36),
#     'EcalVetoGabriel_recon.stdLayerHit_':np.linspace(0, 20, 21),
#     'EcalVetoGabriel_recon.ecalBackEnergy_':np.linspace(0, 200, 41),
# #     'EcalVetoGabriel_recon.discValue_':np.linspace(0.9, 1, 51),
#     'EcalVetoGabriel_recon.recoilX_':np.linspace(-400, 400, 81),
#     'EcalVetoGabriel_recon.recoilY_':np.linspace(-400, 400, 81),

    'TargetSPRecoilE_pt':np.linspace(-50, 200, 51),

#     'ParticleNet_disc':np.linspace(0, 1, 51),
}

# plot_bins = {
#     'EcalVetoGabriel_recon.ecalLayerEdepReadout_:%d'%i:np.linspace(0, 1000, 51) for i in range(34)
# }


In [None]:
colors = ['#636363', '#74c476', '#3182bd', '#f03b20', '#bd0026']
colors = None

In [None]:
def plot_sig_vs_bkg(var_name):
    if ':' in var_name:
        var, var_idx = var_name.split(':')
        var_idx = int(var_idx)
    else:
        var, var_idx = var_name, None

    arrays = []
    labels = []
    for proc in presel_eff.keys():
        pos = test_extra_labels==proc
        arr = a[var][pos]
        if var_idx:
            arr = arr[:, var_idx]
        if not isinstance(plot_bins[var_name], int):
            arr = np.clip(arr, min(plot_bins[var_name]), max(plot_bins[var_name]))
        arrays.append(arr)
        labels.append('BKG' if proc==0 else '%d MeV'%proc)
    f = plt.figure()
    plt.hist(arrays, bins=plot_bins[var_name], label=labels, normed=True, histtype='step', log=True)
    plt.legend()
    plt.xlabel(var_name)

In [None]:
# plot_sig_vs_bkg('TargetSPRecoilE_pt')

In [None]:
for var in plot_bins:
    print(var)
    plot_sig_vs_bkg(var)

In [None]:
def plot_trend(var_name, proc, eff_levels=None, mistag_levels=[1e-3, 1e-4, 1e-5, 1e-6]):
    if ':' in var_name:
        var, var_idx = var_name.split(':')
        var_idx = int(var_idx)
    else:
        var, var_idx = var_name, None
    if eff_levels is None and mistag_levels is None:
        eff_levels = [1e-3, 1e-4, 1e-5, 1e-6] if proc==0 else [0.9, 0.7, 0.5, 0.2]
    f, axes = plt.subplots(1, 2, figsize=(12, 5))
    f.suptitle('%d MeV'%proc if proc>0 else 'BKG', fontsize=16)
    pos0 = test_extra_labels==proc
    a_tmp = {k:a[k][pos0] for k in ('ParticleNet_disc', 'EcalVetoGabriel_recon.discValue_', var)}
    if var_idx:
        a_tmp[var] = a_tmp[var][:, var_idx]
    for i, k in enumerate(['ParticleNet_disc', 'EcalVetoGabriel_recon.discValue_']):
        arrs = []
        labels = []
        scores = a_tmp[k]
        if mistag_levels is None:
            pcts = (1 - np.array(eff_levels)/presel_eff[proc]) * 100.
            print('effs=', pcts)
            thresholds = [-99] + list(np.percentile(scores, pcts))
        else:
            thresholds = [-99] + [info[-1] for info in roc_info[1]['ParticleNet' if i==0 else 'BDT']]
        print(k, thresholds)
        for idx, thres in enumerate(thresholds):
            pos = scores>thres
#             print(k, thres, pos.sum())
            if not isinstance(plot_bins[var_name], int):
                arrs.append(np.clip(a_tmp[var][pos], min(plot_bins[var_name]), max(plot_bins[var_name])))
            else:
                arrs.append(a_tmp[var][pos])
            if mistag_levels is None:
                labels.append('inclusive' if idx==0 else ('$\epsilon_{B}$=%.0e'%eff_levels[idx-1] if proc==0 else '$\epsilon_{S}$=%.1f'%eff_levels[idx-1]))
            else:
                labels.append('inclusive' if idx==0 else '$\epsilon_{B}$=%.0e'%mistag_levels[idx-1])

        axes[i].hist(arrs, histtype='step', bins=plot_bins[var_name], density=True, log=True, label=labels, color=colors, linewidth=2)
        axes[i].set_title('ParticleNet' if i==0 else 'BDT')
        axes[i].set_xlabel(var)
        axes[i].set_ylim(1e-6, 30)
        axes[i].legend(loc='best')

In [None]:
# compare shapes with succesively tighter selections
# keep BKG mistag rate the same for all signal points

for var in plot_bins:
    print(var)
    plot_trend(var, proc=0, mistag_levels=None)
    plot_trend(var, proc=1, mistag_levels=None)
    plot_trend(var, proc=10, mistag_levels=None)
    plot_trend(var, proc=100, mistag_levels=None)
    plot_trend(var, proc=1000, mistag_levels=None)

In [None]:
# compare shapes with succesively tighter selections
# keep SIG efficiency the same for all signal points

# for var in plot_bins:
#     print(var)
#     plot_trend(var, proc=0)
#     plot_trend(var, proc=1)
#     plot_trend(var, proc=10)
#     plot_trend(var, proc=100)
#     plot_trend(var, proc=1000)