# Stacked histogram reading
## plotting via mplhep and related tools
## and extracting data / MC to get KL divergences and related statistics

***
# Used packages
For plotting, calculating stats
***

In [None]:
import uproot
import numpy as np
import matplotlib.pyplot as plt
import mplhep as hep
#plt.style.use([hep.style.ROOT, hep.style.fira, hep.style.firamath])
#plt.style.use([hep.style.ROOT, hep.style.firamath])

In [None]:
plt.style.use([hep.style.ROOT])

In [None]:
from scipy.stats import entropy

In [None]:
import matplotlib as mpl
from cycler import cycler
from matplotlib.offsetbox import AnchoredText

In [None]:
import itertools
def flip(items, ncol):
    return itertools.chain(*[items[i::ncol] for i in range(ncol)])

In [None]:
data_err_opts = {
    'linestyle': 'none',
    'marker': '.',
    'markersize': 10.,
    'color': 'k',
    'elinewidth': 1,
}
stack_fill_opts = {
    'alpha': 0.8,
    'edgecolor':(0,0,0,.5)
}
stack_error_opts = {
    'label':'Stat. Unc.',
    'hatch':'///',
    'facecolor':'none',
    'edgecolor':(0,0,0,.5),
    'linewidth': 0
}
hatch_style = {
    'facecolor': 'none',
    'edgecolor': (0, 0, 0, 0.5),
    'linewidth': 0,
    'hatch': '///',
}
shaded_style = {
    'facecolor': (0,0,0,0.3),
    'linewidth': 0
}

In [None]:
def kldiv(A, B):
        return np.sum([v for v in A * np.log(A/B) if not np.isnan(v)])
    
def jsdiv(P, Q, normalize_first=False):
    """Compute the Jensen-Shannon divergence between two probability distributions.

    Input
    -----
    P, Q : array-like
        Probability distributions of equal length that sum to 1
    """
    
    P = np.array(P)
    Q = np.array(Q)

    if normalize_first:
        P = P/P.sum()
        Q = Q/Q.sum()
    
    M = 0.5 * (P + Q)

    return 0.5 * (kldiv(P, M) +kldiv(Q, M))

In [None]:
def chi2(P, Q, normalize_first=False):
    P = np.array(P)
    Q = np.array(Q)
    if normalize_first:
        P = P/P.sum()
        Q = Q/Q.sum()
    
    return np.sum([v for v in ((P-Q)**2 / (Q)) if not np.isnan(v)])

***
# Relevant final implementation starts here (update paths when necessary)
***

### Specify paths for different epochs
StackPlotter for three epochs or one epoch only results in slightly different directory names, but similar file names  
This is the place to add paths up to the point where they start to be similar again

In [None]:
#ana_outdir_adv = ['/nfs/dust/cms/user/anstein/ctag_condor/systPlots_220809_2017_best_adversarial_eps0p01_/Plots_220809_best_adversarial_eps0p01_minimal_Custom']
# with 30bins only for easier comparison even for low stats
#ana_outdir_adv = ['/nfs/dust/cms/user/anstein/ctag_condor/systPlots_220915_2017_best_adversarial_eps0p01_/Plots_220915_best_adversarial_eps0p01_minimal_Custom_30bins']
ana_outdir_adv = ['/nfs/dust/cms/user/anstein/ctag_condor/systPlots_220917_2017_best_adversarial_0p01_/Plots_220917_best_adversarial_0p01_minimal_Custom_30bins_v4']

ana_outdir_adv_dict = {'best':path for path in ana_outdir_adv}

# put here the log files containing a printout for the normalization factor, for each selection DY / WC / TTSEMI
# this is to figure out the applied scaling for the total stack (and use it for error calculation with sumw(2)!)
# they will all contain a line 'Will normalize total MC to data, with factor:' after which the factor is written in the subsequent line, example: 0.461432322352
condor_logs_adv = [f'/nfs/dust/cms/user/anstein/ctag_condor/StackPlotter_logs/_deepjet_adversarial_NEW_STYLE/log-22825378.{processNo}.out' for processNo in [68,66,67]]

In [None]:
#ana_outdir_basic = ['/nfs/dust/cms/user/anstein/ctag_condor/systPlots_220811_2017_best_nominal_/Plots_220811_best_nominal_minimal_Custom']
# with 30bins only for easier comparison even for low stats
#ana_outdir_basic = ['/nfs/dust/cms/user/anstein/ctag_condor/systPlots_220915_2017_best_nominal_/Plots_220915_best_nominal_minimal_Custom_30bins']
ana_outdir_basic = ['/nfs/dust/cms/user/anstein/ctag_condor/systPlots_220917_2017_best_nominal_/Plots_220917_best_nominal_minimal_Custom_30bins_v4']
ana_outdir_basic_dict = {'best':path for path in ana_outdir_basic}

condor_logs_basic = [f'/nfs/dust/cms/user/anstein/ctag_condor/StackPlotter_logs/_deepjet_adversarial_NEW_STYLE/log-22825368.{processNo}.out' for processNo in [68,66,67]]

In [None]:
def get_normalization_factors(correct_logs):
    factors_dict = {}
    for i,sel in enumerate(['DY_m', 'Wc_m', 'TT_semim']):
        with open(correct_logs[i], 'r') as logfile:
            lines = logfile.readlines()
            normalization_factor_i = float(lines[lines.index('Will normalize total MC to data, with factor:\n') + 1].strip('\n'))
            print(normalization_factor_i)
            factors_dict[sel] = normalization_factor_i
    return factors_dict

In [None]:
nominal_normalization_dict = get_normalization_factors(condor_logs_basic)
adversarial_normalization_dict = get_normalization_factors(condor_logs_adv)

Need to match the epochs with letters used

In [None]:
epoch_letter_dict = {1: '_A_',
                    50: '_B_',
                    150:'_C_',
                     5: '_A_',
                    10: '_B_',
                    100:'_C_',
                    'best':''}

In [None]:
def plotStack(tagger, selection, disc, mergeNonSelectedFlav=False, mergeLepWithLight=False, epoch='best', ratio_err_style='as_in_file', unc_variant='as_in_file'):
    
    # Find the correct file and set axis labels
    analyzerDict = ana_outdir_adv_dict if tagger == 'adversarial_eps0p01' else ana_outdir_basic_dict
    if selection == 'DY_m':
        indexjet = '0'
    else:
        indexjet = 'muJet_idx'
        
    if selection == 'Wc_m':
        yText = 'Jet yield, OS-SS subtracted'
    else:
        yText = 'Jet yield'
        
    if tagger == 'nominal':
        xText = 'DeepJet (Nominal Training) ' + disc
        normalization_factor_dict = nominal_normalization_dict
    else:
        xText = 'DeepJet (Adversarial Training) ' + disc
        normalization_factor_dict = adversarial_normalization_dict
        
    region_text = r'$\bf{DY + jet}$' if selection == 'DY_m' else (r'$\bf{t\bar{t}}$' if selection == 'TT_semim' else r'$\bf{W + c}$')
        
    n_cols_legend = 3 if mergeLepWithLight else 4
    
    histo = uproot.open(analyzerDict[epoch]+f'/output_2017_PFNano_central/{selection}_jet_Custom{epoch_letter_dict[epoch]}{disc}_{indexjet}_.root')
    
    # Print info for cross-checks
    print(histo.keys())
    
    # Hard-coded ! Should match Stacker settings.
    bins = np.linspace(-0.2,1., 31)
    print(bins)
    
    # Colours depending on sample, created with the help of coolors.co
    colours = {'W + b jets' : '#C4FFFF', 'W + c jets' : '#00FFFF', 'W + udsg jets' : '#007474', 'W + lep' : '#0D98BA', 
               'DY + b jets' : '#FCFCAC', 'DY + c jets' : '#FFFF00', 'DY + udsg jets' : '#5F7A33', 'DY + lep' : '#939E7A',
               r'$t\bar{t}$ + b jets' : '#FFBCD9', r'$t\bar{t}$ + c jets' : '#FC0FC0', r'$t\bar{t}$ + udsg jets' : '#8B008B', r'$t\bar{t}$ + lep' : '#86608E',
               'ST + b jets' : '#73C2FB', 'ST + c jets' : '#0000FF', 'ST + udsg jets' : '#002366', 'ST + lep' : '#126180',
               'Bottom' : '#BB0A21', 'Charm' : '#FF9505', 'udsg' : '#4C2882', 'lep' : '#252627'}
    
    # Load individual histograms depending on selection and merge as requested
    DYJets_b = histo['DYJets_b'].values()
    DYJets_c = histo['DYJets_c'].values()
    DYJets_uds = histo['DYJets_uds'].values()
    DYJets_lep = histo['DYJets_lep'].values()
    if mergeLepWithLight:
        DYJets_uds += DYJets_lep
        del DYJets_lep
        
    if selection != 'DY_m':
        WJets_b = histo['WJets_b'].values()
        WJets_c = histo['WJets_c'].values()
        WJets_uds = histo['WJets_uds'].values()
        WJets_lep = histo['WJets_lep'].values()
        if mergeLepWithLight:
            WJets_uds += WJets_lep
            del WJets_lep

        ttbar_b = histo['ttbar_b'].values()
        ttbar_c = histo['ttbar_c'].values()
        ttbar_uds = histo['ttbar_uds'].values()
        ttbar_lep = histo['ttbar_lep'].values()
        if mergeLepWithLight:
            ttbar_uds += ttbar_lep
            del ttbar_lep

        ST_b = histo['ST_b'].values()
        ST_c = histo['ST_c'].values()
        ST_uds = histo['ST_uds'].values()
        ST_lep = histo['ST_lep'].values()
        if mergeLepWithLight:
            ST_uds += ST_lep
            del ST_lep
            
        if mergeNonSelectedFlav and selection == 'TT_semim':
            charm = DYJets_c + WJets_c + ttbar_c + ST_c
            light = DYJets_uds + WJets_uds + ttbar_uds + ST_uds
            if not mergeLepWithLight:
                lep = DYJets_lep + WJets_lep + ttbar_lep + ST_lep
        elif mergeNonSelectedFlav and selection == 'Wc_m':
            bottom = DYJets_b + WJets_b + ttbar_b + ST_b
            light = DYJets_uds + WJets_uds + ttbar_uds + ST_uds
            if not mergeLepWithLight:
                lep = DYJets_lep + WJets_lep + ttbar_lep + ST_lep
        
        if not mergeNonSelectedFlav and not mergeLepWithLight:
            in_stack_legend = ['W + b jets', 'W + c jets', 'W + udsg jets', 'W + lep',
                               'DY + b jets', 'DY + c jets', 'DY + udsg jets', 'DY + lep',
                               r'$t\bar{t}$ + b jets', r'$t\bar{t}$ + c jets', r'$t\bar{t}$ + udsg jets', r'$t\bar{t}$ + lep',
                               'ST + b jets', 'ST + c jets', 'ST + udsg jets', 'ST + lep']
            in_stack_histos = [WJets_b, WJets_c, WJets_uds, WJets_lep,
                               DYJets_b, DYJets_c, DYJets_uds, DYJets_lep,
                               ttbar_b, ttbar_c, ttbar_uds, ttbar_lep,
                               ST_b, ST_c, ST_uds, ST_lep]
        elif mergeNonSelectedFlav and not mergeLepWithLight:
            if selection == 'TT_semim':
                in_stack_legend = ['W + b jets', 
                                   'DY + b jets', 
                                   r'$t\bar{t}$ + b jets', 
                                   'ST + b jets',
                                   'Charm',
                                   'udsg',
                                   'lep']
                in_stack_histos = [WJets_b,
                                   DYJets_b,
                                   ttbar_b, 
                                   ST_b,
                                   charm,
                                   light,
                                   lep]
            elif selection == 'Wc_m':
                in_stack_legend = ['W + c jets', 
                                   'DY + c jets', 
                                   r'$t\bar{t}$ + c jets', 
                                   'ST + c jets',
                                   'Bottom',
                                   'udsg',
                                   'lep']
                in_stack_histos = [WJets_c,
                                   DYJets_c,
                                   ttbar_c, 
                                   ST_c,
                                   bottom,
                                   light,
                                   lep]
        elif not mergeNonSelectedFlav and mergeLepWithLight:
            in_stack_legend = ['W + b jets', 'W + c jets', 'W + udsg jets',
                           'DY + b jets', 'DY + c jets', 'DY + udsg jets', 
                           r'$t\bar{t}$ + b jets', r'$t\bar{t}$ + c jets', r'$t\bar{t}$ + udsg jets', 
                           'ST + b jets', 'ST + c jets', 'ST + udsg jets']
            in_stack_histos = [WJets_b, WJets_c, WJets_uds,
                               DYJets_b, DYJets_c, DYJets_uds,
                               ttbar_b, ttbar_c, ttbar_uds,
                               ST_b, ST_c, ST_uds]
                
        else: # merge lep with udsg and also merge non-selected categories
            if selection == 'TT_semim':
                in_stack_legend = ['W + b jets', 
                                   'DY + b jets', 
                                   r'$t\bar{t}$ + b jets', 
                                   'ST + b jets',
                                   'Charm',
                                   'udsg',]
                in_stack_histos = [WJets_b,
                                   DYJets_b,
                                   ttbar_b, 
                                   ST_b,
                                   charm,
                                   light]
            elif selection == 'Wc_m':
                in_stack_legend = ['W + c jets', 
                                   'DY + c jets', 
                                   r'$t\bar{t}$ + c jets', 
                                   'ST + c jets',
                                   'Bottom',
                                   'udsg']
                in_stack_histos = [WJets_c,
                                   DYJets_c,
                                   ttbar_c, 
                                   ST_c,
                                   bottom,
                                   light]
        
    else: # DY_m case
        if mergeLepWithLight:
            in_stack_legend = ['DY + b jets', 'DY + c jets', 'DY + udsg jets']
            in_stack_histos = [DYJets_b, DYJets_c, DYJets_uds]
        else:
            in_stack_legend = ['DY + b jets', 'DY + c jets', 'DY + udsg jets', 'DY + lep']
            in_stack_histos = [DYJets_b, DYJets_c, DYJets_uds, DYJets_lep]
        
    in_stack_legend.reverse()
    in_stack_histos.reverse()
    if mergeNonSelectedFlav:
        if selection == 'TT_semim':
            all_ = ''
        elif selection == 'Wc_m':
            all_ = ''
    mcsum = histo['MCSum'].values()
    mcsum_err = histo['MCSum'].errors()
    data = histo['Data'].values()
    data_err = histo['Data'].errors()
    
    colors = [colours[name] for name in in_stack_legend] + ['#000000']
    plt.rcParams['axes.axisbelow'] = True
    mpl.rcParams["axes.prop_cycle"] = cycler('color', colors)
    
    # (thanks to this great source: https://github.com/nsmith-/mpl-hep/blob/master/binder/gallery.ipynb)
    fig, ((ax1, ax2)) = plt.subplots(2, 1, figsize=(12, 12), gridspec_kw={"height_ratios": (3, 1), 'hspace': 0.0}, sharex=True)
    fig.tight_layout()
    hep.cms.label("Preliminary", data=True, lumi=41.5, year=2017,ax=ax1)
    #fig.subplots_adjust(hspace = 0.3)
    hep.histplot(in_stack_histos,
                 bins,
                 label=in_stack_legend, stack=True, histtype='fill',
                 yerr = True, xerr = True, ax=ax1, # **hatch_style
                 #fill_opts=stack_fill_opts,
                 #error_opts=stack_error_opts
                 )
    hep.histplot(data, bins, label='Data', histtype='errorbar', ax=ax1, xerr=True, yerr=np.sqrt(data))
    ax1.legend(loc="upper right",ncol=n_cols_legend,labels=in_stack_legend+['Data'],fontsize=18)
    ax1.set_ylabel(yText)
    ax1.set_xlabel('')
    at = AnchoredText(r'$\mu$ channel'+"\n"+region_text+"\n"+'$\it{Pre-Calibration}$',
                      loc='upper left',frameon=False, prop=dict(size=16))
    ax1.add_artist(at)
    
    # Overlay an uncertainty hatch
    sumw = np.sum([i_histo for i_histo in in_stack_histos], axis=0)
    sumw2 = np.sum([i_histo**2 for i_histo in in_stack_histos], axis=0)
    if unc_variant == 'sqrt_of_sumw':
        scale_unc_by_norm_factor = (1 / normalization_factor_dict[selection])
        unc = np.sqrt(sumw) * scale_unc_by_norm_factor # should be equal to sqrt(mcsum) out of Stacker tool
    elif unc_variant == 'sqrt_of_sumw2':
        scale_unc_by_norm_factor = (1 / normalization_factor_dict[selection])
        unc = np.sqrt(sumw2) * scale_unc_by_norm_factor
    elif unc_variant == 'sum_of_sqrt':
        unc = np.sum([np.sqrt(i_histo) for i_histo in in_stack_histos], axis=0)
    elif unc_variant == 'as_in_file':
        unc = mcsum_err
    unc = np.hstack([unc, unc[-1]])
    sumw_total = np.hstack([mcsum, mcsum[-1]])
    print(sumw2)
    print(unc)
    print(mcsum)
    print(data/mcsum)
    num_err = np.sqrt(data)
    denom_err = np.sqrt(mcsum)
    
    #denom_err_for_stack = np.hstack([denom_err, denom_err[-1]])
    
    ax1.fill_between(x=bins, y1=sumw_total - unc, y2=sumw_total + unc,
                    label='MC stat. unc.', step='post', **hatch_style
                   )
    # ========================================== Setup for ratio plot ==========================================
    # this is closer to what is currently done for commissioning,
    # where both sources are kept separate
    if ratio_err_style == 'commissioning':
        ratio_err = np.sqrt(num_err)/(mcsum)
        ax2.fill_between(x=bins, y1=1 - unc/sumw_total, y2=1 + unc/sumw_total,
                        label='MC stat. unc.', step='post', **shaded_style
                   )
    # this merges the two uncertainty sources (data & mc) together,
    # like it was done for the aisafety paper as well as in Spandan's stacked histograms
    elif ratio_err_style == 'aisafetypaper':
        ratio_err = np.sqrt((num_err/mcsum)**2+(data/(mcsum**2)*denom_err)**2)
    elif ratio_err_style == 'as_in_file':
        ratio_err = np.sqrt((data_err/mcsum)**2 + (data/(mcsum**2) * mcsum_err)**2)
    # ----------------------------------------------------------------------------------------------------------   
    
    ax2.errorbar(bins[:-1] + 0.02,data/mcsum,xerr=0.02*np.ones(len(bins)-1),yerr=ratio_err,fmt='o',color='#000000')
    ax2.plot([-0.2,1.],[1,1],color='red') 
    ax2.set_ylim(0.55,1.45)
    ax2.set_xlim(-0.2,1)
    ax2.set_ylabel('Data/MC', loc='center')
    ax2.set_xlabel(xText)
    ax2.set_yticks([0.6,0.7,0.8,0.9,1.0,1.1,1.2,1.3,1.4])
    ax1.plot([])
    hep.mpl_magic(ax = ax1)
    ax2.grid(which='minor', axis='y', alpha=0.85)
    ax2.grid(which='major', axis='y', alpha=0.95, color='black')
    #ax2.grid(which='major', axis='x', alpha=0.85)
    
    handles, labels = ax1.get_legend_handles_labels()
    ax1.legend(flip(handles, n_cols_legend), flip(labels, n_cols_legend), loc="upper right",ncol=n_cols_legend,fontsize=18, handletextpad=0.5, columnspacing=1.0)
    #hep.sort_legend(ax1, in_stack_legend+['Data'])
    mergeNonSelectedFlavText = '' if not mergeNonSelectedFlav else 'mergedFlav_'
    mergeLepWithLightText = '' if not mergeLepWithLight else 'mergedLepUDSG_'
    fig.savefig(f'mplHisto_{tagger}_{selection}_{disc}_{mergeNonSelectedFlavText}{mergeLepWithLightText}{epoch}_{ratio_err_style}_{unc_variant}.pdf', bbox_inches='tight')

In [None]:
plotStack('nominal', 'DY_m', 'CvsB')

In [None]:
plotStack('adversarial_eps0p01', 'DY_m', 'CvsB')

In [None]:
plotStack('adversarial_eps0p01', 'TT_semim', 'CvsB')

In [None]:
plotStack('adversarial_eps0p01', 'Wc_m', 'CvsB')

### Calculate KL divergence for adversarial training

In [None]:
epochs = ['best']  # add the ones for which StackPlotter results exist
wms = ['DY_m', 'TT_semim', 'Wc_m']
# DY sel
kl_divs_all_adv_DY          = []
js_divs_all_adv_DY          = []
chi2_all_adv_DY          = []
# Wc sel
kl_divs_all_adv_Wc          = []
js_divs_all_adv_Wc          = []
chi2_all_adv_Wc          = []
# TT sel
kl_divs_all_adv_TT          = []
js_divs_all_adv_TT          = []
chi2_all_adv_TT          = []

kl_holder_adv = {'DY_m' : kl_divs_all_adv_DY, 'Wc_m' : kl_divs_all_adv_Wc, 'TT_semim' : kl_divs_all_adv_TT}
js_holder_adv = {'DY_m' : js_divs_all_adv_DY, 'Wc_m' : js_divs_all_adv_Wc, 'TT_semim' : js_divs_all_adv_TT}
chi2_holder_adv = {'DY_m' : chi2_all_adv_DY, 'Wc_m' : chi2_all_adv_Wc, 'TT_semim' : chi2_all_adv_TT}
#kl_divs_all_adv_DY_datadata = []
#kl_divs_all_adv_DY_simsim   = []
#data_disc_histo_dict = {'Prob_b':[],'Prob_bb':[],'Prob_c':[],'Prob_l':[],'BvsL':[], 'BvsC':[], 'CvsB':[], 'CvsL':[]}
#sim_disc_histo_dict = {'Prob_b':[],'Prob_bb':[],'Prob_c':[],'Prob_l':[],'BvsL':[], 'BvsC':[], 'CvsB':[], 'CvsL':[]}
#data_disc_histo_dict = {'BvsL':[], 'BvsC':[], 'CvsB':[], 'CvsL':[]}
#sim_disc_histo_dict = {'BvsL':[], 'BvsC':[], 'CvsB':[], 'CvsL':[]}
for j,wm in enumerate(wms):
    if wm == 'DY_m':
        indexjet = '0'
    else:
        indexjet = 'muJet_idx'
    for i,e in enumerate(epochs):
        kl_divs_adv          = []
        js_divs_adv          = []
        chi2_adv          = []
        #kl_divs_adv_datadata = []
        #kl_divs_adv_simsim   = []
        for disc in ['BvsL', 'CvsB', 'CvsL']:  # select those that exist for paths listed below
            if True:
                # data to MC
                #print(ana_outdir_adv_dict[e])
                histo = uproot.open(ana_outdir_adv_dict[e]+f'/output_2017_PFNano_central/{wm}_jet_Custom{epoch_letter_dict[e]}{disc}_{indexjet}_.root')

                num, denom = histo['Data'].values(), histo['MCSum'].values()
                num_num, denom_num = num, denom
                #print(num, denom)
                denom[(denom == 0) & (num != 0)] = 0.00001
                where_num_or_denom_negative = (denom < 0) | (num < 0)
                num[where_num_or_denom_negative] = 0
                denom[where_num_or_denom_negative] = 0
                #print(num, denom)
                this_kl_a = entropy([v for v in num] , qk=[v for v in denom])
                kl_divs_adv.append(this_kl_a)

                num, denom = histo['Data'].values(), histo['MCSum'].values()
                this_js = jsdiv(num, denom, normalize_first=True)
                this_chi2 = chi2(num, denom, normalize_first=True)

                js_divs_adv.append(this_js)
                chi2_adv.append(this_chi2)

                '''
                # data / data and MC / MC
                #print(data_disc_histo_dict[disc])
                #print(data_disc_histo_dict)
                #print(disc)
                if i > 0:
                    # data / data (this epoch: num_num / previous epoch: look up in dictionary)                
                    data_disc_histo_dict[disc][(data_disc_histo_dict[disc] == 0) & (num_num != 0)] = 0.00001
                    where_num_or_denom_negative = (data_disc_histo_dict[disc] < 0) | (num_num < 0)
                    num_num[where_num_or_denom_negative] = 0
                    data_disc_histo_dict[disc][where_num_or_denom_negative] = 0

                    # mc / mc (this epoch: denom_num / previous epoch: look up in dictionary)                
                    sim_disc_histo_dict[disc][(sim_disc_histo_dict[disc] == 0) & (denom_num != 0)] = 0.00001
                    where_num_or_denom_negative = (sim_disc_histo_dict[disc] < 0) | (denom_num < 0)
                    denom_num[where_num_or_denom_negative] = 0
                    sim_disc_histo_dict[disc][where_num_or_denom_negative] = 0


                    kl_divs_adv_datadata.append(entropy([v for v in num_num] , qk=data_disc_histo_dict[disc]))
                    kl_divs_adv_simsim.append(entropy([v for v in denom_num] , qk=sim_disc_histo_dict[disc]))
                
                # store this epoch's histogram for the respective variable in the dictionary (makes comparing with previous epoch possible)
                data_disc_histo_dict[disc] = num
                sim_disc_histo_dict[disc] = denom
                '''
            if False:
                kl_divs_adv.append(np.NaN)

        #print(kl_divs_adv)
        kl_holder_adv[wm].append(kl_divs_adv)
        js_holder_adv[wm].append(js_divs_adv)
        chi2_holder_adv[wm].append(chi2_adv)
        '''
        if i > 0:
            kl_divs_all_adv_DY_datadata.append(kl_divs_adv_datadata)
            kl_divs_all_adv_DY_simsim.append(kl_divs_adv_simsim)
        '''
    #kl_divs_all_adv_DY = np.array(kl_divs_all_adv_DY)
    #kl_divs_all_adv_DY_datadata = np.array(kl_divs_all_adv_DY_datadata)
    #kl_divs_all_adv_DY_simsim = np.array(kl_divs_all_adv_DY_simsim)

    print(kl_divs_all_adv_DY)
    print(js_divs_all_adv_DY)
    print(chi2_all_adv_DY)
    print(kl_divs_all_adv_Wc)
    print(js_divs_all_adv_Wc)
    print(chi2_all_adv_Wc)
    print(kl_divs_all_adv_TT)
    print(js_divs_all_adv_TT)
    print(chi2_all_adv_TT)
    #print(kl_divs_all_adv_DY_datadata)
    #print(kl_divs_all_adv_DY_simsim)
    print()

In [None]:
epochs = ['best']  # add the ones for which StackPlotter results exist
wms = ['DY_m', 'TT_semim', 'Wc_m']
# DY sel
kl_divs_all_basic_DY          = []
js_divs_all_basic_DY          = []
chi2_all_basic_DY          = []
# Wc sel
kl_divs_all_basic_Wc          = []
js_divs_all_basic_Wc          = []
chi2_all_basic_Wc          = []
# TT sel
kl_divs_all_basic_TT          = []
js_divs_all_basic_TT          = []
chi2_all_basic_TT          = []

kl_holder_basic = {'DY_m' : kl_divs_all_basic_DY, 'Wc_m' : kl_divs_all_basic_Wc, 'TT_semim' : kl_divs_all_basic_TT}
js_holder_basic = {'DY_m' : js_divs_all_basic_DY, 'Wc_m' : js_divs_all_basic_Wc, 'TT_semim' : js_divs_all_basic_TT}
chi2_holder_basic = {'DY_m' : chi2_all_basic_DY, 'Wc_m' : chi2_all_basic_Wc, 'TT_semim' : chi2_all_basic_TT}
#kl_divs_all_basic_DY_datadata = []
#kl_divs_all_basic_DY_simsim   = []
#data_disc_histo_dict = {'Prob_b':[],'Prob_bb':[],'Prob_c':[],'Prob_l':[],'BvsL':[], 'BvsC':[], 'CvsB':[], 'CvsL':[]}
#sim_disc_histo_dict = {'Prob_b':[],'Prob_bb':[],'Prob_c':[],'Prob_l':[],'BvsL':[], 'BvsC':[], 'CvsB':[], 'CvsL':[]}
#data_disc_histo_dict = {'BvsL':[], 'BvsC':[], 'CvsB':[], 'CvsL':[]}
#sim_disc_histo_dict = {'BvsL':[], 'BvsC':[], 'CvsB':[], 'CvsL':[]}
for j,wm in enumerate(wms):
    if wm == 'DY_m':
        indexjet = '0'
    else:
        indexjet = 'muJet_idx'
    for i,e in enumerate(epochs):
        kl_divs_basic          = []
        js_divs_basic          = []
        chi2_basic          = []
        #kl_divs_basic_datadata = []
        #kl_divs_basic_simsim   = []
        for disc in ['BvsL', 'CvsB', 'CvsL']:  # select those that exist for paths listed below
            if True:
                # data to MC
                #print(ana_outdir_basic_dict[e])
                histo = uproot.open(ana_outdir_basic_dict[e]+f'/output_2017_PFNano_central/{wm}_jet_Custom{epoch_letter_dict[e]}{disc}_{indexjet}_.root')

                num, denom = histo['Data'].values(), histo['MCSum'].values()
                num_num, denom_num = num, denom
                #print(num, denom)
                denom[(denom == 0) & (num != 0)] = 0.00001
                where_num_or_denom_negative = (denom < 0) | (num < 0)
                num[where_num_or_denom_negative] = 0
                denom[where_num_or_denom_negative] = 0
                #print(num, denom)
                this_kl_a = entropy([v for v in num] , qk=[v for v in denom])
                kl_divs_basic.append(this_kl_a)

                num, denom = histo['Data'].values(), histo['MCSum'].values()
                this_js = jsdiv(num, denom, normalize_first=True)
                this_chi2 = chi2(num, denom, normalize_first=True)

                js_divs_basic.append(this_js)
                chi2_basic.append(this_chi2)

                '''
                # data / data and MC / MC
                #print(data_disc_histo_dict[disc])
                #print(data_disc_histo_dict)
                #print(disc)
                if i > 0:
                    # data / data (this epoch: num_num / previous epoch: look up in dictionary)                
                    data_disc_histo_dict[disc][(data_disc_histo_dict[disc] == 0) & (num_num != 0)] = 0.00001
                    where_num_or_denom_negative = (data_disc_histo_dict[disc] < 0) | (num_num < 0)
                    num_num[where_num_or_denom_negative] = 0
                    data_disc_histo_dict[disc][where_num_or_denom_negative] = 0

                    # mc / mc (this epoch: denom_num / previous epoch: look up in dictionary)                
                    sim_disc_histo_dict[disc][(sim_disc_histo_dict[disc] == 0) & (denom_num != 0)] = 0.00001
                    where_num_or_denom_negative = (sim_disc_histo_dict[disc] < 0) | (denom_num < 0)
                    denom_num[where_num_or_denom_negative] = 0
                    sim_disc_histo_dict[disc][where_num_or_denom_negative] = 0


                    kl_divs_basic_datadata.append(entropy([v for v in num_num] , qk=data_disc_histo_dict[disc]))
                    kl_divs_basic_simsim.append(entropy([v for v in denom_num] , qk=sim_disc_histo_dict[disc]))
                
                # store this epoch's histogram for the respective variable in the dictionary (makes comparing with previous epoch possible)
                data_disc_histo_dict[disc] = num
                sim_disc_histo_dict[disc] = denom
                '''
            if False:
                kl_divs_basic.append(np.NaN)

        #print(kl_divs_basic)
        kl_holder_basic[wm].append(kl_divs_basic)
        js_holder_basic[wm].append(js_divs_basic)
        chi2_holder_basic[wm].append(chi2_basic)
        '''
        if i > 0:
            kl_divs_all_basic_DY_datadata.append(kl_divs_basic_datadata)
            kl_divs_all_basic_DY_simsim.append(kl_divs_basic_simsim)
        '''
    #kl_divs_all_basic_DY = np.array(kl_divs_all_basic_DY)
    #kl_divs_all_basic_DY_datadata = np.array(kl_divs_all_basic_DY_datadata)
    #kl_divs_all_basic_DY_simsim = np.array(kl_divs_all_basic_DY_simsim)

    print(kl_divs_all_basic_DY)
    print(js_divs_all_basic_DY)
    print(chi2_all_basic_DY)
    print(kl_divs_all_basic_Wc)
    print(js_divs_all_basic_Wc)
    print(chi2_all_basic_Wc)
    print(kl_divs_all_basic_TT)
    print(js_divs_all_basic_TT)
    print(chi2_all_basic_TT)
    #print(kl_divs_all_basic_DY_datadata)
    #print(kl_divs_all_basic_DY_simsim)
    print()

In [None]:
#xses = ['BvsL', 'BvsC', 'CvsB', 'CvsL']

In [None]:
xses = ['BvsL', 'CvsB', 'CvsL']

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,8))
fig.subplots_adjust(wspace = 0.3)
#fig.tight_layout()
ax1.scatter(xses, js_divs_all_basic_DY, label='light jets nominal'   , s=100 )
ax1.scatter(xses, js_divs_all_adv_DY, label='light jets adversarial' , s=100 )
ax2.scatter(xses, js_divs_all_basic_Wc, label='c jets nominal'       , s=100 )
ax2.scatter(xses, js_divs_all_adv_Wc, label='c jets adversarial'     , s=100 )
ax3.scatter(xses, js_divs_all_basic_TT, label='b jets nominal'       , s=100 )
ax3.scatter(xses, js_divs_all_adv_TT, label='b jets adversarial'     , s=100 )
ax1.legend(frameon=True)
ax2.legend(frameon=True)
ax3.legend(frameon=True)
ax1.set_ylabel('JS divergence (a.u.)')
ax2.set_ylabel('JS divergence (a.u.)')
ax3.set_ylabel('JS divergence (a.u.)')
hep.cms.label('Preliminary', data=True, lumi=41.54, year=2017)###fig.suptitle('JS divergence. Best epoch.', y=0.95)
#fig.savefig('js_best_epoch_custom_DeepJet_forRTG.pdf')

In [None]:
plt.rcParams['axes.axisbelow'] = True

In [None]:
#fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,8))
fig, ax1 = plt.subplots(1, figsize=(9,8))
#fig.subplots_adjust(wspace = 0.3)
fig.tight_layout()
ax1.grid(which='minor', alpha=0.85)
ax1.grid(which='major', alpha=0.95, color='black')
ax1.scatter(xses, js_divs_all_basic_DY, label='Nominal'   , s=200 )
ax1.scatter(xses, js_divs_all_adv_DY, label='Adversarial' , s=200 )
#ax2.scatter(xses, js_divs_all_basic_Wc, label='c jets nominal'       , s=100 )
#ax2.scatter(xses, js_divs_all_adv_Wc, label='c jets adversarial'     , s=100 )
#ax3.scatter(xses, js_divs_all_basic_TT, label='b jets nominal'       , s=100 )
#ax3.scatter(xses, js_divs_all_adv_TT, label='b jets adversarial'     , s=100 )
#ax1.legend(frameon=True, title='Light jets')
ax1.legend(frameon=True, framealpha=1, title='light jets, DeepJet')
leg = ax1.get_legend()
leg._legend_box.align = "left"
leg.get_frame().set_linewidth(0.0)
#ax2.legend(frameon=True)
#ax3.legend(frameon=True)
ax1.set_ylabel('JS divergence (a.u.)')
#ax2.set_ylabel('JS divergence (a.u.)')
#ax3.set_ylabel('JS divergence (a.u.)')
hep.cms.label('Preliminary', data=True, lumi=41.5, year=2017)###fig.suptitle('JS divergence. Best epoch.', y=0.95)
#fig.savefig('js_best_epoch_custom_DeepJet_LIGHT_v2.pdf', bbox_inches='tight')
fig.savefig('js_best_epoch_custom_DeepJet_LIGHT_v2_30bins_v4.pdf', bbox_inches='tight')

In [None]:
#fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,8))
fig, ax1 = plt.subplots(1, figsize=(9,8))
#fig.subplots_adjust(wspace = 0.3)
fig.tight_layout()
ax1.grid(which='minor', alpha=0.85)
ax1.grid(which='major', alpha=0.95, color='black')
ax1.scatter(xses, chi2_all_basic_DY, label='Nominal'   , s=200 )
ax1.scatter(xses, chi2_all_adv_DY, label='Adversarial' , s=200 )
#ax2.scatter(xses, js_divs_all_basic_Wc, label='c jets nominal'       , s=100 )
#ax2.scatter(xses, js_divs_all_adv_Wc, label='c jets adversarial'     , s=100 )
#ax3.scatter(xses, js_divs_all_basic_TT, label='b jets nominal'       , s=100 )
#ax3.scatter(xses, js_divs_all_adv_TT, label='b jets adversarial'     , s=100 )
#ax1.legend(frameon=True, title='Light jets')
ax1.legend(frameon=True, framealpha=1, title='light jets, DeepJet')
leg = ax1.get_legend()
leg._legend_box.align = "left"
leg.get_frame().set_linewidth(0.0)
#ax2.legend(frameon=True)
#ax3.legend(frameon=True)
ax1.set_ylabel(r'$\chi^2$ (a.u.)')
#ax2.set_ylabel('JS divergence (a.u.)')
#ax3.set_ylabel('JS divergence (a.u.)')
hep.cms.label('Preliminary', data=True, lumi=41.5, year=2017)###fig.suptitle('JS divergence. Best epoch.', y=0.95)
#fig.savefig('chi2_best_epoch_custom_DeepJet_LIGHT_v2.pdf', bbox_inches='tight')
fig.savefig('chi2_best_epoch_custom_DeepJet_LIGHT_v2_30bins_v4.pdf', bbox_inches='tight')

In [None]:
#fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,8))
fig, ax2 = plt.subplots(1, figsize=(9,8))
#fig.subplots_adjust(wspace = 0.3)
fig.tight_layout()
ax2.grid(which='minor', alpha=0.85)
ax2.grid(which='major', alpha=0.95, color='black')
#ax1.scatter(xses, js_divs_all_basic_DY, label='light jets nominal'   , s=100 )
#ax1.scatter(xses, js_divs_all_adv_DY, label='light jets adversarial' , s=100 )
ax2.scatter(xses, js_divs_all_basic_Wc, label='Nominal'       , s=200 )
ax2.scatter(xses, js_divs_all_adv_Wc, label='Adversarial'     , s=200 )
#ax3.scatter(xses, js_divs_all_basic_TT, label='b jets nominal'       , s=100 )
#ax3.scatter(xses, js_divs_all_adv_TT, label='b jets adversarial'     , s=100 )
#ax1.legend(frameon=True)
ax2.legend(frameon=True, framealpha=1, title='c jets, DeepJet', loc="upper left")
leg = ax2.get_legend()
leg._legend_box.align = "left"
leg.get_frame().set_linewidth(0.0)
#ax3.legend(frameon=True)
#ax1.set_ylabel('JS divergence (a.u.)')
ax2.set_ylabel('JS divergence (a.u.)')
#ax3.set_ylabel('JS divergence (a.u.)')
hep.cms.label('Preliminary', data=True, lumi=41.5, year=2017)###fig.suptitle('JS divergence. Best epoch.', y=0.95)
#fig.savefig('js_best_epoch_custom_DeepJet_CHARM_v2.pdf', bbox_inches='tight')
fig.savefig('js_best_epoch_custom_DeepJet_CHARM_v2_30bins_v4.pdf', bbox_inches='tight')

In [None]:
#fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,8))
fig, ax2 = plt.subplots(1, figsize=(9,8))
#fig.subplots_adjust(wspace = 0.3)
fig.tight_layout()
ax2.grid(which='minor', alpha=0.85)
ax2.grid(which='major', alpha=0.95, color='black')
#ax1.scatter(xses, js_divs_all_basic_DY, label='light jets nominal'   , s=100 )
#ax1.scatter(xses, js_divs_all_adv_DY, label='light jets adversarial' , s=100 )
ax2.scatter(xses, chi2_all_basic_Wc, label='Nominal'       , s=200 )
ax2.scatter(xses, chi2_all_adv_Wc, label='Adversarial'     , s=200 )
#ax3.scatter(xses, js_divs_all_basic_TT, label='b jets nominal'       , s=100 )
#ax3.scatter(xses, js_divs_all_adv_TT, label='b jets adversarial'     , s=100 )
#ax1.legend(frameon=True)
ax2.legend(frameon=True, framealpha=1, title='c jets, DeepJet', loc="upper right")
leg = ax2.get_legend()
leg._legend_box.align = "right"
leg.get_frame().set_linewidth(0.0)
#ax3.legend(frameon=True)
#ax1.set_ylabel('JS divergence (a.u.)')
ax2.set_ylabel(r'$\chi^2$ (a.u.)')
#ax3.set_ylabel('JS divergence (a.u.)')
hep.cms.label('Preliminary', data=True, lumi=41.5, year=2017)###fig.suptitle('JS divergence. Best epoch.', y=0.95)
#fig.savefig('chi2_best_epoch_custom_DeepJet_CHARM_v2.pdf', bbox_inches='tight')
fig.savefig('chi2_best_epoch_custom_DeepJet_CHARM_v2_30bins_v4.pdf', bbox_inches='tight')

In [None]:
#fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,8))
fig, ax3 = plt.subplots(1, figsize=(9,8))
#fig.subplots_adjust(wspace = 0.3)
fig.tight_layout()
ax3.grid(which='minor', alpha=0.85)
ax3.grid(which='major', alpha=0.95, color='black')
#ax1.scatter(xses, js_divs_all_basic_DY, label='light jets nominal'   , s=100 )
#ax1.scatter(xses, js_divs_all_adv_DY, label='light jets adversarial' , s=100 )
#ax2.scatter(xses, js_divs_all_basic_Wc, label='c jets nominal'       , s=100 )
#ax2.scatter(xses, js_divs_all_adv_Wc, label='c jets adversarial'     , s=100 )
ax3.scatter(xses, js_divs_all_basic_TT, label='Nominal'       , s=200 )
ax3.scatter(xses, js_divs_all_adv_TT, label='Adversarial'     , s=200 )
#ax1.legend(frameon=True)
#ax2.legend(frameon=True)
ax3.legend(frameon=True, framealpha=1, title='b jets, DeepJet')
leg = ax3.get_legend()
leg._legend_box.align = "left"
leg.get_frame().set_linewidth(0.0)
#ax1.set_ylabel('JS divergence (a.u.)')
#ax2.set_ylabel('JS divergence (a.u.)')
ax3.set_ylabel('JS divergence (a.u.)')
hep.cms.label('Preliminary', data=True, lumi=41.5, year=2017)###fig.suptitle('JS divergence. Best epoch.', y=0.95)
#fig.savefig('js_best_epoch_custom_DeepJet_BOTTOM_v2.pdf', bbox_inches='tight')
fig.savefig('js_best_epoch_custom_DeepJet_BOTTOM_v2_30bins_v4.pdf', bbox_inches='tight')

In [None]:
#fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,8))
fig, ax3 = plt.subplots(1, figsize=(9,8))
#fig.subplots_adjust(wspace = 0.3)
fig.tight_layout()
ax3.grid(which='minor', alpha=0.85)
ax3.grid(which='major', alpha=0.95, color='black')
#ax1.scatter(xses, js_divs_all_basic_DY, label='light jets nominal'   , s=100 )
#ax1.scatter(xses, js_divs_all_adv_DY, label='light jets adversarial' , s=100 )
#ax2.scatter(xses, js_divs_all_basic_Wc, label='c jets nominal'       , s=100 )
#ax2.scatter(xses, js_divs_all_adv_Wc, label='c jets adversarial'     , s=100 )
ax3.scatter(xses, chi2_all_basic_TT, label='Nominal'       , s=200 )
ax3.scatter(xses, chi2_all_adv_TT, label='Adversarial'     , s=200 )
#ax1.legend(frameon=True)
#ax2.legend(frameon=True)
ax3.legend(frameon=True, framealpha=1, title='b jets, DeepJet')
leg = ax3.get_legend()
leg._legend_box.align = "left"
leg.get_frame().set_linewidth(0.0)
#ax1.set_ylabel('JS divergence (a.u.)')
#ax2.set_ylabel('JS divergence (a.u.)')
ax3.set_ylabel(r'$\chi^2$ (a.u.)')
hep.cms.label('Preliminary', data=True, lumi=41.5, year=2017)###fig.suptitle('JS divergence. Best epoch.', y=0.95)
#fig.savefig('chi2_best_epoch_custom_DeepJet_BOTTOM_v2.pdf', bbox_inches='tight')
fig.savefig('chi2_best_epoch_custom_DeepJet_BOTTOM_v2_30bins_v4.pdf', bbox_inches='tight')

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,10))
fig.subplots_adjust(wspace = 0.3)
#fig.tight_layout()
ax1.scatter(xses, js_divs_all_basic_DY, label='DY nominal')
ax1.scatter(xses, js_divs_all_adv_DY, label='DY adversarial')
ax2.scatter(xses, js_divs_all_basic_Wc, label='Wc nominal')
ax2.scatter(xses, js_divs_all_adv_Wc, label='Wc adversarial')
ax3.scatter(xses, js_divs_all_basic_TT, label='TT nominal')
ax3.scatter(xses, js_divs_all_adv_TT, label='TT adversarial')
ax1.legend()
ax2.legend()
ax3.legend()
ax1.set_ylabel('JS divergence (a.u.)')
ax2.set_ylabel('JS divergence (a.u.)')
ax3.set_ylabel('JS divergence (a.u.)')
fig.suptitle('JS divergence. Best epoch.', y=0.95)
#fig.savefig('js_best_epoch_custom_DeepJet.pdf')
fig.savefig('js_best_epoch_custom_DeepJet_30bins_v4.pdf')

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,10))
fig.subplots_adjust(wspace = 0.3)
#fig.tight_layout()
ax1.scatter(xses, kl_divs_all_basic_DY, label='DY nominal')
ax1.scatter(xses, kl_divs_all_adv_DY, label='DY adversarial')
ax2.scatter(xses, kl_divs_all_basic_Wc, label='Wc nominal')
ax2.scatter(xses, kl_divs_all_adv_Wc, label='Wc adversarial')
ax3.scatter(xses, kl_divs_all_basic_TT, label='TT nominal')
ax3.scatter(xses, kl_divs_all_adv_TT, label='TT adversarial')
ax1.legend()
ax2.legend()
ax3.legend()
ax1.set_ylabel('KL divergence (a.u.)')
ax2.set_ylabel('KL divergence (a.u.)')
ax3.set_ylabel('KL divergence (a.u.)')
fig.suptitle('KL divergence. Best epoch.', y=0.95)
#fig.savefig('kl_best_epoch_custom_DeepJet.pdf')
fig.savefig('kl_best_epoch_custom_DeepJet_30bins_v4.pdf')

$\chi$

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(1,3, figsize=(24,10))
fig.subplots_adjust(wspace = 0.3)
#fig.tight_layout()
ax1.scatter(xses, chi2_all_basic_DY, label='DY nominal')
ax1.scatter(xses, chi2_all_adv_DY, label='DY adversarial')
ax2.scatter(xses, chi2_all_basic_Wc, label='Wc nominal')
ax2.scatter(xses, chi2_all_adv_Wc, label='Wc adversarial')
ax3.scatter(xses, chi2_all_basic_TT, label='TT nominal')
ax3.scatter(xses, chi2_all_adv_TT, label='TT adversarial')
ax1.legend()
ax2.legend()
ax3.legend()
ax1.set_ylabel(r'$\chi^2$ (a.u.)')
ax2.set_ylabel(r'$\chi^2$ (a.u.)')
ax3.set_ylabel(r'$\chi^2$ (a.u.)')
fig.suptitle(r'$\mathcal{\chi^2}$. Best epoch.', y=0.95)
#fig.savefig('chi2_best_epoch_custom_DeepJet.pdf')
fig.savefig('chi2_best_epoch_custom_DeepJet_30bins_v4.pdf')

In [None]:
#!tar czf JS_divs_2022_09_15.tar.gz JS_divergences/*.svg